Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Helpers to visualize a pipeline. 

4""" 

5import pprint 

6from collections import OrderedDict 

7import numpy 

8import pandas 

9from sklearn.base import TransformerMixin, ClassifierMixin, RegressorMixin 

10from sklearn.pipeline import Pipeline, FeatureUnion 

11from sklearn.compose import ColumnTransformer, TransformedTargetRegressor 

12from ..helpers.pipeline import enumerate_pipeline_models 

13 

14 

15def _pipeline_info(pipe, data, context, former_data=None): 

16 """ 

17 Internal function to convert a pipeline into 

18 some graph. 

19 """ 

20 def _get_name(context, prefix='-v-', info=None, data=None): 

21 if info is None: 

22 raise RuntimeError("info should not be None") # pragma: no cover 

23 if isinstance(prefix, list): 

24 return [_get_name(context, el, info, data) for el in prefix] 

25 if isinstance(prefix, int): 

26 prefix = former_data[prefix] 

27 if isinstance(prefix, int): 

28 raise TypeError( # pragma: no cover 

29 "prefix must be a string.\ninfo={}".format(info)) 

30 sug = "%s%d" % (prefix, context['n']) 

31 while sug in context['names']: 

32 context['n'] += 1 

33 sug = "%s%d" % (prefix, context['n']) 

34 context['names'][sug] = info 

35 return sug 

36 

37 def _get_name_simple(name, data): 

38 if isinstance(name, str): 

39 return name 

40 res = data[name] 

41 if isinstance(res, int): 

42 raise RuntimeError( # pragma: no cover 

43 "Column name is still a number and not a name: {} and {}." 

44 "".format(name, data)) 

45 return res 

46 

47 if isinstance(pipe, Pipeline): 

48 infos = [] 

49 for _, model in pipe.steps: 

50 info = _pipeline_info(model, data, context) 

51 data = info[-1]["outputs"] 

52 infos.extend(info) 

53 return infos 

54 

55 if isinstance(pipe, ColumnTransformer): 

56 infos = [] 

57 outputs = [] 

58 for _, model, vs in pipe.transformers: 

59 if all(map(lambda o: isinstance(o, int), vs)): 

60 new_data = [] 

61 if isinstance(data, OrderedDict): 

62 new_data = [_[1] for _ in data.items()] 

63 else: 

64 mx = max(vs) 

65 while len(new_data) < mx: 

66 if len(data) > len(new_data): 

67 new_data.append(data[len(new_data)]) 

68 else: 

69 new_data.append(data[-1]) 

70 else: 

71 new_data = OrderedDict() 

72 for v in vs: 

73 new_data[v] = data.get(v, v) 

74 

75 info = _pipeline_info( 

76 model, new_data, context, former_data=new_data) 

77 #new_outputs = [] 

78 # for o in info[-1]['outputs']: 

79 # add = _get_name(context, prefix=o, info=info) 

80 # outputs.append(add) 

81 # new_outputs.append(add) 

82 #info[-1]['outputs'] = new_outputs 

83 outputs.extend(info[-1]['outputs']) 

84 infos.extend(info) 

85 

86 final_hat = False 

87 if pipe.remainder == "passthrough": 

88 

89 done = [set(d['inputs']) for d in info] 

90 merged = done[0] 

91 for d in done[1:]: 

92 merged.union(d) 

93 new_data = OrderedDict( 

94 [(k, v) for k, v in data.items() if k not in merged]) 

95 

96 info = _pipeline_info( 

97 "passthrough", new_data, context, former_data=new_data) 

98 outputs.extend(info[-1]['outputs']) 

99 infos.extend(info) 

100 final_hat = True 

101 

102 if len(pipe.transformers) > 1 or final_hat: 

103 info = {'name': 'union', 'inputs': outputs, 'type': 'transform'} 

104 info['outputs'] = [_get_name(context, info=info)] 

105 infos.append(info) 

106 return infos 

107 

108 if isinstance(pipe, FeatureUnion): 

109 infos = [] 

110 outputs = [] 

111 for _, model in pipe.transformer_list: 

112 info = _pipeline_info(model, data, context) 

113 new_outputs = [] 

114 for o in info[-1]['outputs']: 

115 add = _get_name(context, prefix=o, info=info) 

116 outputs.append(add) 

117 new_outputs.append(add) 

118 info[-1]['outputs'] = new_outputs 

119 infos.extend(info) 

120 if len(pipe.transformer_list) > 1: 

121 info = {'name': 'union', 'inputs': outputs, 'type': 'transform'} 

122 info['outputs'] = [_get_name(context, info=info)] 

123 infos.append(info) 

124 return infos 

125 

126 if isinstance(pipe, TransformedTargetRegressor): 

127 raise NotImplementedError( # pragma: no cover 

128 "Not yet implemented for TransformedTargetRegressor.") 

129 

130 if isinstance(pipe, TransformerMixin): 

131 info = {'name': pipe.__class__.__name__, 'type': 'transform'} 

132 if len(data) == 1: 

133 info['outputs'] = data 

134 info['inputs'] = data 

135 info = [info] 

136 else: 

137 info['inputs'] = [_get_name(context, info=info)] 

138 info['outputs'] = [_get_name(context, info=info)] 

139 info = [{'name': 'union', 'outputs': info['inputs'], 

140 'inputs': data, 'type': 'transform'}, info] 

141 return info 

142 

143 if isinstance(pipe, ClassifierMixin): 

144 info = {'name': pipe.__class__.__name__, 'type': 'classifier'} 

145 exp = ['PredictedLabel', 'Probabilities'] 

146 if len(data) == 1: 

147 info['outputs'] = exp 

148 info['inputs'] = data 

149 info = [info] 

150 else: 

151 info['outputs'] = exp 

152 info['inputs'] = [_get_name(context, info=info)] 

153 info = [{'name': 'union', 'outputs': info['inputs'], 'inputs': data, 

154 'type': 'transform'}, info] 

155 return info 

156 

157 if isinstance(pipe, RegressorMixin): 

158 info = {'name': pipe.__class__.__name__, 'type': 'regressor'} 

159 exp = ['Prediction'] 

160 if len(data) == 1: 

161 info['outputs'] = exp 

162 info['inputs'] = data 

163 info = [info] 

164 else: 

165 info['outputs'] = exp 

166 info['inputs'] = [_get_name(context, info=info)] 

167 info = [{'name': 'union', 'outputs': info['inputs'], 'inputs': data, 

168 'type': 'transform'}, info] 

169 return info 

170 

171 if isinstance(pipe, str): 

172 if pipe == "passthrough": 

173 info = {'name': 'Identity', 'type': 'transform'} 

174 info['inputs'] = [_get_name_simple(n, former_data) for n in data] 

175 if isinstance(data, (OrderedDict, dict)) and len(data) > 1: 

176 info['outputs'] = [ 

177 _get_name(context, data=k, info=info) 

178 for k in data] 

179 else: 

180 info['outputs'] = _get_name(context, data=data, info=info) 

181 info = [info] 

182 else: 

183 raise NotImplementedError( # pragma: no cover 

184 "Not yet implemented for keyword '{}'.".format(type(pipe))) 

185 return info 

186 

187 raise NotImplementedError( # pragma: no cover 

188 "Not yet implemented for {}.".format(type(pipe))) 

189 

190 

191def pipeline2dot(pipe, data, **params): 

192 """ 

