Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Helpers to test a model which follows :epkg:`scikit-learn` API. 

4""" 

5import copy 

6import pickle 

7import pprint 

8from unittest import TestCase 

9from io import BytesIO 

10from numpy import ndarray 

11from numpy.testing import assert_almost_equal 

12from pandas.testing import assert_frame_equal 

13from sklearn.base import BaseEstimator 

14from sklearn.model_selection import train_test_split 

15from sklearn.base import clone 

16from sklearn.pipeline import make_pipeline 

17from sklearn.model_selection import GridSearchCV 

18 

19 

20def train_test_split_with_none(X, y=None, sample_weight=None, random_state=0): 

21 """ 

22 Splits into train and test data even if they are None. 

23 

24 @param X X 

25 @param y y 

26 @param sample_weight sample weight 

27 @param random_state random state 

28 @return similar to :epkg:`scikit-learn:model_selection:train_test_split`. 

29 """ 

30 not_none = [_ for _ in [X, y, sample_weight] if _ is not None] 

31 res = train_test_split(*not_none) 

32 inc = len(not_none) 

33 trains = [] 

34 tests = [] 

35 for i in range(inc): 

36 trains.append(res[i * 2]) 

37 tests.append(res[i * 2 + 1]) 

38 while len(trains) < 3: 

39 trains.append(None) 

40 tests.append(None) 

41 X_train, y_train, w_train = trains 

42 X_test, y_test, w_test = tests 

43 return X_train, y_train, w_train, X_test, y_test, w_test 

44 

45 

46def test_sklearn_pickle(fct_model, X, y=None, sample_weight=None, **kwargs): 

47 """ 

48 Creates a model, fit, predict and check the prediction 

49 are similar after the model was pickled, unpickled. 

50 

51 @param fct_model function which creates the model 

52 @param X X 

53 @param y y 

54 @param sample_weight sample weight 

55 @param kwargs additional parameters for :epkg:`numpy:testing:assert_almost_equal` 

56 @return model, unpickled model 

57 

58 :raises: 

59 AssertionError 

60 """ 

61 X_train, y_train, w_train, X_test, _, __ = train_test_split_with_none( 

62 X, y, sample_weight) 

63 model = fct_model() 

64 if y_train is None and w_train is None: 

65 model.fit(X_train) 

66 else: 

67 try: 

68 model.fit(X_train, y_train, w_train) 

69 except TypeError: 

70 # Do not accept weights? 

71 model.fit(X_train, y_train) 

72 if hasattr(model, 'predict'): 

73 pred1 = model.predict(X_test) 

74 else: 

75 pred1 = model.transform(X_test) 

76 

77 st = BytesIO() 

78 pickle.dump(model, st) 

79 data = BytesIO(st.getvalue()) 

80 model2 = pickle.load(data) 

81 if hasattr(model2, 'predict'): 

82 pred2 = model2.predict(X_test) 

83 else: 

84 pred2 = model2.transform(X_test) 

85 if isinstance(pred1, ndarray): 

86 assert_almost_equal(pred1, pred2, **kwargs) 

87 else: 

88 assert_frame_equal(pred1, pred2, **kwargs) 

89 return model, model2 

90 

91 

92def _get_test_instance(): 

93 try: 

94 from pyquickhelper.pycode import ExtTestCase # pylint: disable=C0415 

95 cls = ExtTestCase 

96 except ImportError: # pragma: no cover 

97 

98 class _ExtTestCase(TestCase): 

99 "simple test classe with a more methods" 

100 

101 def assertIsInstance(self, inst, cltype): 

102 "checks that one instance is from one type" 

103 if not isinstance(inst, cltype): 

104 raise AssertionError( 

105 "Unexpected type {} != {}.".format( 

106 type(inst), cltype)) 

107 

108 cls = _ExtTestCase 

109 return cls() 

110 

111 

112def test_sklearn_clone(fct_model, ext=None, copy_fitted=False): 

113 """ 

114 Tests that a cloned model is similar to the original one. 

115 

116 @param fct_model function which creates the model 

117 @param ext unit test class instance 

118 @param copy_fitted copy fitted parameters as well 

119 @return model, cloned model 

120 

121 :raises: 

122 AssertionError 

