Coverage for mlinsights/sklapi/sklearn_base_transform_stacking.py: 96%

73 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-28 08:46 +0100

1# -*- coding: utf-8 -*- 

2""" 

3@file 

4@brief Implémente un *transform* qui suit la même API que tout :epkg:`scikit-learn` transform. 

5""" 

6import textwrap 

7import numpy 

8from .sklearn_base_transform import SkBaseTransform 

9from .sklearn_base_transform_learner import SkBaseTransformLearner 

10 

11 

12class SkBaseTransformStacking(SkBaseTransform): 

13 """ 

14 Un *transform* qui cache plusieurs *learners*, arrangés 

15 selon la méthode du `stacking <http://blog.kaggle.com/2016/12/27/a-kagglers-guide-to-model-stacking-in-practice/>`_. 

16 

17 .. exref:: 

18 :title: Stacking de plusieurs learners dans un pipeline scikit-learn. 

19 :tag: sklearn 

20 :lid: ex-pipe2learner2 

21 

22 Ce *transform* assemble les résultats de plusieurs learners. 

23 Ces features servent d'entrée à un modèle de stacking. 

24 

25 .. runpython:: 

26 :showcode: 

27 :warningout: FutureWarning 

28 

29 from sklearn.model_selection import train_test_split 

30 from sklearn.datasets import load_iris 

31 from sklearn.linear_model import LogisticRegression 

32 from sklearn.tree import DecisionTreeClassifier 

33 from sklearn.metrics import accuracy_score 

34 from sklearn.pipeline import make_pipeline 

35 from mlinsights.sklapi import SkBaseTransformStacking 

36 

37 data = load_iris() 

38 X, y = data.data, data.target 

39 X_train, X_test, y_train, y_test = train_test_split(X, y) 

40 

41 trans = SkBaseTransformStacking([LogisticRegression(), 

42 DecisionTreeClassifier()]) 

43 trans.fit(X_train, y_train) 

44 pred = trans.transform(X_test) 

45 print(pred[3:]) 

46 """ 

47 

48 def __init__(self, models=None, method=None, **kwargs): 

49 """ 

50 @param models list of learners 

51 @param method methods or list of methods to call 

52 to convert features into prediction 

53 (see below) 

54 @param kwargs parameters 

55 

56 Available options for parameter *method*: 

57 

58 * ``'predict'`` 

59 * ``'predict_proba'`` 

60 * ``'decision_function'`` 

61 * a function 

62 

63 If *method is None*, the default value is first 

64 ``predict_proba`` it it exists then ``predict``. 

65 """ 

66 super().__init__(**kwargs) 

67 if models is None: 

68 raise ValueError("models cannot be None") # pragma: no cover 

69 if not isinstance(models, list): 

70 raise TypeError( # pragma: no cover 

71 f"models must be a list not {type(models)}") 

72 if method is None: 

73 method = 'predict' 

74 if not isinstance(method, str): 

75 raise TypeError( # pragma: no cover 

76 f"Method must be a string not {type(method)}") 

77 self.method = method 

78 if isinstance(method, list): 

79 if len(method) != len(models): # pragma: no cover 

80 raise ValueError( 

81 f"models and methods must have the same " 

82 f"length: {len(models)} != {len(method)}.") 

83 else: 

84 method = [method for m in models] 

85 

86 def convert2transform(c, new_learners): 

87 "converting function into a transform" 

88 m, me = c 

89 if isinstance(m, SkBaseTransformLearner): 

90 if me == m.method: 

91 return m 

92 res = SkBaseTransformLearner(m.model, me) 

93 new_learners.append(res) 

94 return res 

95 if hasattr(m, 'transform'): 

96 return m 

97 res = SkBaseTransformLearner(m, me) 

98 new_learners.append(res) 

99 return res 

100 

101 new_learners = [] 

102 res = list(map(lambda c: convert2transform( 

103 c, new_learners), zip(models, method))) 

104 if len(new_learners) == 0: 

105 # We need to do that to avoid creating new objects 

106 # when it is not necessary. This behavior is not 

107 # supported anymore by scikit-learn. 

108 # See sklearn.base.py 

109 self.models = models 

110 else: 

111 self.models = res 

112 

113 def fit(self, X, y=None, **kwargs): 

114 """ 

115 Trains a model. 

116 

117 @param X features 

118 @param y targets 

119 @param kwargs additional parameters 

120 @return self 

121 """ 

122 for m in self.models: 

123 m.fit(X, y=y, **kwargs) 

124 return self 

125 

126 def transform(self, X): 

127 """ 

128 Calls the learners predictions to convert 

129 the features. 

130 

131 @param X features 

132 @return prédictions 

133 """ 

134 Xs = [m.transform(X) for m in self.models] 

135 return numpy.hstack(Xs) 

136 

137 ############## 

138 # cloning API 

139 ############## 

140 

141 def get_params(self, deep=True): 

142 """ 

143 Returns the parameters which define the object. 

144 It follows :epkg:`scikit-learn` API. 

145 

146 @param deep unused here 

147 @return dict 

148 """ 

149 res = self.P.to_dict() 

150 res['models'] = self.models 

151 res['method'] = self.method 

152 if deep: 

153 for i, m in enumerate(self.models): 

154 par = m.get_params(deep) 

155 for k, v in par.items(): 

156 res[f"models_{i}__" + k] = v 

157 return res 

158 

159 def set_params(self, **values): 

160 """ 

161 Sets the parameters. 

162 

163 @param params parameters 

164 """ 

165 if 'models' in values: 

166 self.models = values['models'] 

167 del values['models'] 

168 if 'method' in values: 

169 self.method = values['method'] 

170 del values['method'] 

171 for k, v in values.items(): 

172 if not k.startswith('models_'): 

173 raise ValueError( # pragma: no cover 

174 f"Parameter '{k}' must start with 'models_'.") 

175 d = len('models_') 

176 pars = [{} for m in self.models] 

177 for k, v in values.items(): 

178 si = k[d:].split('__', 1) 

179 i = int(si[0]) 

180 pars[i][k[d + 1 + len(si):]] = v 

181 for p, m in zip(pars, self.models): 

182 if p: 

183 m.set_params(**p) 

184 

185 ################# 

186 # common methods 

187 ################# 

188 

189 def __repr__(self): 

190 """ 

191 usual 

192 """ 

193 rps = repr(self.P) 

194 res = "{0}([{1}], [{2}], {3})".format( 

195 self.__class__.__name__, 

196 ", ".join(repr(m.model if hasattr(m, 'model') else m) 

197 for m in self.models), 

198 ", ".join(repr(m.method if hasattr(m, 'method') else None) for m in self.models), rps) 

199 return "\n".join(textwrap.wrap(res, subsequent_indent=" "))