193 Exports a *scikit-learn* pipeline to 

194 :epkg:`DOT` language. See :ref:`visualizepipelinerst` 

195 for an example. 

196 

197 @param pipe *scikit-learn* pipeline 

198 @param data training data as a dataframe or a numpy array, 

199 or just a list with the variable names 

200 @param params additional params to draw the graph 

201 @return string 

202 

203 Default options for the graph are: 

204 

205 :: 

206 

207 options = { 

208 'orientation': 'portrait', 

209 'ranksep': '0.25', 

210 'nodesep': '0.05', 

211 'width': '0.5', 

212 'height': '0.1', 

213 } 

214 """ 

215 raw_data = data 

216 data = OrderedDict() 

217 if isinstance(raw_data, pandas.DataFrame): 

218 for k, c in enumerate(raw_data.columns): 

219 data[c] = 'sch0:f%d' % k 

220 elif isinstance(raw_data, numpy.ndarray): 

221 if len(raw_data.shape) != 2: 

222 raise NotImplementedError( # pragma: no cover 

223 "Unexpected training data dimension: {}.".format( 

224 data.shape)) # pylint: disable=E1101 

225 for i in range(raw_data.shape[1]): 

226 data['X%d' % i] = 'sch0:f%d' % i 

227 elif not isinstance(raw_data, list): 

228 raise TypeError( # pragma: no cover 

229 "Unexpected data type: {}.".format(type(raw_data))) 

230 

231 options = { 

232 'orientation': 'portrait', 

233 'ranksep': '0.25', 

234 'nodesep': '0.05', 

235 'width': '0.5', 

236 'height': '0.1', 

237 } 

238 options.update(params) 

239 

240 exp = ["digraph{"] 

241 for opt in ['orientation', 'pad', 'nodesep', 'ranksep']: 

242 if opt in options: 

243 exp.append(" {}={};".format(opt, options[opt])) 

244 fontsize = 8 

245 info = [dict(schema_after=data)] 

246 names = OrderedDict() 

247 for d in data: 

248 names[d] = info 

249 info.extend(_pipeline_info(pipe, data, context=dict(n=0, names=names))) 

250 columns = OrderedDict() 

251 

252 for i, line in enumerate(info): 

253 if i == 0: 

254 schema = line['schema_after'] 

255 labs = [] 

256 for c, col in enumerate(schema): 

257 columns[col] = 'sch0:f{0}'.format(c) 

258 labs.append("<f{0}> {1}".format(c, col)) 

259 node = ' sch0[label="{0}",shape=record,fontsize={1}];'.format( 

260 "|".join(labs), params.get('fontsize', fontsize)) 

261 exp.append(node) 

262 else: 

263 exp.append('') 

264 if line['type'] == 'transform': 

265 node = ' node{0}[label="{1}",shape=box,style="filled' \ 

266 ',rounded",color=cyan,fontsize={2}];'.format( 

267 i, line['name'], 

268 int(params.get('fontsize', fontsize) * 1.5)) 

269 else: 

270 node = ' node{0}[label="{1}",shape=box,style="filled,' \ 

271 'rounded",color=yellow,fontsize={2}];'.format( 

272 i, line['name'], 

273 int(params.get('fontsize', fontsize) * 1.5)) 

274 exp.append(node) 

275 

276 for inp in line['inputs']: 

277 if isinstance(inp, int): 

278 raise IndexError( # pragma: no cover 

279 "Unable to guess columns {} in\n{}\n---\n{}".format( 

280 inp, pprint.pformat(columns), '\n'.join(exp))) 

281 else: 

282 nc = columns.get(inp, inp) 

283 edge = ' {0} -> node{1};'.format(nc, i) 

284 exp.append(edge) 

285 

286 labs = [] 

287 for c, out in enumerate(line['outputs']): 

288 columns[out] = 'sch{0}:f{1}'.format(i, c) 

289 labs.append("<f{0}> {1}".format(c, out)) 

290 node = ' sch{0}[label="{1}",shape=record,fontsize={2}];'.format( 

291 i, "|".join(labs), params.get('fontsize', fontsize)) 

292 exp.append(node) 

293 

294 for out in line['outputs']: 

295 nc = columns[out] 

296 edge = ' node{1} -> {0};'.format(nc, i) 

297 if edge not in exp: 

298 exp.append(edge) 

299 

300 exp.append('}') 

301 return "\n".join(exp) 

302 

303 

304def pipeline2str(pipe, indent=3): 

305 """ 

