Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Helpers to visualize a pipeline.
4"""
5import pprint
6from collections import OrderedDict
7import numpy
8import pandas
9from sklearn.base import TransformerMixin, ClassifierMixin, RegressorMixin
10from sklearn.pipeline import Pipeline, FeatureUnion
11from sklearn.compose import ColumnTransformer, TransformedTargetRegressor
12from ..helpers.pipeline import enumerate_pipeline_models
15def _pipeline_info(pipe, data, context, former_data=None):
16 """
17 Internal function to convert a pipeline into
18 some graph.
19 """
20 def _get_name(context, prefix='-v-', info=None, data=None):
21 if info is None:
22 raise RuntimeError("info should not be None") # pragma: no cover
23 if isinstance(prefix, list):
24 return [_get_name(context, el, info, data) for el in prefix]
25 if isinstance(prefix, int):
26 prefix = former_data[prefix]
27 if isinstance(prefix, int):
28 raise TypeError( # pragma: no cover
29 "prefix must be a string.\ninfo={}".format(info))
30 sug = "%s%d" % (prefix, context['n'])
31 while sug in context['names']:
32 context['n'] += 1
33 sug = "%s%d" % (prefix, context['n'])
34 context['names'][sug] = info
35 return sug
37 def _get_name_simple(name, data):
38 if isinstance(name, str):
39 return name
40 res = data[name]
41 if isinstance(res, int):
42 raise RuntimeError( # pragma: no cover
43 "Column name is still a number and not a name: {} and {}."
44 "".format(name, data))
45 return res
47 if isinstance(pipe, Pipeline):
48 infos = []
49 for _, model in pipe.steps:
50 info = _pipeline_info(model, data, context)
51 data = info[-1]["outputs"]
52 infos.extend(info)
53 return infos
55 if isinstance(pipe, ColumnTransformer):
56 infos = []
57 outputs = []
58 for _, model, vs in pipe.transformers:
59 if all(map(lambda o: isinstance(o, int), vs)):
60 new_data = []
61 if isinstance(data, OrderedDict):
62 new_data = [_[1] for _ in data.items()]
63 else:
64 mx = max(vs)
65 while len(new_data) < mx:
66 if len(data) > len(new_data):
67 new_data.append(data[len(new_data)])
68 else:
69 new_data.append(data[-1])
70 else:
71 new_data = OrderedDict()
72 for v in vs:
73 new_data[v] = data.get(v, v)
75 info = _pipeline_info(
76 model, new_data, context, former_data=new_data)
77 #new_outputs = []
78 # for o in info[-1]['outputs']:
79 # add = _get_name(context, prefix=o, info=info)
80 # outputs.append(add)
81 # new_outputs.append(add)
82 #info[-1]['outputs'] = new_outputs
83 outputs.extend(info[-1]['outputs'])
84 infos.extend(info)
86 final_hat = False
87 if pipe.remainder == "passthrough":
89 done = [set(d['inputs']) for d in info]
90 merged = done[0]
91 for d in done[1:]:
92 merged.union(d)
93 new_data = OrderedDict(
94 [(k, v) for k, v in data.items() if k not in merged])
96 info = _pipeline_info(
97 "passthrough", new_data, context, former_data=new_data)
98 outputs.extend(info[-1]['outputs'])
99 infos.extend(info)
100 final_hat = True
102 if len(pipe.transformers) > 1 or final_hat:
103 info = {'name': 'union', 'inputs': outputs, 'type': 'transform'}
104 info['outputs'] = [_get_name(context, info=info)]
105 infos.append(info)
106 return infos
108 if isinstance(pipe, FeatureUnion):
109 infos = []
110 outputs = []
111 for _, model in pipe.transformer_list:
112 info = _pipeline_info(model, data, context)
113 new_outputs = []
114 for o in info[-1]['outputs']:
115 add = _get_name(context, prefix=o, info=info)
116 outputs.append(add)
117 new_outputs.append(add)
118 info[-1]['outputs'] = new_outputs
119 infos.extend(info)
120 if len(pipe.transformer_list) > 1:
121 info = {'name': 'union', 'inputs': outputs, 'type': 'transform'}
122 info['outputs'] = [_get_name(context, info=info)]
123 infos.append(info)
124 return infos
126 if isinstance(pipe, TransformedTargetRegressor):
127 raise NotImplementedError( # pragma: no cover
128 "Not yet implemented for TransformedTargetRegressor.")
130 if isinstance(pipe, TransformerMixin):
131 info = {'name': pipe.__class__.__name__, 'type': 'transform'}
132 if len(data) == 1:
133 info['outputs'] = data
134 info['inputs'] = data
135 info = [info]
136 else:
137 info['inputs'] = [_get_name(context, info=info)]
138 info['outputs'] = [_get_name(context, info=info)]
139 info = [{'name': 'union', 'outputs': info['inputs'],
140 'inputs': data, 'type': 'transform'}, info]
141 return info
143 if isinstance(pipe, ClassifierMixin):
144 info = {'name': pipe.__class__.__name__, 'type': 'classifier'}
145 exp = ['PredictedLabel', 'Probabilities']
146 if len(data) == 1:
147 info['outputs'] = exp
148 info['inputs'] = data
149 info = [info]
150 else:
151 info['outputs'] = exp
152 info['inputs'] = [_get_name(context, info=info)]
153 info = [{'name': 'union', 'outputs': info['inputs'], 'inputs': data,
154 'type': 'transform'}, info]
155 return info
157 if isinstance(pipe, RegressorMixin):
158 info = {'name': pipe.__class__.__name__, 'type': 'regressor'}
159 exp = ['Prediction']
160 if len(data) == 1:
161 info['outputs'] = exp
162 info['inputs'] = data
163 info = [info]
164 else:
165 info['outputs'] = exp
166 info['inputs'] = [_get_name(context, info=info)]
167 info = [{'name': 'union', 'outputs': info['inputs'], 'inputs': data,
168 'type': 'transform'}, info]
169 return info
171 if isinstance(pipe, str):
172 if pipe == "passthrough":
173 info = {'name': 'Identity', 'type': 'transform'}
174 info['inputs'] = [_get_name_simple(n, former_data) for n in data]
175 if isinstance(data, (OrderedDict, dict)) and len(data) > 1:
176 info['outputs'] = [
177 _get_name(context, data=k, info=info)
178 for k in data]
179 else:
180 info['outputs'] = _get_name(context, data=data, info=info)
181 info = [info]
182 else:
183 raise NotImplementedError( # pragma: no cover
184 "Not yet implemented for keyword '{}'.".format(type(pipe)))
185 return info
187 raise NotImplementedError( # pragma: no cover
188 "Not yet implemented for {}.".format(type(pipe)))
191def pipeline2dot(pipe, data, **params):
192 """
193 Exports a *scikit-learn* pipeline to
194 :epkg:`DOT` language. See :ref:`visualizepipelinerst`
195 for an example.
197 @param pipe *scikit-learn* pipeline
198 @param data training data as a dataframe or a numpy array,
199 or just a list with the variable names
200 @param params additional params to draw the graph
201 @return string
203 Default options for the graph are:
205 ::
207 options = {
208 'orientation': 'portrait',
209 'ranksep': '0.25',
210 'nodesep': '0.05',
211 'width': '0.5',
212 'height': '0.1',
213 }
214 """
215 raw_data = data
216 data = OrderedDict()
217 if isinstance(raw_data, pandas.DataFrame):
218 for k, c in enumerate(raw_data.columns):
219 data[c] = 'sch0:f%d' % k
220 elif isinstance(raw_data, numpy.ndarray):
221 if len(raw_data.shape) != 2:
222 raise NotImplementedError( # pragma: no cover
223 "Unexpected training data dimension: {}.".format(
224 data.shape)) # pylint: disable=E1101
225 for i in range(raw_data.shape[1]):
226 data['X%d' % i] = 'sch0:f%d' % i
227 elif not isinstance(raw_data, list):
228 raise TypeError( # pragma: no cover
229 "Unexpected data type: {}.".format(type(raw_data)))
231 options = {
232 'orientation': 'portrait',
233 'ranksep': '0.25',
234 'nodesep': '0.05',
235 'width': '0.5',
236 'height': '0.1',
237 }
238 options.update(params)
240 exp = ["digraph{"]
241 for opt in ['orientation', 'pad', 'nodesep', 'ranksep']:
242 if opt in options:
243 exp.append(" {}={};".format(opt, options[opt]))
244 fontsize = 8
245 info = [dict(schema_after=data)]
246 names = OrderedDict()
247 for d in data:
248 names[d] = info
249 info.extend(_pipeline_info(pipe, data, context=dict(n=0, names=names)))
250 columns = OrderedDict()
252 for i, line in enumerate(info):
253 if i == 0:
254 schema = line['schema_after']
255 labs = []
256 for c, col in enumerate(schema):
257 columns[col] = 'sch0:f{0}'.format(c)
258 labs.append("<f{0}> {1}".format(c, col))
259 node = ' sch0[label="{0}",shape=record,fontsize={1}];'.format(
260 "|".join(labs), params.get('fontsize', fontsize))
261 exp.append(node)
262 else:
263 exp.append('')
264 if line['type'] == 'transform':
265 node = ' node{0}[label="{1}",shape=box,style="filled' \
266 ',rounded",color=cyan,fontsize={2}];'.format(
267 i, line['name'],
268 int(params.get('fontsize', fontsize) * 1.5))
269 else:
270 node = ' node{0}[label="{1}",shape=box,style="filled,' \
271 'rounded",color=yellow,fontsize={2}];'.format(
272 i, line['name'],
273 int(params.get('fontsize', fontsize) * 1.5))
274 exp.append(node)
276 for inp in line['inputs']:
277 if isinstance(inp, int):
278 raise IndexError( # pragma: no cover
279 "Unable to guess columns {} in\n{}\n---\n{}".format(
280 inp, pprint.pformat(columns), '\n'.join(exp)))
281 else:
282 nc = columns.get(inp, inp)
283 edge = ' {0} -> node{1};'.format(nc, i)
284 exp.append(edge)
286 labs = []
287 for c, out in enumerate(line['outputs']):
288 columns[out] = 'sch{0}:f{1}'.format(i, c)
289 labs.append("<f{0}> {1}".format(c, out))
290 node = ' sch{0}[label="{1}",shape=record,fontsize={2}];'.format(
291 i, "|".join(labs), params.get('fontsize', fontsize))
292 exp.append(node)
294 for out in line['outputs']:
295 nc = columns[out]
296 edge = ' node{1} -> {0};'.format(nc, i)
297 if edge not in exp:
298 exp.append(edge)
300 exp.append('}')
301 return "\n".join(exp)
304def pipeline2str(pipe, indent=3):
305 """
306 Exports a *scikit-learn* pipeline to text.
308 @param pipe *scikit-learn* pipeline
309 @return str
311 .. runpython::
312 :showcode:
314 from sklearn.linear_model import LogisticRegression
315 from sklearn.impute import SimpleImputer
316 from sklearn.preprocessing import OneHotEncoder
317 from sklearn.preprocessing import StandardScaler, MinMaxScaler
318 from sklearn.compose import ColumnTransformer
319 from sklearn.pipeline import Pipeline
321 from mlinsights.plotting import pipeline2str
323 numeric_features = ['age', 'fare']
324 numeric_transformer = Pipeline(steps=[
325 ('imputer', SimpleImputer(strategy='median')),
326 ('scaler', StandardScaler())])
328 categorical_features = ['embarked', 'sex', 'pclass']
329 categorical_transformer = Pipeline(steps=[
330 ('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
331 ('onehot', OneHotEncoder(handle_unknown='ignore'))])
333 preprocessor = ColumnTransformer(
334 transformers=[
335 ('num', numeric_transformer, numeric_features),
336 ('cat', categorical_transformer, categorical_features),
337 ])
339 clf = Pipeline(steps=[('preprocessor', preprocessor),
340 ('classifier', LogisticRegression(solver='lbfgs'))])
341 text = pipeline2str(clf)
342 print(text)
343 """
344 rows = []
345 for coor, model, vs in enumerate_pipeline_models(pipe):
346 spaces = " " * indent * (len(coor) - 1)
347 if vs is None:
348 msg = "{}{}".format(spaces, model.__class__.__name__)
349 else:
350 v = ','.join(map(str, vs))
351 msg = "{}{}({})".format(spaces, model.__class__.__name__, v)
352 rows.append(msg)
353 return "\n".join(rows)