Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- coding: utf-8 -*-
2"""
3@file
4@brief About Machine Learning Benchmark
5"""
6import os
7import numpy
8from sklearn.model_selection import train_test_split
9from sklearn.base import ClusterMixin
10from sklearn.metrics import silhouette_score
11from pyquickhelper.loghelper import noLOG
12from pyquickhelper.benchhelper import GridBenchMark
15class MlGridBenchMark(GridBenchMark):
16 """
17 The class tests a list of model over a list of datasets.
18 """
20 def __init__(self, name, datasets, clog=None, fLOG=noLOG, path_to_images=".",
21 cache_file=None, progressbar=None, graphx=None, graphy=None,
22 **params):
23 """
24 @param name name of the test
25 @param datasets list of dictionary of dataframes
26 @param clog see @see cl CustomLog or string
27 @param fLOG logging function
28 @param params extra parameters
29 @param path_to_images path to images and intermediate results
30 @param cache_file cache file
31 @param progressbar relies on *tqdm*, example *tnrange*
32 @param graphx list of variables to use as X axis
33 @param graphy list of variables to use as Y axis
35 If *cache_file* is specified, the class will store the results of the
36 method :meth:`bench <pyquickhelper.benchhelper.benchmark.GridBenchMark.bench>`.
37 On a second run, the function load the cache
38 and run modified or new run (in *param_list*).
40 *datasets* should be a dictionary with dataframes a values
41 with the following keys:
43 * ``'X'``: features
44 * ``'Y'``: labels (optional)
45 """
46 GridBenchMark.__init__(self, name=name, datasets=datasets, clog=clog, fLOG=fLOG,
47 path_to_images=path_to_images, cache_file=cache_file,
48 progressbar=progressbar, **params)
49 self._xaxis = graphx
50 self._yaxis = graphy
52 def preprocess_dataset(self, dsi, **params):
53 """
54 Splits the dataset into train and test.
56 @param dsi dataset index
57 @param params additional parameters
58 @return dataset (like info), dictionary for metrics
59 """
60 ds, appe, params = GridBenchMark.preprocess_dataset(
61 self, dsi, **params)
63 no_split = ds["no_split"] if "no_split" in ds else False
65 if no_split:
66 self.fLOG("[MlGridBenchMark.preprocess_dataset] no split")
67 return (ds, ds), appe, params
69 self.fLOG("[MlGridBenchMark.preprocess_dataset] split train test")
70 spl = ["X", "Y", "weight", "group"]
71 names = [_ for _ in spl if _ in ds]
72 if len(names) == 0:
73 raise ValueError( # pragma: no cover
74 "No dataframe or matrix was found.")
75 mats = [ds[_] for _ in names]
77 pars = {"train_size", "test_size"}
78 options = {k: v for k, v in params.items() if k in pars}
79 for k in pars:
80 if k in params:
81 del params[k]
83 res = train_test_split(*mats, **options)
85 train = {}
86 for i, n in enumerate(names):
87 train[n] = res[i * 2]
88 test = {}
89 for i, n in enumerate(names):
90 test[n] = res[i * 2 + 1]
92 self.fLOG("[MlGridBenchMark.preprocess_dataset] done")
93 return (train, test), appe, params
95 def bench_experiment(self, ds, **params): # pylint: disable=W0237
96 """
97 Calls meth *fit*.
98 """
99 if not isinstance(ds, tuple) and len(ds) != 2:
100 raise TypeError( # pragma: no cover
101 "ds must a tuple with two dictionaries train, test")
102 if "model" not in params:
103 raise KeyError( # pragma: no cover
104 "params must contains key 'model'")
105 model = params["model"]
106 # we assume model is a function which creates a model
107 model = model()
108 del params["model"]
109 return self.fit(ds[0], model, **params)
111 def predict_score_experiment(self, ds, model, **params): # pylint: disable=W0237
112 """
113 Calls method *score*.
114 """
115 if not isinstance(ds, tuple) and len(ds) != 2:
116 raise TypeError( # pragma: no cover
117 "ds must a tuple with two dictionaries train, test")
118 if "model" in params:
119 raise KeyError( # pragma: no cover
120 "params must not contains key 'model'")
121 return self.score(ds[1], model, **params)
123 def fit(self, ds, model, **params):
124 """
125 Trains a model.
127 @param ds dictionary with the data to use for training
128 @param model model to train
129 """
130 if "X" not in ds:
131 raise KeyError( # pragma: no cover
132 "ds must contain key 'X'")
133 if "model" in params:
134 raise KeyError( # pragma: no cover
135 "params must not contain key 'model', this is the model to train")
136 X = ds["X"]
137 Y = ds.get("Y", None)
138 weight = ds.get("weight", None)
139 self.fLOG("[MlGridBenchMark.fit] fit", params)
141 train_params = params.get("train_params", {})
143 if weight is not None:
144 model.fit(X=X, y=Y, weight=weight, **train_params)
145 else:
146 model.fit(X=X, y=Y, **train_params)
147 self.fLOG("[MlGridBenchMark.fit] Done.")
148 return model
150 def score(self, ds, model, **params):
151 """
152 Scores a model.
153 """
154 X = ds["X"]
155 Y = ds.get("Y", None)
157 if "weight" in ds:
158 raise NotImplementedError( # pragma: no cover
159 "weight are not used yet")
161 metrics = {}
162 appe = {}
164 if hasattr(model, "score"):
165 score = model.score(X, Y)
166 metrics["own_score"] = score
168 if isinstance(model, ClusterMixin):
169 # add silhouette
170 if hasattr(model, "predict"):
171 ypred = model.predict(X)
172 elif hasattr(model, "transform"):
173 ypred = model.transform(X)
174 elif hasattr(model, "labels_"):
175 ypred = model.labels_
176 if len(ypred.shape) > 1 and ypred.shape[1] > 1:
177 ypred = numpy.argmax(ypred, axis=1)
178 score = silhouette_score(X, ypred)
179 metrics["silhouette"] = score
181 return metrics, appe
183 def end(self):
184 """
185 nothing to do
186 """
187 pass
189 def graphs(self, path_to_images):
190 """
191 Plots multiples graphs.
193 @param path_to_images where to store images
194 @return list of tuple (image_name, function to create the graph)
195 """
196 import matplotlib.pyplot as plt # pylint: disable=C0415
197 import matplotlib.cm as mcm # pylint: disable=C0415
198 df = self.to_df()
200 def local_graph(vx, vy, ax=None, text=True, figsize=(5, 5)):
201 btrys = set(df["_btry"])
202 ymin = df[vy].min()
203 ymax = df[vy].max()
204 decy = (ymax - ymin) / 50
205 colors = mcm.rainbow(numpy.linspace(0, 1, len(btrys)))
206 if len(btrys) == 0:
207 raise ValueError("The benchmark is empty.") # pragma: no cover
208 if ax is None:
209 _, ax = plt.subplots(1, 1, figsize=figsize) # pragma: no cover
210 ax.grid(True) # pragma: no cover
211 for i, btry in enumerate(sorted(btrys)):
212 subset = df[df["_btry"] == btry]
213 if subset.shape[0] > 0:
214 tx = subset[vx].mean()
215 ty = subset[vy].mean()
216 if not numpy.isnan(tx) and not numpy.isnan(ty):
217 subset.plot(x=vx, y=vy, kind="scatter",
218 label=btry, ax=ax, color=colors[i])
219 if text:
220 ax.text(tx, ty + decy, btry, size='small',
221 color=colors[i], ha='center', va='bottom')
222 ax.set_xlabel(vx)
223 ax.set_ylabel(vy)
224 return ax
226 res = []
227 if self._xaxis is not None and self._yaxis is not None:
228 for vx in self._xaxis:
229 for vy in self._yaxis:
230 self.fLOG("Plotting {0} x {1}".format(vx, vy))
231 func_graph = lambda ax=None, text=True, vx=vx, vy=vy, **kwargs: \
232 local_graph(vx, vy, ax=ax, text=text, **kwargs)
234 if path_to_images is not None:
235 img = os.path.join(
236 path_to_images, "img-{0}-{1}x{2}.png".format(self.Name, vx, vy))
237 gr = self.LocalGraph(
238 func_graph, img, root=path_to_images)
239 self.fLOG("Saving '{0}'".format(img))
240 fig, ax = plt.subplots(1, 1, figsize=(8, 8))
241 gr.plot(ax=ax, text=True)
242 fig.savefig(img)
243 self.fLOG("Done")
244 res.append(gr)
245 plt.close('all')
246 else:
247 gr = self.LocalGraph(func_graph)
248 res.append(gr)
249 return res
251 def plot_graphs(self, grid=None, text=True, **kwargs):
252 """
253 Plots all graphs in the same graphs.
255 @param grid grid of axes
256 @param text add legend title on the graph
257 @return grid
258 """
259 nb = len(self.Graphs)
260 if nb == 0:
261 raise ValueError("No graph to plot.") # pragma: no cover
263 nb = len(self.Graphs)
264 if nb % 2 == 0:
265 size = nb // 2, 2
266 else:
267 size = nb // 2 + 1, 2
269 if grid is None:
270 import matplotlib.pyplot as plt # pylint: disable=C0415
271 fg = kwargs.get('figsize', (5 * size[0], 10))
272 _, grid = plt.subplots(size[0], size[1], figsize=fg)
273 if 'figsize' in kwargs:
274 del kwargs['figsize'] # pragma: no cover
275 else:
276 shape = grid.shape
277 if shape[0] * shape[1] < nb:
278 raise ValueError( # pragma: no cover
279 "The graph is not big enough {0} < {1}".format(shape, nb))
281 x = 0
282 y = 0
283 for i, gr in enumerate(self.Graphs):
284 self.fLOG("Plot graph {0}/{1}".format(i + 1, nb))
285 gr.plot(ax=grid[y, x], text=text, **kwargs)
286 x += 1
287 if x >= grid.shape[1]:
288 x = 0
289 y += 1
290 return grid