Coverage for mlinsights/sklapi/sklearn_base_transform_learner.py: 98%
66 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-28 08:46 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-28 08:46 +0100
1# -*- coding: utf-8 -*-
2"""
3@file
4@brief Implements a *transform* which converts a *learner* into
5a *transform*.
6"""
7import textwrap
8import numpy
9from .sklearn_base_transform import SkBaseTransform
12class SkBaseTransformLearner(SkBaseTransform):
13 """
14 A *transform* which hides a *learner*, it converts
15 method *predict* into *transform*. This way,
16 two learners can be inserted into the same pipeline.
17 There is another a,d shorter implementation
18 with class @see class TransferTransformer.
20 .. exref::
21 :title: Use two learners into a same pipeline
22 :tag: sklearn
23 :lid: ex-pipe2learner
25 It is impossible to use two *learners* into a pipeline
26 unless we use a class such as @see cl SkBaseTransformLearner
27 which disguise a *learner* into a *transform*.
29 .. runpython::
30 :showcode:
31 :warningout: FutureWarning
33 from sklearn.model_selection import train_test_split
34 from sklearn.datasets import load_iris
35 from sklearn.linear_model import LogisticRegression
36 from sklearn.tree import DecisionTreeClassifier
37 from sklearn.metrics import accuracy_score
38 from sklearn.pipeline import make_pipeline
39 from mlinsights.sklapi import SkBaseTransformLearner
41 data = load_iris()
42 X, y = data.data, data.target
43 X_train, X_test, y_train, y_test = train_test_split(X, y)
45 try:
46 pipe = make_pipeline(LogisticRegression(),
47 DecisionTreeClassifier())
48 except Exception as e:
49 print("ERROR:")
50 print(e)
51 print('.')
53 pipe = make_pipeline(SkBaseTransformLearner(LogisticRegression()),
54 DecisionTreeClassifier())
55 pipe.fit(X_train, y_train)
56 pred = pipe.predict(X_test)
57 score = accuracy_score(y_test, pred)
58 print("pipeline avec deux learners :", score)
59 """
61 def __init__(self, model=None, method=None, **kwargs):
62 """
63 @param model learner instance
64 @param method method to call to transform the feature (see below)
65 @param kwargs parameters
67 Options for parameter *method*:
69 * ``'predict'``
70 * ``'predict_proba'``
71 * ``'decision_function'``
72 * a function
74 If *method is None*, the function tries first
75 ``predict_proba`` then ``predict`` until one of them
76 is part of the class.
77 """
78 super().__init__(**kwargs)
79 self.model = model
80 if model is None:
81 raise ValueError("value cannot be None") # pragma: no cover
82 if method is None:
83 for name in ['predict_proba', 'predict', 'transform']:
84 if hasattr(model.__class__, name):
85 method = name
86 if method is None:
87 raise ValueError( # pragma: no cover
88 f"Unable to guess a default method for '{repr(model)}'")
89 self.method = method
90 self._set_method(method)
92 def _set_method(self, method):
93 """
94 Defines the method to use to convert the features
95 into predictions.
96 """
97 if isinstance(method, str):
98 if method == 'predict':
99 self.method_ = self.model.predict
100 elif method == 'predict_proba':
101 self.method_ = self.model.predict_proba
102 elif method == 'decision_function':
103 self.method_ = self.model.decision_function
104 elif method == 'transform':
105 self.method_ = self.model.transform
106 else:
107 raise ValueError( # pragma: no cover
108 f"Unexpected method '{method}'")
109 elif callable(method):
110 self.method_ = method
111 else:
112 raise TypeError( # pragma: no cover
113 f"Unable to find the transform method, method={method}")
115 def fit(self, X, y=None, **kwargs):
116 """
117 Trains a model.
119 @param X features
120 @param y targets
121 @param kwargs additional parameters
122 @return self
123 """
124 self.model.fit(X, y=y, **kwargs)
125 return self
127 def transform(self, X):
128 """
129 Predictions, output of the embedded learner.
131 @param X features
132 @return prédictions
133 """
134 res = self.method_(X)
135 if len(res.shape) == 1:
136 res = res[:, numpy.newaxis]
137 return res
139 ##############
140 # cloning API
141 ##############
143 def get_params(self, deep=True):
144 """
145 Returns the parameters mandatory to clone the class.
147 @param deep unused here
148 @return dict
149 """
150 res = self.P.to_dict()
151 res['model'] = self.model
152 res['method'] = self.method
153 if deep:
154 par = self.model.get_params(deep)
155 for k, v in par.items():
156 res["model__" + k] = v
157 return res
159 def set_params(self, **values):
160 """
161 Sets parameters.
163 @param values parameters
164 """
165 if 'model' in values:
166 self.model = values['model']
167 del values['model']
168 elif not hasattr(self, 'model') or self.model is None:
169 raise KeyError( # pragma: no cover
170 f"Missing key 'model' in [{', '.join(sorted(values))}]")
171 if 'method' in values:
172 self._set_method(values['method'])
173 del values['method']
174 for k in values:
175 if not k.startswith('model__'):
176 raise ValueError( # pragma: no cover
177 f"Parameter '{k}' must start with 'model__'.")
178 d = len('model__')
179 pars = {k[d:]: v for k, v in values.items()}
180 self.model.set_params(**pars)
181 if 'method' in values:
182 self.method = values['method']
183 self._set_method(values['method'])
185 #################
186 # common methods
187 #################
189 def __repr__(self):
190 """
191 usual
192 """
193 rp = repr(self.model)
194 rps = repr(self.P)
195 res = f"{self.__class__.__name__}(model={rp}, method={self.method}, {rps})"
196 return "\n".join(textwrap.wrap(res, subsequent_indent=" "))