306 Exports a *scikit-learn* pipeline to text. 

307 

308 @param pipe *scikit-learn* pipeline 

309 @return str 

310 

311 .. runpython:: 

312 :showcode: 

313 

314 from sklearn.linear_model import LogisticRegression 

315 from sklearn.impute import SimpleImputer 

316 from sklearn.preprocessing import OneHotEncoder 

317 from sklearn.preprocessing import StandardScaler, MinMaxScaler 

318 from sklearn.compose import ColumnTransformer 

319 from sklearn.pipeline import Pipeline 

320 

321 from mlinsights.plotting import pipeline2str 

322 

323 numeric_features = ['age', 'fare'] 

324 numeric_transformer = Pipeline(steps=[ 

325 ('imputer', SimpleImputer(strategy='median')), 

326 ('scaler', StandardScaler())]) 

327 

328 categorical_features = ['embarked', 'sex', 'pclass'] 

329 categorical_transformer = Pipeline(steps=[ 

330 ('imputer', SimpleImputer(strategy='constant', fill_value='missing')), 

331 ('onehot', OneHotEncoder(handle_unknown='ignore'))]) 

332 

333 preprocessor = ColumnTransformer( 

334 transformers=[ 

335 ('num', numeric_transformer, numeric_features), 

336 ('cat', categorical_transformer, categorical_features), 

337 ]) 

338 

339 clf = Pipeline(steps=[('preprocessor', preprocessor), 

340 ('classifier', LogisticRegression(solver='lbfgs'))]) 

341 text = pipeline2str(clf) 

342 print(text) 

343 """ 

344 rows = [] 

345 for coor, model, vs in enumerate_pipeline_models(pipe): 

346 spaces = " " * indent * (len(coor) - 1) 

347 if vs is None: 

348 msg = "{}{}".format(spaces, model.__class__.__name__) 

349 else: 

350 v = ','.join(map(str, vs)) 

351 msg = "{}{}({})".format(spaces, model.__class__.__name__, v) 

352 rows.append(msg) 

353 return "\n".join(rows)