Coverage for mlinsights/sklapi/sklearn_base_transform_stacking.py: 96%
73 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-28 08:46 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-28 08:46 +0100
1# -*- coding: utf-8 -*-
2"""
3@file
4@brief Implémente un *transform* qui suit la même API que tout :epkg:`scikit-learn` transform.
5"""
6import textwrap
7import numpy
8from .sklearn_base_transform import SkBaseTransform
9from .sklearn_base_transform_learner import SkBaseTransformLearner
12class SkBaseTransformStacking(SkBaseTransform):
13 """
14 Un *transform* qui cache plusieurs *learners*, arrangés
15 selon la méthode du `stacking <http://blog.kaggle.com/2016/12/27/a-kagglers-guide-to-model-stacking-in-practice/>`_.
17 .. exref::
18 :title: Stacking de plusieurs learners dans un pipeline scikit-learn.
19 :tag: sklearn
20 :lid: ex-pipe2learner2
22 Ce *transform* assemble les résultats de plusieurs learners.
23 Ces features servent d'entrée à un modèle de stacking.
25 .. runpython::
26 :showcode:
27 :warningout: FutureWarning
29 from sklearn.model_selection import train_test_split
30 from sklearn.datasets import load_iris
31 from sklearn.linear_model import LogisticRegression
32 from sklearn.tree import DecisionTreeClassifier
33 from sklearn.metrics import accuracy_score
34 from sklearn.pipeline import make_pipeline
35 from mlinsights.sklapi import SkBaseTransformStacking
37 data = load_iris()
38 X, y = data.data, data.target
39 X_train, X_test, y_train, y_test = train_test_split(X, y)
41 trans = SkBaseTransformStacking([LogisticRegression(),
42 DecisionTreeClassifier()])
43 trans.fit(X_train, y_train)
44 pred = trans.transform(X_test)
45 print(pred[3:])
46 """
48 def __init__(self, models=None, method=None, **kwargs):
49 """
50 @param models list of learners
51 @param method methods or list of methods to call
52 to convert features into prediction
53 (see below)
54 @param kwargs parameters
56 Available options for parameter *method*:
58 * ``'predict'``
59 * ``'predict_proba'``
60 * ``'decision_function'``
61 * a function
63 If *method is None*, the default value is first
64 ``predict_proba`` it it exists then ``predict``.
65 """
66 super().__init__(**kwargs)
67 if models is None:
68 raise ValueError("models cannot be None") # pragma: no cover
69 if not isinstance(models, list):
70 raise TypeError( # pragma: no cover
71 f"models must be a list not {type(models)}")
72 if method is None:
73 method = 'predict'
74 if not isinstance(method, str):
75 raise TypeError( # pragma: no cover
76 f"Method must be a string not {type(method)}")
77 self.method = method
78 if isinstance(method, list):
79 if len(method) != len(models): # pragma: no cover
80 raise ValueError(
81 f"models and methods must have the same "
82 f"length: {len(models)} != {len(method)}.")
83 else:
84 method = [method for m in models]
86 def convert2transform(c, new_learners):
87 "converting function into a transform"
88 m, me = c
89 if isinstance(m, SkBaseTransformLearner):
90 if me == m.method:
91 return m
92 res = SkBaseTransformLearner(m.model, me)
93 new_learners.append(res)
94 return res
95 if hasattr(m, 'transform'):
96 return m
97 res = SkBaseTransformLearner(m, me)
98 new_learners.append(res)
99 return res
101 new_learners = []
102 res = list(map(lambda c: convert2transform(
103 c, new_learners), zip(models, method)))
104 if len(new_learners) == 0:
105 # We need to do that to avoid creating new objects
106 # when it is not necessary. This behavior is not
107 # supported anymore by scikit-learn.
108 # See sklearn.base.py
109 self.models = models
110 else:
111 self.models = res
113 def fit(self, X, y=None, **kwargs):
114 """
115 Trains a model.
117 @param X features
118 @param y targets
119 @param kwargs additional parameters
120 @return self
121 """
122 for m in self.models:
123 m.fit(X, y=y, **kwargs)
124 return self
126 def transform(self, X):
127 """
128 Calls the learners predictions to convert
129 the features.
131 @param X features
132 @return prédictions
133 """
134 Xs = [m.transform(X) for m in self.models]
135 return numpy.hstack(Xs)
137 ##############
138 # cloning API
139 ##############
141 def get_params(self, deep=True):
142 """
143 Returns the parameters which define the object.
144 It follows :epkg:`scikit-learn` API.
146 @param deep unused here
147 @return dict
148 """
149 res = self.P.to_dict()
150 res['models'] = self.models
151 res['method'] = self.method
152 if deep:
153 for i, m in enumerate(self.models):
154 par = m.get_params(deep)
155 for k, v in par.items():
156 res[f"models_{i}__" + k] = v
157 return res
159 def set_params(self, **values):
160 """
161 Sets the parameters.
163 @param params parameters
164 """
165 if 'models' in values:
166 self.models = values['models']
167 del values['models']
168 if 'method' in values:
169 self.method = values['method']
170 del values['method']
171 for k, v in values.items():
172 if not k.startswith('models_'):
173 raise ValueError( # pragma: no cover
174 f"Parameter '{k}' must start with 'models_'.")
175 d = len('models_')
176 pars = [{} for m in self.models]
177 for k, v in values.items():
178 si = k[d:].split('__', 1)
179 i = int(si[0])
180 pars[i][k[d + 1 + len(si):]] = v
181 for p, m in zip(pars, self.models):
182 if p:
183 m.set_params(**p)
185 #################
186 # common methods
187 #################
189 def __repr__(self):
190 """
191 usual
192 """
193 rps = repr(self.P)
194 res = "{0}([{1}], [{2}], {3})".format(
195 self.__class__.__name__,
196 ", ".join(repr(m.model if hasattr(m, 'model') else m)
197 for m in self.models),
198 ", ".join(repr(m.method if hasattr(m, 'method') else None) for m in self.models), rps)
199 return "\n".join(textwrap.wrap(res, subsequent_indent=" "))