Coverage for mlinsights/mlmodel/transfer_transformer.py: 88%
41 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-28 08:46 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-28 08:46 +0100
1"""
2@file
3@brief Implements a transformer which wraps a predictor
4to do transfer learning.
5"""
6import inspect
7from sklearn.base import BaseEstimator, TransformerMixin
8from .sklearn_testing import clone_with_fitted_parameters
11class TransferTransformer(BaseEstimator, TransformerMixin):
12 """
13 Wraps a predictor or a transformer in a transformer.
14 This model is frozen: it cannot be trained and only
15 computes the predictions.
17 .. index:: transfer learning, frozen model
18 """
20 def __init__(self, estimator, method=None, copy_estimator=True,
21 trainable=False):
22 """
23 @param estimator estimator to wrap in a transformer, it is cloned
24 with the training data (deep copy) when fitted
25 @param method if None, guess what method should be called,
26 *transform* for a transformer,
27 *predict_proba* for a classifier,
28 *decision_function* if found,
29 *predict* otherwiser
30 @param copy_estimator copy the model instead of taking a reference
31 @param trainable the transfered model must be trained
32 """
33 TransformerMixin.__init__(self)
34 BaseEstimator.__init__(self)
35 self.estimator = estimator
36 self.copy_estimator = copy_estimator
37 self.trainable = trainable
38 if method is None:
39 if hasattr(estimator, "transform"):
40 method = "transform"
41 elif hasattr(estimator, "predict_proba"):
42 method = "predict_proba"
43 elif hasattr(estimator, "decision_function"):
44 method = "decision_function"
45 elif hasattr(estimator, "predict"):
46 method = "predict"
47 else:
48 raise AttributeError( # pragma: no cover
49 f"Cannot find a method transform, predict_proba, decision_function, "
50 f"predict in object {type(estimator)}.")
51 if not hasattr(estimator, method):
52 raise AttributeError( # pragma: no cover
53 f"Cannot find method '{method}' in object {type(estimator)}")
54 self.method = method
56 def fit(self, X=None, y=None, sample_weight=None):
57 """
58 The function does nothing.
60 :param X: unused
61 :param y: unused
62 :param sample_weight: unused
63 :return: self: returns an instance of self.
65 Fitted attributes:
67 * `estimator_`: already trained estimator
68 """
69 if self.copy_estimator:
70 self.estimator_ = clone_with_fitted_parameters(self.estimator)
71 from .sklearn_testing import assert_estimator_equal # pylint: disable=C0415
72 assert_estimator_equal(self.estimator_, self.estimator)
73 else:
74 self.estimator_ = self.estimator
75 if self.trainable:
76 insp = inspect.signature(self.estimator_.fit)
77 pars = insp.parameters
78 if 'y' in pars and 'sample_weight' in pars:
79 self.estimator_.fit(X, y, sample_weight)
80 elif 'y' in pars:
81 self.estimator_.fit(X, y)
82 elif 'sample_weight' in pars:
83 self.estimator_.fit(X, sample_weight=sample_weight)
84 else:
85 self.estimator_.fit(X)
86 return self
88 def transform(self, X):
89 """
90 Runs the predictions.
92 :param X: numpy array or sparse matrix of shape [n_samples,n_features]
93 Training data
94 :return: tranformed *X*
95 """
96 meth = getattr(self.estimator_, self.method)
97 return meth(X)