Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- coding: utf-8 -*- 

2""" 

3@file 

4@brief Grid benchmark. 

5""" 

6 

7from time import perf_counter 

8from ..loghelper import noLOG 

9from .benchmark import BenchMark 

10 

11 

12class GridBenchMark(BenchMark): 

13 """ 

14 Compares a couple of machine learning models. 

15 """ 

16 

17 def __init__(self, name, datasets, clog=None, fLOG=noLOG, path_to_images=".", 

18 cache_file=None, repetition=1, progressbar=None, **params): 

19 """ 

20 @param name name of the test 

21 @param datasets list of dictionary of dataframes 

22 @param clog see @see cl CustomLog or string 

23 @param fLOG logging function 

24 @param params extra parameters 

25 @param path_to_images path to images 

26 @param cache_file cache file 

27 @param repetition repetition of the experiment (to get confidence interval) 

28 @param progressbar relies on *tqdm*, example *tnrange* 

29 

30 If *cache_file* is specified, the class will store the results of the 

31 method :meth:`bench <pyquickhelper.benchhelper.benchmark.GridBenchMark.bench>`. 

32 On a second run, the function load the cache 

33 and run modified or new run (in *param_list*). 

34 

35 *datasets* should be a dictionary with dataframes a values 

36 with the following keys: 

37 

38 * ``'X'``: features 

39 * ``'Y'``: labels (optional) 

40 """ 

41 BenchMark.__init__(self, name=name, datasets=datasets, clog=clog, 

42 fLOG=fLOG, path_to_images=path_to_images, 

43 cache_file=cache_file, progressbar=progressbar, 

44 **params) 

45 

46 if not isinstance(datasets, list): 

47 raise TypeError("datasets must be a list") # pragma: no cover 

48 for i, df in enumerate(datasets): 

49 if not isinstance(df, dict): 

50 raise TypeError( # pragma: no cover 

51 "Every dataset must be a dictionary, {0}th is not.".format(i)) 

52 if "X" not in df: 

53 raise KeyError( # pragma: no cover 

54 "Dictionary {0} should contain key 'X'.".format(i)) 

55 if "di" in df: 

56 raise KeyError( # pragma: no cover 

57 "Dictionary {0} should not contain key 'di'.".format(i)) 

58 if "name" not in df: 

59 raise KeyError( # pragma: no cover 

60 "Dictionary {0} should not contain key 'name'.".format(i)) 

61 if "shortname" not in df: 

62 raise KeyError( # pragma: no cover 

63 "Dictionary {0} should not contain key 'shortname'.".format(i)) 

64 self._datasets = datasets 

65 self._repetition = repetition 

66 

67 def init_main(self): 

68 """ 

69 initialisation 

70 """ 

71 skip = {"X", "Y", "weight", "name", "shortname"} 

72 self.fLOG("[MlGridBenchmark.init] begin") 

73 self._datasets_info = [] 

74 self._results = [] 

75 for i, dd in enumerate(self._datasets): 

76 X = dd["X"] 

77 N = X.shape[0] 

78 Nc = X.shape[1] 

79 info = dict(Nrows=N, Nfeat=Nc) 

80 for k, v in dd.items(): 

81 if k not in skip: 

82 info[k] = v 

83 self.fLOG( 

84 "[MlGridBenchmark.init] dataset {0}: {1}".format(i, info)) 

85 self._datasets_info.append(info) 

86 

87 self.fLOG("[MlGridBenchmark.init] end") 

88 

89 def init(self): 

90 """ 

91 Skips it. 

92 """ 

93 pass # pragma: no cover 

94 

95 def run(self, params_list): 

96 """ 

97 Runs the benchmark. 

98 """ 

99 self.init_main() 

100 self.fLOG("[MlGridBenchmark.bench] start") 

101 self.fLOG("[MlGridBenchmark.bench] number of datasets", 

102 len(self._datasets)) 

103 self.fLOG("[MlGridBenchmark.bench] number of experiments", 

104 len(params_list)) 

105 

