Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Implements a transformer which wraps a predictor
4to do transfer learning.
5"""
6import inspect
7from sklearn.base import BaseEstimator, TransformerMixin
8from .sklearn_testing import clone_with_fitted_parameters
11class TransferTransformer(BaseEstimator, TransformerMixin):
12 """
13 Wraps a predictor or a transformer in a transformer.
14 This model is frozen: it cannot be trained and only
15 computes the predictions.
17 .. index:: transfer learning, frozen model
18 """
20 def __init__(self, estimator, method=None, copy_estimator=True,
21 trainable=False):
22 """
23 @param estimator estimator to wrap in a transformer, it is cloned
24 with the training data (deep copy) when fitted
25 @param method if None, guess what method should be called,
26 *transform* for a transformer,
27 *predict_proba* for a classifier,
28 *decision_function* if found,
29 *predict* otherwiser
30 @param copy_estimator copy the model instead of taking a reference
31 @param trainable the transfered model must be trained
32 """
33 TransformerMixin.__init__(self)
34 BaseEstimator.__init__(self)
35 self.estimator = estimator
36 self.copy_estimator = copy_estimator
37 self.trainable = trainable
38 if method is None:
39 if hasattr(estimator, "transform"):
40 method = "transform"
41 elif hasattr(estimator, "predict_proba"):
42 method = "predict_proba"
43 elif hasattr(estimator, "decision_function"):
44 method = "decision_function"
45 elif hasattr(estimator, "predict"):
46 method = "predict"
47 else:
48 raise AttributeError( # pragma: no cover
49 "Cannot find a method transform, predict_proba, decision_function, "
50 "predict in object {}".format(type(estimator)))
51 if not hasattr(estimator, method):
52 raise AttributeError( # pragma: no cover
53 "Cannot find method '{}' in object {}".format(
54 method, type(estimator)))
55 self.method = method
57 def fit(self, X=None, y=None, sample_weight=None):
58 """
59 The function does nothing.
61 :param X: unused
62 :param y: unused
63 :param sample_weight: unused
64 :return: self: returns an instance of self.
66 Fitted attributes:
68 * `estimator_`: already trained estimator
69 """
70 if self.copy_estimator:
71 self.estimator_ = clone_with_fitted_parameters(self.estimator)
72 from .sklearn_testing import assert_estimator_equal # pylint: disable=C0415
73 assert_estimator_equal(self.estimator_, self.estimator)
74 else:
75 self.estimator_ = self.estimator
76 if self.trainable:
77 insp = inspect.signature(self.estimator_.fit)
78 pars = insp.parameters
79 if 'y' in pars and 'sample_weight' in pars:
80 self.estimator_.fit(X, y, sample_weight)
81 elif 'y' in pars:
82 self.estimator_.fit(X, y)
83 elif 'sample_weight' in pars:
84 self.estimator_.fit(X, sample_weight=sample_weight)
85 else:
86 self.estimator_.fit(X)
87 return self
89 def transform(self, X):
90 """
91 Runs the predictions.
93 :param X: numpy array or sparse matrix of shape [n_samples,n_features]
94 Training data
95 :return: tranformed *X*
96 """
97 meth = getattr(self.estimator_, self.method)
98 return meth(X)