Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Helpers to test a model which follows :epkg:`scikit-learn` API.
4"""
5import copy
6import pickle
7import pprint
8from unittest import TestCase
9from io import BytesIO
10from numpy import ndarray
11from numpy.testing import assert_almost_equal
12from pandas.testing import assert_frame_equal
13from sklearn.base import BaseEstimator
14from sklearn.model_selection import train_test_split
15from sklearn.base import clone
16from sklearn.pipeline import make_pipeline
17from sklearn.model_selection import GridSearchCV
20def train_test_split_with_none(X, y=None, sample_weight=None, random_state=0):
21 """
22 Splits into train and test data even if they are None.
24 @param X X
25 @param y y
26 @param sample_weight sample weight
27 @param random_state random state
28 @return similar to :epkg:`scikit-learn:model_selection:train_test_split`.
29 """
30 not_none = [_ for _ in [X, y, sample_weight] if _ is not None]
31 res = train_test_split(*not_none)
32 inc = len(not_none)
33 trains = []
34 tests = []
35 for i in range(inc):
36 trains.append(res[i * 2])
37 tests.append(res[i * 2 + 1])
38 while len(trains) < 3:
39 trains.append(None)
40 tests.append(None)
41 X_train, y_train, w_train = trains
42 X_test, y_test, w_test = tests
43 return X_train, y_train, w_train, X_test, y_test, w_test
46def test_sklearn_pickle(fct_model, X, y=None, sample_weight=None, **kwargs):
47 """
48 Creates a model, fit, predict and check the prediction
49 are similar after the model was pickled, unpickled.
51 @param fct_model function which creates the model
52 @param X X
53 @param y y
54 @param sample_weight sample weight
55 @param kwargs additional parameters for :epkg:`numpy:testing:assert_almost_equal`
56 @return model, unpickled model
58 :raises:
59 AssertionError
60 """
61 X_train, y_train, w_train, X_test, _, __ = train_test_split_with_none(
62 X, y, sample_weight)
63 model = fct_model()
64 if y_train is None and w_train is None:
65 model.fit(X_train)
66 else:
67 try:
68 model.fit(X_train, y_train, w_train)
69 except TypeError:
70 # Do not accept weights?
71 model.fit(X_train, y_train)
72 if hasattr(model, 'predict'):
73 pred1 = model.predict(X_test)
74 else:
75 pred1 = model.transform(X_test)
77 st = BytesIO()
78 pickle.dump(model, st)
79 data = BytesIO(st.getvalue())
80 model2 = pickle.load(data)
81 if hasattr(model2, 'predict'):
82 pred2 = model2.predict(X_test)
83 else:
84 pred2 = model2.transform(X_test)
85 if isinstance(pred1, ndarray):
86 assert_almost_equal(pred1, pred2, **kwargs)
87 else:
88 assert_frame_equal(pred1, pred2, **kwargs)
89 return model, model2
92def _get_test_instance():
93 try:
94 from pyquickhelper.pycode import ExtTestCase # pylint: disable=C0415
95 cls = ExtTestCase
96 except ImportError: # pragma: no cover
98 class _ExtTestCase(TestCase):
99 "simple test classe with a more methods"
101 def assertIsInstance(self, inst, cltype):
102 "checks that one instance is from one type"
103 if not isinstance(inst, cltype):
104 raise AssertionError(
105 "Unexpected type {} != {}.".format(
106 type(inst), cltype))
108 cls = _ExtTestCase
109 return cls()
112def test_sklearn_clone(fct_model, ext=None, copy_fitted=False):
113 """
114 Tests that a cloned model is similar to the original one.
116 @param fct_model function which creates the model
117 @param ext unit test class instance
118 @param copy_fitted copy fitted parameters as well
119 @return model, cloned model
121 :raises:
122 AssertionError
123 """
124 conv = fct_model()
125 p1 = conv.get_params(deep=True)
126 if copy_fitted:
127 cloned = clone_with_fitted_parameters(conv)
128 else:
129 cloned = clone(conv)
130 p2 = cloned.get_params(deep=True)
131 if ext is None:
132 ext = _get_test_instance()
133 try:
134 ext.assertEqual(set(p1), set(p2))
135 except AssertionError as e: # pragma no cover
136 p1 = pprint.pformat(p1)
137 p2 = pprint.pformat(p2)
138 raise AssertionError(
139 "Differences between\n----\n{0}\n----\n{1}".format(p1, p2)) from e
141 for k in sorted(p1):
142 if isinstance(p1[k], BaseEstimator) and isinstance(p2[k], BaseEstimator):
143 if copy_fitted:
144 assert_estimator_equal(p1[k], p2[k])
145 elif isinstance(p1[k], list) and isinstance(p2[k], list):
146 _assert_list_equal(p1[k], p2[k], ext)
147 else:
148 try:
149 ext.assertEqual(p1[k], p2[k])
150 except AssertionError: # pragma no cover
151 raise AssertionError( # pylint: disable=W0707
152 "Difference for key '{0}'\n==1 {1}\n==2 {2}".format(
153 k, p1[k], p2[k]))
154 return conv, cloned
157def _assert_list_equal(l1, l2, ext):
158 if len(l1) != len(l2):
159 raise AssertionError( # pragma no cover
160 "Lists have different length {0} != {1}".format(len(l1), len(l2)))
161 for a, b in zip(l1, l2):
162 if isinstance(a, tuple) and isinstance(b, tuple):
163 _assert_tuple_equal(a, b, ext)
164 else:
165 ext.assertEqual(a, b)
168def _assert_dict_equal(a, b, ext):
169 if not isinstance(a, dict): # pragma no cover
170 raise TypeError('a is not dict but {0}'.format(type(a)))
171 if not isinstance(b, dict): # pragma no cover
172 raise TypeError('b is not dict but {0}'.format(type(b)))
173 rows = []
174 for key in sorted(b):
175 if key not in a:
176 rows.append("** Added key '{0}' in b".format(key))
177 elif isinstance(a[key], BaseEstimator) and isinstance(b[key], BaseEstimator):
178 assert_estimator_equal(a[key], b[key], ext)
179 else:
180 if a[key] != b[key]:
181 rows.append(
182 "** Value != for key '{0}': != id({1}) != id({2})\n==1 {3}\n==2 {4}".format(
183 key, id(a[key]), id(b[key]), a[key], b[key]))
184 for key in sorted(a):
185 if key not in b:
186 rows.append("** Removed key '{0}' in a".format(key))
187 if len(rows) > 0:
188 raise AssertionError( # pragma: no cover
189 "Dictionaries are different\n{0}".format('\n'.join(rows)))
192def _assert_tuple_equal(t1, t2, ext):
193 if len(t1) != len(t2): # pragma no cover
194 raise AssertionError(
195 "Lists have different length {0} != {1}".format(len(t1), len(t2)))
196 for a, b in zip(t1, t2):
197 if isinstance(a, BaseEstimator) and isinstance(b, BaseEstimator):
198 assert_estimator_equal(a, b, ext)
199 else:
200 ext.assertEqual(a, b)
203def assert_estimator_equal(esta, estb, ext=None):
204 """
205 Checks that two models are equal.
207 @param esta first estimator
208 @param estb second estimator
209 @param ext unit test class
211 The function raises an exception if the comparison fails.
212 """
213 if ext is None:
214 ext = _get_test_instance()
215 ext.assertIsInstance(esta, estb.__class__)
216 ext.assertIsInstance(estb, esta.__class__)
217 _assert_dict_equal(esta.get_params(), estb.get_params(), ext)
218 for att in esta.__dict__:
219 if (att.endswith('_') and not att.endswith('__')) or \
220 (att.startswith('_') and not att.startswith('__')):
221 if not hasattr(estb, att): # pragma no cover
222 raise AssertionError("Missing fitted attribute '{}' class {}\n==1 {}\n==2 {}".format(
223 att, esta.__class__, list(sorted(esta.__dict__)), list(sorted(estb.__dict__))))
224 if isinstance(getattr(esta, att), BaseEstimator):
225 assert_estimator_equal(
226 getattr(esta, att), getattr(estb, att), ext)
227 else:
228 ext.assertEqual(getattr(esta, att), getattr(estb, att))
229 for att in estb.__dict__:
230 if att.endswith('_') and not att.endswith('__'):
231 if not hasattr(esta, att): # pragma no cover
232 raise AssertionError("Missing fitted attribute\n==1 {}\n==2 {}".format(
233 list(sorted(esta.__dict__)), list(sorted(estb.__dict__))))
236def test_sklearn_grid_search_cv(fct_model, X, y=None, sample_weight=None, **grid_params):
237 """
238 Creates a model, checks that a grid search works with it.
240 @param fct_model function which creates the model
241 @param X X
242 @param y y
243 @param sample_weight sample weight
244 @param grid_params parameter to use to run the grid search.
245 @return dictionary with results
247 :raises:
248 AssertionError
249 """
250 X_train, y_train, w_train, X_test, y_test, w_test = (
251 train_test_split_with_none(X, y, sample_weight))
252 model = fct_model()
253 pipe = make_pipeline(model)
254 name = model.__class__.__name__.lower()
255 parameters = {name + "__" + k: v for k, v in grid_params.items()}
256 if len(parameters) == 0:
257 raise ValueError(
258 "Some parameters must be tested when running grid search.")
259 clf = GridSearchCV(pipe, parameters)
260 if y_train is None and w_train is None:
261 clf.fit(X_train)
262 elif w_train is None:
263 clf.fit(X_train, y_train) # pylint: disable=E1121
264 else:
265 clf.fit(X_train, y_train, w_train) # pylint: disable=E1121
266 score = clf.score(X_test, y_test)
267 ext = _get_test_instance()
268 ext.assertIsInstance(score, float)
269 return dict(model=clf, X_train=X_train, y_train=y_train, w_train=w_train,
270 X_test=X_test, y_test=y_test, w_test=w_test, score=score)
273def clone_with_fitted_parameters(est):
274 """
275 Clones an estimator with the fitted results.
277 @param est estimator
278 @return cloned object
279 """
280 def adjust(obj1, obj2):
281 if isinstance(obj1, list) and isinstance(obj2, list):
282 for a, b in zip(obj1, obj2):
283 adjust(a, b)
284 elif isinstance(obj1, tuple) and isinstance(obj2, tuple):
285 for a, b in zip(obj1, obj2):
286 adjust(a, b)
287 elif isinstance(obj1, dict) and isinstance(obj2, dict):
288 for a, b in zip(obj1, obj2):
289 adjust(obj1[a], obj2[b])
290 elif isinstance(obj1, BaseEstimator) and isinstance(obj2, BaseEstimator):
291 for k in obj1.__dict__:
292 if hasattr(obj2, k):
293 v1 = getattr(obj1, k)
294 if callable(v1):
295 raise RuntimeError( # pragma: no cover
296 "Cannot migrate trained parameters for {}.".format(obj1))
297 elif isinstance(v1, BaseEstimator):
298 v1 = getattr(obj1, k)
299 setattr(obj2, k, clone_with_fitted_parameters(v1))
300 else:
301 adjust(getattr(obj1, k), getattr(obj2, k))
302 elif (k.endswith('_') and not k.endswith('__')) or \
303 (k.startswith('_') and not k.startswith('__')):
304 v1 = getattr(obj1, k)
305 setattr(obj2, k, clone_with_fitted_parameters(v1))
306 else:
307 raise RuntimeError( # pragma: no cover
308 "Cloned object is missing '{0}' in {1}.".format(k, obj2))
310 if isinstance(est, BaseEstimator):
311 cloned = clone(est)
312 adjust(est, cloned)
313 res = cloned
314 elif isinstance(est, list):
315 res = list(clone_with_fitted_parameters(o) for o in est)
316 elif isinstance(est, tuple):
317 res = tuple(clone_with_fitted_parameters(o) for o in est)
318 elif isinstance(est, dict):
319 res = {k: clone_with_fitted_parameters(v) for k, v in est.items()}
320 else:
321 res = copy.deepcopy(est)
322 return res