106 unique = set() 

107 for i, pars in enumerate(params_list): 

108 if "name" not in pars: 

109 raise KeyError( # pragma: no cover 

110 "Dictionary {0} must contain key 'name'.".format(i)) 

111 if "shortname" not in pars: 

112 raise KeyError( # pragma: no cover 

113 "Dictionary {0} must contain key 'shortname'.".format(i)) 

114 if pars["name"] in unique: 

115 raise ValueError( # pragma: no cover 

116 "'{0}' is duplicated.".format(pars["name"])) 

117 unique.add(pars["name"]) 

118 if pars["shortname"] in unique: 

119 raise ValueError( # pragma: no cover 

120 "'{0}' is duplicated.".format(pars["shortname"])) 

121 unique.add(pars["shortname"]) 

122 

123 # Multiplies the experiments. 

124 full_list = [] 

125 for i in range(len(self._datasets)): 

126 for pars in params_list: 

127 pc = pars.copy() 

128 pc["di"] = i 

129 full_list.append(pc) 

130 

131 # Runs the bench 

132 res = BenchMark.run(self, full_list) 

133 

134 self.fLOG("[MlGridBenchmark.bench] end") 

135 return res 

136 

137 def bench(self, **params): 

138 """ 

139 Runs an experiment multiple times, 

140 parameter *di* is the dataset to use. 

141 """ 

142 if "di" not in params: 

143 raise KeyError( 

144 "key 'di' is missing from params") # pragma: no cover 

145 results = [] 

146 

147 for iexp in range(self._repetition): 

148 

149 di = params["di"] 

150 shortname_model = params["shortname"] 

151 name_model = params["name"] 

152 shortname_ds = self._datasets[di]["shortname"] 

153 name_ds = self._datasets[di]["name"] 

154 

155 cl = perf_counter() 

156 ds, appe, pars = self.preprocess_dataset(di, **params) 

157 split = perf_counter() - cl 

158 

159 cl = perf_counter() 

160 output = self.bench_experiment(ds, **pars) 

161 train = perf_counter() - cl 

162 

163 cl = perf_counter() 

164 metrics, appe_ = self.predict_score_experiment(ds, output) 

165 test = perf_counter() - cl 

166 

167 metrics["time_preproc"] = split 

168 metrics["time_train"] = train 

169 metrics["time_test"] = test 

170 metrics["_btry"] = "{0}-{1}".format(shortname_model, shortname_ds) 

171 metrics["_iexp"] = iexp 

172 metrics["model_name"] = name_model 

173 metrics["ds_name"] = name_ds 

174 appe.update(appe_) 

175 appe["_iexp"] = iexp 

176 metrics.update(appe) 

177 

178 appe["_btry"] = metrics["_btry"] 

179 if "_i" in metrics: 

180 del metrics["_i"] # pragma: no cover 

181 results.append((metrics, appe)) 

182 

183 return results 

184 

185 def preprocess_dataset(self, dsi, **params): 

186 """ 

187 Splits the dataset into train and test. 

188 

189 @param dsi dataset index 

190 @param params additional parameters 

191 @return list of (dataset (like info), dictionary for metrics, parameters) 

192 """ 

193 ds = self._datasets[dsi] 

194 appe = self._datasets_info[dsi].copy() 

195 params = params.copy() 

196 if "di" in params: 

197 del params["di"] 

198 return ds, appe, params 

199 

200 def bench_experiment(self, info, **params): 

201 """ 

202 function to overload 

203 

204 @param info dictionary with at least key ``'X'`` 

205 @param params additional parameters 

206 @return output of the experiment 

207 """ 

208 raise NotImplementedError() # pragma: no cover 

209 

210 def predict_score_experiment(self, info, output, **params): 

211 """ 

212 function to overload 

213 

214 @param info dictionary with at least key ``'X'`` 

215 @param output output of the benchmar 

216 @param params additional parameters 

217 @return output of the experiment, tuple of dictionaries 

218 """ 

219 raise NotImplementedError() # pragma: no cover