Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Combines a *k-means* followed by a predictor.
4"""
5import textwrap
6import inspect
7import numpy
8from sklearn.linear_model import LogisticRegression
9from sklearn.cluster import KMeans
10from sklearn.base import BaseEstimator, ClassifierMixin, clone
13class ClassifierAfterKMeans(BaseEstimator, ClassifierMixin):
14 """
15 Applies a *k-means* (see :epkg:`sklearn:cluster:KMeans`)
16 for each class, then adds the distance to each cluster
17 as a feature for a classifier.
18 See notebook :ref:`logisticregressionclusteringrst`.
19 """
21 def __init__(self, estimator=None, clus=None, **kwargs):
22 """
23 @param estimator :epkg:`sklearn:linear_model:LogisiticRegression`
24 by default
25 @param clus clustering applied on each class,
26 by default k-means with two classes
27 @param kwargs sent to :meth:`set_params
28 <mlinsights.mlmodel.classification_kmeans.
29 ClassifierAfterKMeans.set_params>`,
30 see its documentation to understand how to
31 specify parameters
32 """
33 ClassifierMixin.__init__(self)
34 BaseEstimator.__init__(self)
35 if estimator is None:
36 estimator = LogisticRegression()
37 if clus is None:
38 clus = KMeans(n_clusters=2)
39 self.estimator = estimator
40 self.clus = clus
41 if not hasattr(clus, "transform"):
42 raise AttributeError( # pragma: no cover
43 "clus does not have a transform method.")
44 if kwargs:
45 self.set_params(**kwargs)
47 def fit(self, X, y, sample_weight=None):
48 """
49 Runs a *k-means* on each class
50 then trains a classifier on the
51 extended set of features.
53 :param X: numpy array or sparse matrix of shape [n_samples,n_features]
54 Training data
55 :param y: numpy array of shape [n_samples, n_targets]
56 Target values. Will be cast to X's dtype if necessary
57 :param sample_weight: numpy array of shape [n_samples]
58 Individual weights for each sample
59 :return: self : returns an instance of self.
61 Fitting attributes:
62 * `labels_`: dictionary of clustering models
63 * `clus_`: array of clustering models
64 * `estimator_`: trained classifier
65 """
66 classes = set(y)
67 self.labels_ = list(sorted(classes))
68 self.clus_ = {}
69 sig = inspect.signature(self.clus.fit)
70 for cl in classes:
71 m = clone(self.clus)
72 Xcl = X[y == cl]
73 if sample_weight is None or 'sample_weight' not in sig.parameters:
74 w = None
75 m.fit(Xcl)
76 else:
77 w = sample_weight[y == cl]
78 m.fit(Xcl, sample_weight=w)
79 self.clus_[cl] = m
81 extX = self.transform_features(X)
82 self.estimator_ = self.estimator.fit(
83 extX, y, sample_weight=sample_weight)
84 return self
86 def transform_features(self, X):
87 """
88 Applies all the clustering objects
89 on every observations and extends the list of
90 features.
92 @param X features
93 @return extended features
94 """
95 preds = []
96 for _, v in sorted(self.clus_.items()):
97 p = v.transform(X)
98 preds.append(p)
99 return numpy.hstack(preds)
101 def predict(self, X):
102 """
103 Runs the predictions.
104 """
105 extX = self.transform_features(X)
106 return self.estimator.predict(extX)
108 def predict_proba(self, X):
109 """
110 Converts predictions into probabilities.
111 """
112 extX = self.transform_features(X)
113 return self.estimator.predict_proba(extX)
115 def decision_function(self, X):
116 """
117 Calls *decision_function*.
118 """
119 extX = self.transform_features(X)
120 return self.estimator.decision_function(extX)
122 def get_params(self, deep=True):
123 """
124 Returns the parameters for both
125 the clustering and the classifier.
127 @param deep unused here
128 @return dict
130 :meth:`set_params <mlinsights.mlmodel.classification_kmeans.
131 ClassifierAfterKMeans.set_params>`
132 describes the pattern parameters names follow.
133 """
134 res = {}
135 for k, v in self.clus.get_params().items():
136 res["c_" + k] = v
137 for k, v in self.estimator.get_params().items():
138 res["e_" + k] = v
139 return res
141 def set_params(self, **values):
142 """
143 Sets the parameters before training.
144 Every parameter prefixed by ``'e_'`` is an estimator
145 parameter, every parameter prefixed by ``'c_'`` is for
146 the :epkg:`sklearn:cluster:KMeans`.
148 @param values valeurs
149 @return dict
150 """
151 pc, pe = {}, {}
152 for k, v in values.items():
153 if k.startswith('e_'):
154 pe[k[2:]] = v
155 elif k.startswith('c_'):
156 pc[k[2:]] = v
157 else:
158 raise ValueError( # pragma: no cover
159 "Unexpected parameter name '{0}'".format(k))
160 self.clus.set_params(**pc)
161 self.estimator.set_params(**pe)
163 def __repr__(self): # pylint: disable=W0222
164 """
165 Overloads `repr` as *scikit-learn* now relies
166 on the constructor signature.
167 """
168 el = ', '.join(['%s=%r' % (k, v)
169 for k, v in self.get_params().items()])
170 text = "%s(%s)" % (self.__class__.__name__, el)
171 lines = textwrap.wrap(text, subsequent_indent=' ')
172 return "\n".join(lines)