Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Implements a transform which modifies the target 

4and applies the reverse transformation on the target. 

5""" 

6import numpy 

7from sklearn.exceptions import NotFittedError 

8from sklearn.neighbors import NearestNeighbors 

9from .sklearn_transform_inv import BaseReciprocalTransformer 

10 

11 

12class FunctionReciprocalTransformer(BaseReciprocalTransformer): 

13 """ 

14 The transform is used to apply a function on a the target, 

15 predict, then transform the target back before scoring. 

16 The transforms implements a series of predefined functions: 

17 

18 .. runpython:: 

19 :showcode: 

20 

21 import pprint 

22 from mlinsights.mlmodel.sklearn_transform_inv_fct import FunctionReciprocalTransformer 

23 pprint.pprint(FunctionReciprocalTransformer.available_fcts()) 

24 """ 

25 

26 @staticmethod 

27 def available_fcts(): 

28 """ 

29 Returns the list of predefined functions. 

30 """ 

31 return { 

32 'log': (numpy.log, 'exp'), 

33 'exp': (numpy.exp, 'log'), 

34 'log(1+x)': (lambda x: numpy.log(x + 1), 'exp(x)-1'), 

35 'log1p': (numpy.log1p, 'expm1'), 

36 'exp(x)-1': (lambda x: numpy.exp(x) - 1, 'log'), 

37 'expm1': (numpy.expm1, 'log1p'), 

38 } 

39 

40 def __init__(self, fct, fct_inv=None): 

41 """ 

42 @param fct function name of numerical function 

43 @param fct_inv optional if *fct* is a function name, 

44 reciprocal function otherwise 

45 """ 

46 BaseReciprocalTransformer.__init__(self) 

47 if isinstance(fct, str): 

48 if fct_inv is not None: 

49 raise ValueError( # pragma: no cover 

50 "If fct is a function name, fct_inv must not be specified.") 

51 opts = self.__class__.available_fcts() 

52 if fct not in opts: 

53 raise ValueError( # pragma: no cover 

54 "Unknown fct '{}', it should in {}.".format( 

55 fct, list(sorted(opts)))) 

56 else: 

57 if fct_inv is None: 

58 raise ValueError( 

59 "If fct is callable, fct_inv must be specified.") 

60 self.fct = fct 

61 self.fct_inv = fct_inv 

62 

63 def fit(self, X=None, y=None, sample_weight=None): 

64 """ 

65 Just defines *fct* and *fct_inv*. 

66 """ 

67 if callable(self.fct): 

68 self.fct_ = self.fct 

69 self.fct_inv_ = self.fct_inv 

70 else: 

71 opts = self.__class__.available_fcts() 

72 self.fct_, self.fct_inv_ = opts[self.fct] 

73 return self 

74 

75 def get_fct_inv(self): 

76 """ 

77 Returns a trained transform which reverse the target 

78 after a predictor. 

79 """ 

80 if isinstance(self.fct_inv_, str): 

81 res = FunctionReciprocalTransformer(self.fct_inv_) 

82 else: 

83 res = FunctionReciprocalTransformer(self.fct_inv_, self.fct_) 

84 return res.fit() 

85 

86 def transform(self, X, y): 

87 """ 

88 Transforms *X* and *y*. 

89 Returns transformed *X* and *y*. 

90 If *y* is None, the returned value for *y* 

91 is None as well. 

92 """ 

93 if y is None: 

94 return X, None 

95 return X, self.fct_(y) 

96 

97 

98class PermutationReciprocalTransformer(BaseReciprocalTransformer): 

99 """ 

100 The transform is used to permute targets, 

101 predict, then permute the target back before scoring. 

102 nan values remain nan values. Once fitted, the transform 

103 has attribute ``permutation_`` which keeps 

104 track of the permutation to apply. 

105 """ 

106 

107 def __init__(self, random_state=None, closest=False): 

108 """ 

109 @param random_state random state 

110 @param closest if True, finds the closest permuted element 

111 """ 

112 BaseReciprocalTransformer.__init__(self) 

113 self.random_state = random_state 

114 self.closest = closest 

115 

116 def fit(self, X=None, y=None, sample_weight=None): 

117 """ 

