Coverage for mlinsights/mlmodel/sklearn_transform_inv_fct.py: 96%

105 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-28 08:46 +0100

1""" 

2@file 

3@brief Implements a transform which modifies the target 

4and applies the reverse transformation on the target. 

5""" 

6import numpy 

7from sklearn.exceptions import NotFittedError 

8from sklearn.neighbors import NearestNeighbors 

9from .sklearn_transform_inv import BaseReciprocalTransformer 

10 

11 

12class FunctionReciprocalTransformer(BaseReciprocalTransformer): 

13 """ 

14 The transform is used to apply a function on a the target, 

15 predict, then transform the target back before scoring. 

16 The transforms implements a series of predefined functions: 

17 

18 .. runpython:: 

19 :showcode: 

20 

21 import pprint 

22 from mlinsights.mlmodel.sklearn_transform_inv_fct import FunctionReciprocalTransformer 

23 pprint.pprint(FunctionReciprocalTransformer.available_fcts()) 

24 """ 

25 

26 @staticmethod 

27 def available_fcts(): 

28 """ 

29 Returns the list of predefined functions. 

30 """ 

31 return { 

32 'log': (numpy.log, 'exp'), 

33 'exp': (numpy.exp, 'log'), 

34 'log(1+x)': (lambda x: numpy.log(x + 1), 'exp(x)-1'), 

35 'log1p': (numpy.log1p, 'expm1'), 

36 'exp(x)-1': (lambda x: numpy.exp(x) - 1, 'log'), 

37 'expm1': (numpy.expm1, 'log1p'), 

38 } 

39 

40 def __init__(self, fct, fct_inv=None): 

41 """ 

42 @param fct function name of numerical function 

43 @param fct_inv optional if *fct* is a function name, 

44 reciprocal function otherwise 

45 """ 

46 BaseReciprocalTransformer.__init__(self) 

47 if isinstance(fct, str): 

48 if fct_inv is not None: 

49 raise ValueError( # pragma: no cover 

50 "If fct is a function name, fct_inv must not be specified.") 

51 opts = self.__class__.available_fcts() 

52 if fct not in opts: 

53 raise ValueError( # pragma: no cover 

54 f"Unknown fct '{fct}', it should in {list(sorted(opts))}.") 

55 else: 

56 if fct_inv is None: 

57 raise ValueError( 

58 "If fct is callable, fct_inv must be specified.") 

59 self.fct = fct 

60 self.fct_inv = fct_inv 

61 

62 def fit(self, X=None, y=None, sample_weight=None): 

63 """ 

64 Just defines *fct* and *fct_inv*. 

65 """ 

66 if callable(self.fct): 

67 self.fct_ = self.fct 

68 self.fct_inv_ = self.fct_inv 

69 else: 

70 opts = self.__class__.available_fcts() 

71 self.fct_, self.fct_inv_ = opts[self.fct] 

72 return self 

73 

74 def get_fct_inv(self): 

75 """ 

76 Returns a trained transform which reverse the target 

77 after a predictor. 

78 """ 

79 if isinstance(self.fct_inv_, str): 

80 res = FunctionReciprocalTransformer(self.fct_inv_) 

81 else: 

82 res = FunctionReciprocalTransformer(self.fct_inv_, self.fct_) 

83 return res.fit() 

84 

85 def transform(self, X, y): 

86 """ 

87 Transforms *X* and *y*. 

88 Returns transformed *X* and *y*. 

89 If *y* is None, the returned value for *y* 

90 is None as well. 

91 """ 

92 if y is None: 

93 return X, None 

94 return X, self.fct_(y) 

95 

96 

97class PermutationReciprocalTransformer(BaseReciprocalTransformer): 

98 """ 

99 The transform is used to permute targets, 

100 predict, then permute the target back before scoring. 

101 nan values remain nan values. Once fitted, the transform 

102 has attribute ``permutation_`` which keeps 

103 track of the permutation to apply. 

104 """ 

105 

106 def __init__(self, random_state=None, closest=False): 

107 """ 

108 @param random_state random state 

109 @param closest if True, finds the closest permuted element 

110 """ 

111 BaseReciprocalTransformer.__init__(self) 

112 self.random_state = random_state 

113 self.closest = closest 

114 