123 """ 

124 conv = fct_model() 

125 p1 = conv.get_params(deep=True) 

126 if copy_fitted: 

127 cloned = clone_with_fitted_parameters(conv) 

128 else: 

129 cloned = clone(conv) 

130 p2 = cloned.get_params(deep=True) 

131 if ext is None: 

132 ext = _get_test_instance() 

133 try: 

134 ext.assertEqual(set(p1), set(p2)) 

135 except AssertionError as e: # pragma no cover 

136 p1 = pprint.pformat(p1) 

137 p2 = pprint.pformat(p2) 

138 raise AssertionError( 

139 "Differences between\n----\n{0}\n----\n{1}".format(p1, p2)) from e 

140 

141 for k in sorted(p1): 

142 if isinstance(p1[k], BaseEstimator) and isinstance(p2[k], BaseEstimator): 

143 if copy_fitted: 

144 assert_estimator_equal(p1[k], p2[k]) 

145 elif isinstance(p1[k], list) and isinstance(p2[k], list): 

146 _assert_list_equal(p1[k], p2[k], ext) 

147 else: 

148 try: 

149 ext.assertEqual(p1[k], p2[k]) 

150 except AssertionError: # pragma no cover 

151 raise AssertionError( # pylint: disable=W0707 

152 "Difference for key '{0}'\n==1 {1}\n==2 {2}".format( 

153 k, p1[k], p2[k])) 

154 return conv, cloned 

155 

156 

157def _assert_list_equal(l1, l2, ext): 

158 if len(l1) != len(l2): 

159 raise AssertionError( # pragma no cover 

160 "Lists have different length {0} != {1}".format(len(l1), len(l2))) 

161 for a, b in zip(l1, l2): 

162 if isinstance(a, tuple) and isinstance(b, tuple): 

163 _assert_tuple_equal(a, b, ext) 

164 else: 

165 ext.assertEqual(a, b) 

166 

167 

168def _assert_dict_equal(a, b, ext): 

169 if not isinstance(a, dict): # pragma no cover 

170 raise TypeError('a is not dict but {0}'.format(type(a))) 

171 if not isinstance(b, dict): # pragma no cover 

172 raise TypeError('b is not dict but {0}'.format(type(b))) 

173 rows = [] 

174 for key in sorted(b): 

175 if key not in a: 

176 rows.append("** Added key '{0}' in b".format(key)) 

177 elif isinstance(a[key], BaseEstimator) and isinstance(b[key], BaseEstimator): 

178 assert_estimator_equal(a[key], b[key], ext) 

179 else: 

180 if a[key] != b[key]: 

181 rows.append( 

182 "** Value != for key '{0}': != id({1}) != id({2})\n==1 {3}\n==2 {4}".format( 

183 key, id(a[key]), id(b[key]), a[key], b[key])) 

184 for key in sorted(a): 

185 if key not in b: 

186 rows.append("** Removed key '{0}' in a".format(key)) 

187 if len(rows) > 0: 

188 raise AssertionError( # pragma: no cover 

189 "Dictionaries are different\n{0}".format('\n'.join(rows))) 

190 

191 

192def _assert_tuple_equal(t1, t2, ext): 

193 if len(t1) != len(t2): # pragma no cover 

194 raise AssertionError( 

195 "Lists have different length {0} != {1}".format(len(t1), len(t2))) 

196 for a, b in zip(t1, t2): 

197 if isinstance(a, BaseEstimator) and isinstance(b, BaseEstimator): 

198 assert_estimator_equal(a, b, ext) 

199 else: 

200 ext.assertEqual(a, b) 

201 

202 

203def assert_estimator_equal(esta, estb, ext=None): 

204 """ 

205 Checks that two models are equal. 

206 

207 @param esta first estimator 

208 @param estb second estimator 

209 @param ext unit test class 

210 

211 The function raises an exception if the comparison fails. 

212 """ 

213 if ext is None: 

214 ext = _get_test_instance() 

215 ext.assertIsInstance(esta, estb.__class__) 

216 ext.assertIsInstance(estb, esta.__class__) 

217 _assert_dict_equal(esta.get_params(), estb.get_params(), ext) 

218 for att in esta.__dict__: 

219 if (att.endswith('_') and not att.endswith('__')) or \ 

220 (att.startswith('_') and not att.startswith('__')): 

221 if not hasattr(estb, att): # pragma no cover 

222 raise AssertionError("Missing fitted attribute '{}' class {}\n==1 {}\n==2 {}".format( 

223 att, esta.__class__, list(sorted(esta.__dict__)), list(sorted(estb.__dict__)))) 

224 if isinstance(getattr(esta, att), BaseEstimator): 

225 assert_estimator_equal( 

226 getattr(esta, att), getattr(estb, att), ext) 

227 else: 

228 ext.assertEqual(getattr(esta, att), getattr(estb, att)) 

229 for att in estb.__dict__: 

230 if att.endswith('_') and not att.endswith('__'): 

231 if not hasattr(esta, att): # pragma no cover 

232 raise AssertionError("Missing fitted attribute\n==1 {}\n==2 {}".format( 

233 list(sorted(esta.__dict__)), list(sorted(estb.__dict__)))) 

234 

235 

236def test_sklearn_grid_search_cv(fct_model, X, y=None, sample_weight=None, **grid_params): 

237 """ 