118 Defines a random permutation over the targets. 

119 """ 

120 if y is None: 

121 raise RuntimeError( # pragma: no cover 

122 "targets cannot be empty.") 

123 num = numpy.issubdtype(y.dtype, numpy.floating) 

124 perm = {} 

125 for u in y.ravel(): 

126 if num and numpy.isnan(u): 

127 continue 

128 if u in perm: 

129 continue 

130 perm[u] = len(perm) 

131 

132 lin = numpy.arange(len(perm)) 

133 if self.random_state is None: 

134 lin = numpy.random.permutation(lin) 

135 else: 

136 rs = numpy.random.RandomState( # pylint: disable=E1101 

137 self.random_state) # pylint: disable=E1101 

138 lin = rs.permutation(lin) 

139 

140 for u in perm: 

141 perm[u] = lin[perm[u]] 

142 self.permutation_ = perm 

143 

144 def _check_is_fitted(self): 

145 if not hasattr(self, 'permutation_'): 

146 raise NotFittedError( # pragma: no cover 

147 "This instance {} is not fitted yet. Call 'fit' with " 

148 "appropriate arguments before using this method.".format( 

149 type(self))) 

150 

151 def get_fct_inv(self): 

152 """ 

153 Returns a trained transform which reverse the target 

154 after a predictor. 

155 """ 

156 self._check_is_fitted() 

157 res = PermutationReciprocalTransformer( 

158 self.random_state, closest=self.closest) 

159 res.permutation_ = {v: k for k, v in self.permutation_.items()} 

160 return res 

161 

162 def _find_closest(self, cl): 

163 if not hasattr(self, 'knn_'): 

164 self.knn_ = NearestNeighbors(n_neighbors=1, algorithm='kd_tree') 

165 self.knn_perm_ = numpy.array(list(self.permutation_)) 

166 self.knn_perm_ = self.knn_perm_.reshape((len(self.knn_perm_), 1)) 

167 self.knn_.fit(self.knn_perm_) 

168 ind = self.knn_.kneighbors([[cl]], return_distance=False) 

169 res = self.knn_perm_[ind, 0] 

170 if self.knn_perm_.dtype in (numpy.float32, numpy.float64): 

171 return float(res) 

172 if self.knn_perm_.dtype in (numpy.int32, numpy.int64): 

173 return int(res) 

174 raise NotImplementedError( # pragma: no cover 

175 "The function does not work for type {}.".format( 

176 self.knn_perm_.dtype)) 

177 

178 def transform(self, X, y): 

179 """ 

180 Transforms *X* and *y*. 

181 Returns transformed *X* and *y*. 

182 If *y* is None, the returned value for *y* 

183 is None as well. 

184 """ 

185 if y is None: 

186 return X, None 

187 self._check_is_fitted() 

188 if len(y.shape) == 1 or y.dtype in (numpy.str, numpy.int32, numpy.int64): 

189 # permutes classes 

190 yp = y.copy().ravel() 

191 num = numpy.issubdtype(y.dtype, numpy.floating) 

192 for i in range(len(yp)): # pylint: disable=C0200 

193 if num and numpy.isnan(yp[i]): 

194 continue 

195 if yp[i] not in self.permutation_: 

196 if self.closest: 

197 cl = self._find_closest(yp[i]) 

198 else: 

199 raise RuntimeError("Unable to find key '{}' in {}.".format( 

200 yp[i], list(sorted(self.permutation_)))) 

201 else: 

202 cl = yp[i] 

203 yp[i] = self.permutation_[cl] 

204 return X, yp.reshape(y.shape) 

205 else: 

206 # y is probababilies or raw score 

207 if len(y.shape) != 2: 

208 raise RuntimeError( 

209 "yp should be a matrix but has shape {}.".format(y.shape)) 

210 cl = [(v, k) for k, v in self.permutation_.items()] 

211 cl.sort() 

212 new_perm = {} 

213 for cl, current in cl: 

214 new_perm[current] = len(new_perm) 

215 yp = y.copy() 

216 for i in range(y.shape[1]): 

217 yp[:, new_perm[i]] = y[:, i] 

218 return X, yp