Coverage for mlinsights/mlmodel/sklearn_transform_inv_fct.py: 96%
105 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-28 08:46 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-28 08:46 +0100
1"""
2@file
3@brief Implements a transform which modifies the target
4and applies the reverse transformation on the target.
5"""
6import numpy
7from sklearn.exceptions import NotFittedError
8from sklearn.neighbors import NearestNeighbors
9from .sklearn_transform_inv import BaseReciprocalTransformer
12class FunctionReciprocalTransformer(BaseReciprocalTransformer):
13 """
14 The transform is used to apply a function on a the target,
15 predict, then transform the target back before scoring.
16 The transforms implements a series of predefined functions:
18 .. runpython::
19 :showcode:
21 import pprint
22 from mlinsights.mlmodel.sklearn_transform_inv_fct import FunctionReciprocalTransformer
23 pprint.pprint(FunctionReciprocalTransformer.available_fcts())
24 """
26 @staticmethod
27 def available_fcts():
28 """
29 Returns the list of predefined functions.
30 """
31 return {
32 'log': (numpy.log, 'exp'),
33 'exp': (numpy.exp, 'log'),
34 'log(1+x)': (lambda x: numpy.log(x + 1), 'exp(x)-1'),
35 'log1p': (numpy.log1p, 'expm1'),
36 'exp(x)-1': (lambda x: numpy.exp(x) - 1, 'log'),
37 'expm1': (numpy.expm1, 'log1p'),
38 }
40 def __init__(self, fct, fct_inv=None):
41 """
42 @param fct function name of numerical function
43 @param fct_inv optional if *fct* is a function name,
44 reciprocal function otherwise
45 """
46 BaseReciprocalTransformer.__init__(self)
47 if isinstance(fct, str):
48 if fct_inv is not None:
49 raise ValueError( # pragma: no cover
50 "If fct is a function name, fct_inv must not be specified.")
51 opts = self.__class__.available_fcts()
52 if fct not in opts:
53 raise ValueError( # pragma: no cover
54 f"Unknown fct '{fct}', it should in {list(sorted(opts))}.")
55 else:
56 if fct_inv is None:
57 raise ValueError(
58 "If fct is callable, fct_inv must be specified.")
59 self.fct = fct
60 self.fct_inv = fct_inv
62 def fit(self, X=None, y=None, sample_weight=None):
63 """
64 Just defines *fct* and *fct_inv*.
65 """
66 if callable(self.fct):
67 self.fct_ = self.fct
68 self.fct_inv_ = self.fct_inv
69 else:
70 opts = self.__class__.available_fcts()
71 self.fct_, self.fct_inv_ = opts[self.fct]
72 return self
74 def get_fct_inv(self):
75 """
76 Returns a trained transform which reverse the target
77 after a predictor.
78 """
79 if isinstance(self.fct_inv_, str):
80 res = FunctionReciprocalTransformer(self.fct_inv_)
81 else:
82 res = FunctionReciprocalTransformer(self.fct_inv_, self.fct_)
83 return res.fit()
85 def transform(self, X, y):
86 """
87 Transforms *X* and *y*.
88 Returns transformed *X* and *y*.
89 If *y* is None, the returned value for *y*
90 is None as well.
91 """
92 if y is None:
93 return X, None
94 return X, self.fct_(y)
97class PermutationReciprocalTransformer(BaseReciprocalTransformer):
98 """
99 The transform is used to permute targets,
100 predict, then permute the target back before scoring.
101 nan values remain nan values. Once fitted, the transform
102 has attribute ``permutation_`` which keeps
103 track of the permutation to apply.
104 """
106 def __init__(self, random_state=None, closest=False):
107 """
108 @param random_state random state
109 @param closest if True, finds the closest permuted element
110 """
111 BaseReciprocalTransformer.__init__(self)
112 self.random_state = random_state
113 self.closest = closest
115 def fit(self, X=None, y=None, sample_weight=None):
116 """
117 Defines a random permutation over the targets.
118 """
119 if y is None:
120 raise RuntimeError( # pragma: no cover
121 "targets cannot be empty.")
122 num = numpy.issubdtype(y.dtype, numpy.floating)
123 perm = {}
124 for u in y.ravel():
125 if num and numpy.isnan(u):
126 continue
127 if u in perm:
128 continue
129 perm[u] = len(perm)
131 lin = numpy.arange(len(perm))
132 if self.random_state is None:
133 lin = numpy.random.permutation(lin)
134 else:
135 rs = numpy.random.RandomState( # pylint: disable=E1101
136 self.random_state) # pylint: disable=E1101
137 lin = rs.permutation(lin)
139 perm_keys = list(perm.keys())
140 for u in perm_keys:
141 perm[u] = lin[perm[u]]
142 self.permutation_ = perm
144 def _check_is_fitted(self):
145 if not hasattr(self, 'permutation_'):
146 raise NotFittedError( # pragma: no cover
147 f"This instance {type(self)} is not fitted yet. Call 'fit' with "
148 f"appropriate arguments before using this method.")
150 def get_fct_inv(self):
151 """
152 Returns a trained transform which reverse the target
153 after a predictor.
154 """
155 self._check_is_fitted()
156 res = PermutationReciprocalTransformer(
157 self.random_state, closest=self.closest)
158 res.permutation_ = {v: k for k, v in self.permutation_.items()}
159 return res
161 def _find_closest(self, cl):
162 if not hasattr(self, 'knn_'):
163 self.knn_ = NearestNeighbors(n_neighbors=1, algorithm='kd_tree')
164 self.knn_perm_ = numpy.array(list(self.permutation_))
165 self.knn_perm_ = self.knn_perm_.reshape((len(self.knn_perm_), 1))
166 self.knn_.fit(self.knn_perm_)
167 ind = self.knn_.kneighbors([[cl]], return_distance=False)
168 res = self.knn_perm_[ind, 0]
169 if self.knn_perm_.dtype in (numpy.float32, numpy.float64):
170 return float(res)
171 if self.knn_perm_.dtype in (numpy.int32, numpy.int64):
172 return int(res)
173 raise NotImplementedError( # pragma: no cover
174 f"The function does not work for type {self.knn_perm_.dtype}.")
176 def transform(self, X, y):
177 """
178 Transforms *X* and *y*.
179 Returns transformed *X* and *y*.
180 If *y* is None, the returned value for *y*
181 is None as well.
182 """
183 if y is None:
184 return X, None
185 self._check_is_fitted()
186 if len(y.shape) == 1 or y.dtype in (numpy.str, numpy.int32, numpy.int64):
187 # permutes classes
188 yp = y.copy().ravel()
189 num = numpy.issubdtype(y.dtype, numpy.floating)
190 for i in range(len(yp)): # pylint: disable=C0200
191 if num and numpy.isnan(yp[i]):
192 continue
193 if yp[i] not in self.permutation_:
194 if self.closest:
195 cl = self._find_closest(yp[i])
196 else:
197 raise RuntimeError(
198 f"Unable to find key {yp[i]!r} in "
199 f"{list(sorted(self.permutation_))!r}.")
200 else:
201 cl = yp[i]
202 yp[i] = self.permutation_[cl]
203 return X, yp.reshape(y.shape)
204 else:
205 # y is probababilies or raw score
206 if len(y.shape) != 2:
207 raise RuntimeError(
208 f"yp should be a matrix but has shape {y.shape}.")
209 cl = [(v, k) for k, v in self.permutation_.items()]
210 cl.sort()
211 new_perm = {}
212 for cl, current in cl:
213 new_perm[current] = len(new_perm)
214 yp = y.copy()
215 for i in range(y.shape[1]):
216 yp[:, new_perm[i]] = y[:, i]
217 return X, yp