115 def fit(self, X=None, y=None, sample_weight=None): 

116 """ 

117 Defines a random permutation over the targets. 

118 """ 

119 if y is None: 

120 raise RuntimeError( # pragma: no cover 

121 "targets cannot be empty.") 

122 num = numpy.issubdtype(y.dtype, numpy.floating) 

123 perm = {} 

124 for u in y.ravel(): 

125 if num and numpy.isnan(u): 

126 continue 

127 if u in perm: 

128 continue 

129 perm[u] = len(perm) 

130 

131 lin = numpy.arange(len(perm)) 

132 if self.random_state is None: 

133 lin = numpy.random.permutation(lin) 

134 else: 

135 rs = numpy.random.RandomState( # pylint: disable=E1101 

136 self.random_state) # pylint: disable=E1101 

137 lin = rs.permutation(lin) 

138 

139 perm_keys = list(perm.keys()) 

140 for u in perm_keys: 

141 perm[u] = lin[perm[u]] 

142 self.permutation_ = perm 

143 

144 def _check_is_fitted(self): 

145 if not hasattr(self, 'permutation_'): 

146 raise NotFittedError( # pragma: no cover 

147 f"This instance {type(self)} is not fitted yet. Call 'fit' with " 

148 f"appropriate arguments before using this method.") 

149 

150 def get_fct_inv(self): 

151 """ 

152 Returns a trained transform which reverse the target 

153 after a predictor. 

154 """ 

155 self._check_is_fitted() 

156 res = PermutationReciprocalTransformer( 

157 self.random_state, closest=self.closest) 

158 res.permutation_ = {v: k for k, v in self.permutation_.items()} 

159 return res 

160 

161 def _find_closest(self, cl): 

162 if not hasattr(self, 'knn_'): 

163 self.knn_ = NearestNeighbors(n_neighbors=1, algorithm='kd_tree') 

164 self.knn_perm_ = numpy.array(list(self.permutation_)) 

165 self.knn_perm_ = self.knn_perm_.reshape((len(self.knn_perm_), 1)) 

166 self.knn_.fit(self.knn_perm_) 

167 ind = self.knn_.kneighbors([[cl]], return_distance=False) 

168 res = self.knn_perm_[ind, 0] 

169 if self.knn_perm_.dtype in (numpy.float32, numpy.float64): 

170 return float(res) 

171 if self.knn_perm_.dtype in (numpy.int32, numpy.int64): 

172 return int(res) 

173 raise NotImplementedError( # pragma: no cover 

174 f"The function does not work for type {self.knn_perm_.dtype}.") 

175 

176 def transform(self, X, y): 

177 """ 

178 Transforms *X* and *y*. 

179 Returns transformed *X* and *y*. 

180 If *y* is None, the returned value for *y* 

181 is None as well. 

182 """ 

183 if y is None: 

184 return X, None 

185 self._check_is_fitted() 

186 if len(y.shape) == 1 or y.dtype in (numpy.str, numpy.int32, numpy.int64): 

187 # permutes classes 

188 yp = y.copy().ravel() 

189 num = numpy.issubdtype(y.dtype, numpy.floating) 

190 for i in range(len(yp)): # pylint: disable=C0200 

191 if num and numpy.isnan(yp[i]): 

192 continue 

193 if yp[i] not in self.permutation_: 

194 if self.closest: 

195 cl = self._find_closest(yp[i]) 

196 else: 

197 raise RuntimeError( 

198 f"Unable to find key {yp[i]!r} in " 

199 f"{list(sorted(self.permutation_))!r}.") 

200 else: 

201 cl = yp[i] 

202 yp[i] = self.permutation_[cl] 

203 return X, yp.reshape(y.shape) 

204 else: 

205 # y is probababilies or raw score 

206 if len(y.shape) != 2: 

207 raise RuntimeError( 

208 f"yp should be a matrix but has shape {y.shape}.") 

209 cl = [(v, k) for k, v in self.permutation_.items()] 

210 cl.sort() 

211 new_perm = {} 

212 for cl, current in cl: 

213 new_perm[current] = len(new_perm) 

214 yp = y.copy() 

215 for i in range(y.shape[1]): 

216 yp[:, new_perm[i]] = y[:, i] 

217 return X, yp