238 Creates a model, checks that a grid search works with it. 

239 

240 @param fct_model function which creates the model 

241 @param X X 

242 @param y y 

243 @param sample_weight sample weight 

244 @param grid_params parameter to use to run the grid search. 

245 @return dictionary with results 

246 

247 :raises: 

248 AssertionError 

249 """ 

250 X_train, y_train, w_train, X_test, y_test, w_test = ( 

251 train_test_split_with_none(X, y, sample_weight)) 

252 model = fct_model() 

253 pipe = make_pipeline(model) 

254 name = model.__class__.__name__.lower() 

255 parameters = {name + "__" + k: v for k, v in grid_params.items()} 

256 if len(parameters) == 0: 

257 raise ValueError( 

258 "Some parameters must be tested when running grid search.") 

259 clf = GridSearchCV(pipe, parameters) 

260 if y_train is None and w_train is None: 

261 clf.fit(X_train) 

262 elif w_train is None: 

263 clf.fit(X_train, y_train) # pylint: disable=E1121 

264 else: 

265 clf.fit(X_train, y_train, w_train) # pylint: disable=E1121 

266 score = clf.score(X_test, y_test) 

267 ext = _get_test_instance() 

268 ext.assertIsInstance(score, float) 

269 return dict(model=clf, X_train=X_train, y_train=y_train, w_train=w_train, 

270 X_test=X_test, y_test=y_test, w_test=w_test, score=score) 

271 

272 

273def clone_with_fitted_parameters(est): 

274 """ 

275 Clones an estimator with the fitted results. 

276 

277 @param est estimator 

278 @return cloned object 

279 """ 

280 def adjust(obj1, obj2): 

281 if isinstance(obj1, list) and isinstance(obj2, list): 

282 for a, b in zip(obj1, obj2): 

283 adjust(a, b) 

284 elif isinstance(obj1, tuple) and isinstance(obj2, tuple): 

285 for a, b in zip(obj1, obj2): 

286 adjust(a, b) 

287 elif isinstance(obj1, dict) and isinstance(obj2, dict): 

288 for a, b in zip(obj1, obj2): 

289 adjust(obj1[a], obj2[b]) 

290 elif isinstance(obj1, BaseEstimator) and isinstance(obj2, BaseEstimator): 

291 for k in obj1.__dict__: 

292 if hasattr(obj2, k): 

293 v1 = getattr(obj1, k) 

294 if callable(v1): 

295 raise RuntimeError( # pragma: no cover 

296 "Cannot migrate trained parameters for {}.".format(obj1)) 

297 elif isinstance(v1, BaseEstimator): 

298 v1 = getattr(obj1, k) 

299 setattr(obj2, k, clone_with_fitted_parameters(v1)) 

300 else: 

301 adjust(getattr(obj1, k), getattr(obj2, k)) 

302 elif (k.endswith('_') and not k.endswith('__')) or \ 

303 (k.startswith('_') and not k.startswith('__')): 

304 v1 = getattr(obj1, k) 

305 setattr(obj2, k, clone_with_fitted_parameters(v1)) 

306 else: 

307 raise RuntimeError( # pragma: no cover 

308 "Cloned object is missing '{0}' in {1}.".format(k, obj2)) 

309 

310 if isinstance(est, BaseEstimator): 

311 cloned = clone(est) 

312 adjust(est, cloned) 

313 res = cloned 

314 elif isinstance(est, list): 

315 res = list(clone_with_fitted_parameters(o) for o in est) 

316 elif isinstance(est, tuple): 

317 res = tuple(clone_with_fitted_parameters(o) for o in est) 

318 elif isinstance(est, dict): 

319 res = {k: clone_with_fitted_parameters(v) for k, v in est.items()} 

320 else: 

321 res = copy.deepcopy(est) 

322 return res