Coverage for mlprodict/plotting/text_plot.py: 92%
518 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# pylint: disable=R0912,R0914,C0302
2"""
3@file
4@brief Text representations of graphs.
5"""
6import pprint
7from collections import OrderedDict
8import numpy
9from onnx import TensorProto, AttributeProto
10from onnx.numpy_helper import to_array
11from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
12from ..tools.graphs import onnx2bigraph
13from ..onnx_tools.onnx2py_helper import _var_as_dict, get_tensor_shape
16def onnx_text_plot(model_onnx, recursive=False, graph_type='basic',
17 grid=5, distance=5):
18 """
19 Uses @see fn onnx2bigraph to convert the ONNX graph
20 into text.
22 :param model_onnx: onnx representation
23 :param recursive: @see fn onnx2bigraph
24 :param graph_type: @see fn onnx2bigraph
25 :param grid: @see me display_structure
26 :param distance: @see fn display_structure
27 :return: text
29 .. runpython::
30 :showcode:
31 :warningout: DeprecationWarning
33 import numpy
34 from mlprodict.onnx_conv import to_onnx
35 from mlprodict import __max_supported_opset__ as opv
36 from mlprodict.plotting.plotting import onnx_text_plot
37 from mlprodict.npy.xop import loadop
39 OnnxAdd, OnnxSub = loadop('Add', 'Sub')
41 idi = numpy.identity(2).astype(numpy.float32)
42 A = OnnxAdd('X', idi, op_version=opv)
43 B = OnnxSub(A, 'W', output_names=['Y'], op_version=opv)
44 onx = B.to_onnx({'X': idi, 'W': idi})
45 print(onnx_text_plot(onx))
46 """
47 bigraph = onnx2bigraph(model_onnx)
48 graph = bigraph.display_structure()
49 return graph.to_text()
52def onnx_text_plot_tree(node):
53 """
54 Gives a textual representation of a tree ensemble.
56 :param node: `TreeEnsemble*`
57 :return: text
59 .. runpython::
60 :showcode:
61 :warningout: DeprecationWarning
63 import numpy
64 from sklearn.datasets import load_iris
65 from sklearn.tree import DecisionTreeRegressor
66 from mlprodict.onnx_conv import to_onnx
67 from mlprodict.plotting.plotting import onnx_text_plot_tree
69 iris = load_iris()
70 X, y = iris.data.astype(numpy.float32), iris.target
71 clr = DecisionTreeRegressor(max_depth=3)
72 clr.fit(X, y)
73 onx = to_onnx(clr, X)
74 res = onnx_text_plot_tree(onx.graph.node[0])
75 print(res)
76 """
77 def rule(r):
78 if r == b'BRANCH_LEQ':
79 return '<='
80 if r == b'BRANCH_LT': # pragma: no cover
81 return '<'
82 if r == b'BRANCH_GEQ': # pragma: no cover
83 return '>='
84 if r == b'BRANCH_GT': # pragma: no cover
85 return '>'
86 if r == b'BRANCH_EQ': # pragma: no cover
87 return '=='
88 if r == b'BRANCH_NEQ': # pragma: no cover
89 return '!='
90 raise ValueError( # pragma: no cover
91 f"Unexpected rule {rule!r}.")
93 class Node:
94 "Node representation."
96 def __init__(self, i, atts):
97 self.nodes_hitrates = None
98 self.nodes_missing_value_tracks_true = None
99 for k, v in atts.items():
100 if k.startswith('nodes'):
101 setattr(self, k, v[i])
102 self.depth = 0
103 self.true_false = ''
105 def process_node(self):
106 "node to string"
107 if self.nodes_modes == b'LEAF': # pylint: disable=E1101
108 text = "%s y=%r f=%r i=%r" % (
109 self.true_false,
110 self.target_weights, self.target_ids, # pylint: disable=E1101
111 self.target_nodeids) # pylint: disable=E1101
112 else:
113 text = "%s X%d %s %r" % (
114 self.true_false,
115 self.nodes_featureids, # pylint: disable=E1101
116 rule(self.nodes_modes), # pylint: disable=E1101
117 self.nodes_values) # pylint: disable=E1101
118 if self.nodes_hitrates and self.nodes_hitrates != 1:
119 text += f" hi={self.nodes_hitrates!r}"
120 if self.nodes_missing_value_tracks_true:
121 text += f" miss={self.nodes_missing_value_tracks_true!r}"
122 return f"{' ' * self.depth}{text}"
124 def process_tree(atts, treeid):
125 "tree to string"
126 rows = [f'treeid={treeid!r}']
127 if 'base_values' in atts:
128 if treeid < len(atts['base_values']):
129 rows.append(f"base_value={atts['base_values'][treeid]!r}")
131 short = {}
132 for prefix in ['nodes', 'target', 'class']:
133 if (f'{prefix}_treeids') not in atts:
134 continue
135 idx = [i for i in range(len(atts[f'{prefix}_treeids']))
136 if atts[f'{prefix}_treeids'][i] == treeid]
137 for k, v in atts.items():
138 if k.startswith(prefix):
139 if 'classlabels' in k:
140 short[k] = list(v)
141 else:
142 short[k] = [v[i] for i in idx]
144 nodes = OrderedDict()
145 for i in range(len(short['nodes_treeids'])):
146 nodes[i] = Node(i, short)
147 prefix = 'target' if 'target_treeids' in short else 'class'
148 for i in range(len(short[f'{prefix}_treeids'])):
149 idn = short[f'{prefix}_nodeids'][i]
150 node = nodes[idn]
151 node.target_nodeids = idn
152 node.target_ids = short[f'{prefix}_ids'][i]
153 node.target_weights = short[f'{prefix}_weights'][i]
155 def iterate(nodes, node, depth=0, true_false=''):
156 node.depth = depth
157 node.true_false = true_false
158 yield node
159 if node.nodes_falsenodeids > 0:
160 for n in iterate(nodes, nodes[node.nodes_falsenodeids],
161 depth=depth + 1, true_false='F'):
162 yield n
163 for n in iterate(nodes, nodes[node.nodes_truenodeids],
164 depth=depth + 1, true_false='T'):
165 yield n
167 for node in iterate(nodes, nodes[0]):
168 rows.append(node.process_node())
169 return rows
171 if node.op_type in ("TreeEnsembleRegressor", "TreeEnsembleClassifier"):
172 d = {k: v['value'] for k, v in _var_as_dict(node)['atts'].items()}
173 atts = {}
174 for k, v in d.items():
175 atts[k] = v if isinstance(v, int) else list(v)
176 trees = list(sorted(set(atts['nodes_treeids'])))
177 if 'n_targets' in atts:
178 rows = [f"n_targets={atts['n_targets']!r}"]
179 else:
180 rows = ['n_classes=%r' % len(
181 atts.get('classlabels_int64s',
182 atts.get('classlabels_strings', [])))]
183 rows.append(f'n_trees={len(trees)!r}')
184 for tree in trees:
185 r = process_tree(atts, tree)
186 rows.append('----')
187 rows.extend(r)
188 return "\n".join(rows)
190 raise NotImplementedError( # pragma: no cover
191 f"Type {node.op_type!r} cannot be displayed.")
194def _append_succ_pred(subgraphs, successors, predecessors, node_map, node, prefix="",
195 parent_node_name=None):
196 node_name = prefix + node.name + "#" + "|".join(node.output)
197 node_map[node_name] = node
198 successors[node_name] = []
199 predecessors[node_name] = []
200 for name in node.input:
201 predecessors[node_name].append(name)
202 if name not in successors:
203 successors[name] = []
204 successors[name].append(node_name)
205 for name in node.output:
206 successors[node_name].append(name)
207 predecessors[name] = [node_name]
208 if node.op_type in {'If', 'Scan', 'Loop', 'Expression'}:
209 for att in node.attribute:
210 if (att.type != AttributeProto.GRAPH or # pylint: disable=E1101
211 not hasattr(att, 'g') or att.g is None):
212 continue
213 subgraphs.append((node, att.name, att.g))
214 _append_succ_pred_s(subgraphs, successors, predecessors, node_map,
215 att.g.node, prefix=node_name + ":/:",
216 parent_node_name=node_name,
217 parent_graph=att.g)
220def _append_succ_pred_s(subgraphs, successors, predecessors, node_map, nodes, prefix="",
221 parent_node_name=None, parent_graph=None):
222 for node in nodes:
223 _append_succ_pred(subgraphs, successors, predecessors, node_map, node,
224 prefix=prefix, parent_node_name=parent_node_name)
225 if parent_node_name is not None:
226 unknown = set()
227 known = {}
228 for i in parent_graph.initializer:
229 known[i.name] = None
230 for i in parent_graph.input:
231 known[i.name] = None
232 for n in parent_graph.node:
233 for i in n.input:
234 if i not in known:
235 unknown.add(i)
236 for i in n.output:
237 known[i] = n
238 if len(unknown) > 0:
239 # These inputs are coming from the graph below.
240 for name in unknown:
241 successors[name].append(parent_node_name)
242 predecessors[parent_node_name].append(name)
245def graph_predecessors_and_successors(graph):
246 """
247 Returns the successors and the predecessors within on ONNX graph.
248 """
249 node_map = {}
250 successors = {}
251 predecessors = {}
252 subgraphs = []
253 _append_succ_pred_s(subgraphs, successors,
254 predecessors, node_map, graph.node)
255 return subgraphs, predecessors, successors, node_map
258def get_hidden_inputs(nodes):
259 """
260 Returns the list of hidden inputs used by subgraphs.
262 :param nodes: list of nodes
263 :return: list of names
264 """
265 inputs = set()
266 outputs = set()
267 for node in nodes:
268 inputs |= set(node.input)
269 outputs |= set(node.output)
270 for att in node.attribute:
271 if (att.type != AttributeProto.GRAPH or # pylint: disable=E1101
272 not hasattr(att, 'g') or att.g is None):
273 continue
274 hidden = get_hidden_inputs(att.g.node)
275 inits = set(i.name for i in att.g.initializer)
276 inits |= set(i.name for i in att.g.sparse_initializer)
277 inputs |= hidden - (inits & hidden)
278 return inputs - (outputs & inputs)
281def reorder_nodes_for_display(nodes, verbose=False):
282 """
283 Reorders the node with breadth first seach (BFS).
285 :param nodes: list of ONNX nodes
286 :param verbose: dislay intermediate informations
287 :return: reordered list of nodes
288 """
289 class temp:
290 "Fake GraphProto."
292 def __init__(self, nodes):
293 self.node = nodes
295 _, predecessors, successors, dnodes = graph_predecessors_and_successors(
296 temp(nodes))
297 local_variables = get_hidden_inputs(nodes)
299 all_outputs = set()
300 all_inputs = set(local_variables)
301 for node in nodes:
302 all_outputs |= set(node.output)
303 all_inputs |= set(node.input)
304 common = all_outputs & all_inputs
306 successors = {k: set(v) for k, v in successors.items()}
307 predecessors = {k: set(v) for k, v in predecessors.items()}
308 if verbose:
309 pprint.pprint( # pragma: no cover
310 ["[reorder_nodes_for_display]", "predecessors",
311 predecessors, "successors", successors])
313 known = all_inputs - common
314 new_nodes = []
315 done = set()
317 def _find_sequence(node_name, known, done):
318 inputs = dnodes[node_name].input
319 if any(map(lambda i: i not in known, inputs)):
320 return []
322 res = [node_name]
323 while res[-1] in successors:
324 next_names = successors[res[-1]]
325 if res[-1] not in dnodes:
326 next_names = set(v for v in next_names if v not in known)
327 if len(next_names) == 1:
328 next_name = next_names.pop()
329 inputs = dnodes[next_name].input
330 if any(map(lambda i: i not in known, inputs)):
331 break
332 res.extend(next_name)
333 else:
334 break
335 else:
336 next_names = set(v for v in next_names if v not in done)
337 if len(next_names) == 1:
338 next_name = next_names.pop()
339 res.append(next_name)
340 else:
341 break
343 return [r for r in res if r in dnodes and r not in done]
345 while len(done) < len(nodes):
346 # possible
347 possibles = OrderedDict()
348 for k, v in dnodes.items():
349 if k in done:
350 continue
351 if ':/:' in k:
352 # node part of a sub graph (assuming :/: is never used in a node name)
353 continue
354 if predecessors[k] <= known:
355 possibles[k] = v
357 sequences = OrderedDict()
358 for k, v in possibles.items():
359 if k in done:
360 continue
361 sequences[k] = _find_sequence(k, known, done)
362 if verbose:
363 print("[reorder_nodes_for_display] * sequence(%s)=%s - %r" % (
364 k, ",".join(sequences[k]), list(sequences)))
366 if len(sequences) == 0:
367 raise RuntimeError( # pragma: no cover
368 "Unexpected empty sequence (len(possibles)=%d, "
369 "len(done)=%d, len(nodes)=%d). This is usually due to "
370 "a name used both as result name and node node. "
371 "known=%r." % (len(possibles), len(done), len(nodes), known))
373 # find the best sequence
374 best = None
375 for k, v in sequences.items():
376 if best is None or len(v) > len(sequences[best]):
377 # if the sequence of successors is longer
378 best = k
379 elif len(v) == len(sequences[best]):
380 if len(new_nodes) > 0:
381 # then choose the next successor sharing input with
382 # previous output
383 so = set(new_nodes[-1].output)
384 first1 = dnodes[sequences[best][0]]
385 first2 = dnodes[v[0]]
386 if len(set(first1.input) & so) < len(set(first2.input) & so):
387 best = k
388 else:
389 first1 = dnodes[sequences[best][0]]
390 first2 = dnodes[v[0]]
391 if first1.op_type > first2.op_type:
392 best = k
393 elif (first1.op_type == first2.op_type and
394 first1.name > first2.name):
395 best = k
397 if best is None:
398 raise RuntimeError( # pragma: no cover
399 f"Wrong implementation (len(sequence)={len(sequences)}).")
400 if verbose:
401 print("[reorder_nodes_for_display] BEST: sequence(%s)=%s" % (
402 best, ",".join(sequences[best])))
404 # process the sequence
405 for k in sequences[best]:
406 v = dnodes[k]
407 new_nodes.append(v)
408 if verbose:
409 print(
410 f"[reorder_nodes_for_display] + {v.name!r} ({v.op_type!r})")
411 done.add(k)
412 known |= set(v.output)
414 if len(new_nodes) != len(nodes):
415 raise RuntimeError( # pragma: no cover
416 "The returned new nodes are different. "
417 "len(nodes=%d) != %d=len(new_nodes). done=\n%r"
418 "\n%s\n----------\n%s" % (
419 len(nodes), len(new_nodes), done,
420 "\n".join("%d - %s - %s - %s" % (
421 (n.name + "".join(n.output)) in done,
422 n.op_type, n.name, n.name + "".join(n.output))
423 for n in nodes),
424 "\n".join("%d - %s - %s - %s" % (
425 (n.name + "".join(n.output)) in done,
426 n.op_type, n.name, n.name + "".join(n.output))
427 for n in new_nodes)))
428 n0s = set(n.name for n in nodes)
429 n1s = set(n.name for n in new_nodes)
430 if n0s != n1s:
431 raise RuntimeError( # pragma: no cover
432 "The returned new nodes are different.\n"
433 "%r !=\n%r\ndone=\n%r"
434 "\n----------\n%s\n----------\n%s" % (
435 n0s, n1s, done,
436 "\n".join("%d - %s - %s - %s" % (
437 (n.name + "".join(n.output)) in done,
438 n.op_type, n.name, n.name + "".join(n.output))
439 for n in nodes),
440 "\n".join("%d - %s - %s - %s" % (
441 (n.name + "".join(n.output)) in done,
442 n.op_type, n.name, n.name + "".join(n.output))
443 for n in new_nodes)))
444 return new_nodes
447def _get_type(obj0):
448 obj = obj0
449 if hasattr(obj, 'data_type'):
450 if (obj.data_type == TensorProto.FLOAT and # pylint: disable=E1101
451 hasattr(obj, 'float_data')):
452 return TENSOR_TYPE_TO_NP_TYPE[TensorProto.FLOAT] # pylint: disable=E1101
453 if (obj.data_type == TensorProto.DOUBLE and # pylint: disable=E1101
454 hasattr(obj, 'double_data')):
455 return TENSOR_TYPE_TO_NP_TYPE[TensorProto.DOUBLE] # pylint: disable=E1101
456 if (obj.data_type == TensorProto.INT64 and # pylint: disable=E1101
457 hasattr(obj, 'int64_data')):
458 return TENSOR_TYPE_TO_NP_TYPE[TensorProto.INT64] # pylint: disable=E1101
459 if (obj.data_type == TensorProto.INT32 and # pylint: disable=E1101
460 hasattr(obj, 'int32_data')):
461 return TENSOR_TYPE_TO_NP_TYPE[TensorProto.INT32] # pylint: disable=E1101
462 if hasattr(obj, 'raw_data') and len(obj.raw_data) > 0:
463 arr = to_array(obj)
464 return arr.dtype
465 raise RuntimeError( # pragma: no cover
466 f"Unable to guess type from {obj0!r}.")
467 if hasattr(obj, 'type'):
468 obj = obj.type
469 if hasattr(obj, 'tensor_type'):
470 obj = obj.tensor_type
471 if hasattr(obj, 'elem_type'):
472 return TENSOR_TYPE_TO_NP_TYPE.get(obj.elem_type, '?')
473 raise RuntimeError( # pragma: no cover
474 f"Unable to guess type from {obj0!r}.")
477def _get_shape(obj):
478 try:
479 arr = to_array(obj)
480 return arr.shape
481 except Exception: # pylint: disable=W0703
482 pass
483 obj0 = obj
484 if hasattr(obj, 'data_type'):
485 if (obj.data_type == TensorProto.FLOAT and # pylint: disable=E1101
486 hasattr(obj, 'float_data')):
487 return (len(obj.float_data), )
488 if (obj.data_type == TensorProto.DOUBLE and # pylint: disable=E1101
489 hasattr(obj, 'double_data')):
490 return (len(obj.double_data), )
491 if (obj.data_type == TensorProto.INT64 and # pylint: disable=E1101
492 hasattr(obj, 'int64_data')):
493 return (len(obj.int64_data), )
494 if (obj.data_type == TensorProto.INT32 and # pylint: disable=E1101
495 hasattr(obj, 'int32_data')):
496 return (len(obj.int32_data), )
497 if hasattr(obj, 'raw_data') and len(obj.raw_data) > 0:
498 arr = to_array(obj)
499 return arr.shape
500 raise RuntimeError( # pragma: no cover
501 f"Unable to guess type from {obj0!r}, "
502 f"data_type is {obj.data_type!r}.")
503 if hasattr(obj, 'type'):
504 obj = obj.type
505 if hasattr(obj, 'tensor_type'):
506 return get_tensor_shape(obj)
507 raise RuntimeError( # pragma: no cover
508 f"Unable to guess type from {obj0!r}.")
511def onnx_simple_text_plot(model, verbose=False, att_display=None, # pylint: disable=R0915
512 add_links=False, recursive=False, functions=True,
513 raise_exc=True, sub_graphs_names=None,
514 level=1, indent=True):
515 """
516 Displays an ONNX graph into text.
518 :param model: ONNX graph
519 :param verbose: display debugging information
520 :param att_display: list of attributes to display, if None,
521 a default list if used
522 :param add_links: displays links of the right side
523 :param recursive: display subgraphs as well
524 :param functions: display functions as well
525 :param raise_exc: raises an exception if the model is not valid,
526 otherwise tries to continue
527 :param sub_graphs_names: list of sub-graphs names
528 :param level: sub-graph level
529 :param indent: use indentation or not
530 :return: str
532 An ONNX graph is printed the following way:
534 .. runpython::
535 :showcode:
536 :warningout: DeprecationWarning
538 import numpy
539 from sklearn.cluster import KMeans
540 from mlprodict.plotting.plotting import onnx_simple_text_plot
541 from mlprodict.onnx_conv import to_onnx
543 x = numpy.random.randn(10, 3)
544 y = numpy.random.randn(10)
545 model = KMeans(3)
546 model.fit(x, y)
547 onx = to_onnx(model, x.astype(numpy.float32),
548 target_opset=15)
549 text = onnx_simple_text_plot(onx, verbose=False)
550 print(text)
552 The same graphs with links.
554 .. runpython::
555 :showcode:
556 :warningout: DeprecationWarning
558 import numpy
559 from sklearn.cluster import KMeans
560 from mlprodict.plotting.plotting import onnx_simple_text_plot
561 from mlprodict.onnx_conv import to_onnx
563 x = numpy.random.randn(10, 3)
564 y = numpy.random.randn(10)
565 model = KMeans(3)
566 model.fit(x, y)
567 onx = to_onnx(model, x.astype(numpy.float32),
568 target_opset=15)
569 text = onnx_simple_text_plot(onx, verbose=False, add_links=True)
570 print(text)
572 Visually, it looks like the following:
574 .. gdot::
575 :script: DOT-SECTION
577 import numpy
578 from sklearn.cluster import KMeans
579 from mlprodict.onnxrt import OnnxInference
580 from mlprodict.onnx_conv import to_onnx
582 x = numpy.random.randn(10, 3)
583 y = numpy.random.randn(10)
584 model = KMeans(3)
585 model.fit(x, y)
586 model_onnx = to_onnx(model, x.astype(numpy.float32),
587 target_opset=15)
588 oinf = OnnxInference(model_onnx, inplace=False)
590 print("DOT-SECTION", oinf.to_dot())
591 """
592 use_indentation = indent
593 if att_display is None:
594 att_display = [
595 'activations',
596 'align_corners',
597 'allowzero',
598 'alpha',
599 'auto_pad',
600 'axis',
601 'axes',
602 'batch_axis',
603 'batch_dims',
604 'beta',
605 'bias',
606 'blocksize',
607 'case_change_action',
608 'ceil_mode',
609 'center_point_box',
610 'clip',
611 'coordinate_transformation_mode',
612 'count_include_pad',
613 'cubic_coeff_a',
614 'decay_factor',
615 'detect_negative',
616 'detect_positive',
617 'dilation',
618 'dilations',
619 'direction',
620 'dtype',
621 'end',
622 'epsilon',
623 'equation',
624 'exclusive',
625 'exclude_outside',
626 'extrapolation_value',
627 'fmod',
628 'gamma',
629 'group',
630 'hidden_size',
631 'high',
632 'ignore_index',
633 'input_forget',
634 'is_case_sensitive',
635 'k',
636 'keepdims',
637 'kernel_shape',
638 'lambd',
639 'largest',
640 'layout',
641 'linear_before_reset',
642 'locale',
643 'low',
644 'max_gram_length',
645 'max_skip_count',
646 'mean',
647 'min_gram_length',
648 'mode',
649 'momentum',
650 'nearest_mode',
651 'ngram_counts',
652 'ngram_indexes',
653 'noop_with_empty_axes',
654 'norm_coefficient',
655 'norm_coefficient_post',
656 'num_scan_inputs',
657 'output_height',
658 'output_padding',
659 'output_shape',
660 'output_width',
661 'p',
662 'padding_mode',
663 'pads',
664 'perm',
665 'pooled_shape',
666 'reduction',
667 'reverse',
668 'sample_size',
669 'sampling_ratio',
670 'scale',
671 'scan_input_axes',
672 'scan_input_directions',
673 'scan_output_axes',
674 'scan_output_directions',
675 'seed',
676 'select_last_index',
677 'size',
678 'sorted',
679 'spatial_scale',
680 'start',
681 'storage_order',
682 'strides',
683 'time_axis',
684 'to',
685 'training_mode',
686 'transA',
687 'transB',
688 'type',
689 'upper',
690 'xs',
691 'y',
692 'zs',
693 ]
695 if sub_graphs_names is None:
696 sub_graphs_names = {}
698 def _get_subgraph_name(idg):
699 if idg in sub_graphs_names:
700 return sub_graphs_names[idg]
701 g = "G%d" % (len(sub_graphs_names) + 1)
702 sub_graphs_names[idg] = g
703 return g
705 def str_node(indent, node):
706 atts = []
707 if hasattr(node, 'attribute'):
708 for att in node.attribute:
709 done = True
710 if hasattr(att, "ref_attr_name") and att.ref_attr_name:
711 atts.append(f"{att.name}=${att.ref_attr_name}")
712 continue
713 if att.name in att_display:
714 if att.type == AttributeProto.INT: # pylint: disable=E1101
715 atts.append("%s=%d" % (att.name, att.i))
716 elif att.type == AttributeProto.FLOAT: # pylint: disable=E1101
717 atts.append(f"{att.name}={att.f:1.2f}")
718 elif att.type == AttributeProto.INTS: # pylint: disable=E1101
719 atts.append("%s=%s" % (att.name, str(
720 list(att.ints)).replace(" ", "")))
721 else:
722 done = False
723 elif (att.type == AttributeProto.GRAPH and # pylint: disable=E1101
724 hasattr(att, 'g') and att.g is not None):
725 atts.append(f"{att.name}={_get_subgraph_name(id(att.g))}")
726 else:
727 done = False
728 if done:
729 continue
730 if att.type in (AttributeProto.TENSOR, # pylint: disable=E1101
731 AttributeProto.TENSORS, # pylint: disable=E1101
732 AttributeProto.SPARSE_TENSOR, # pylint: disable=E1101
733 AttributeProto.SPARSE_TENSORS): # pylint: disable=E1101
734 try:
735 val = str(to_array(att.t).tolist())
736 except TypeError as e: # pragma: no cover
737 raise TypeError(
738 "Unable to display tensor type %r.\n%s" % (
739 att.type, str(att))) from e
740 if "\n" in val:
741 val = val.split("\n", maxsplit=1) + "..."
742 if len(val) > 10:
743 val = val[:10] + "..."
744 elif att.type == AttributeProto.STRING: # pylint: disable=E1101
745 val = str(att.s)
746 elif att.type == AttributeProto.STRINGS: # pylint: disable=E1101
747 n_val = list(att.strings)
748 if len(n_val) < 5:
749 val = ",".join(map(str, n_val))
750 else:
751 val = "%d:[%s...%s]" % (
752 len(n_val),
753 ",".join(map(str, n_val[:2])),
754 ",".join(map(str, n_val[-2:])))
755 elif att.type == AttributeProto.INT: # pylint: disable=E1101
756 val = str(att.i)
757 elif att.type == AttributeProto.FLOAT: # pylint: disable=E1101
758 val = str(att.f)
759 elif att.type == AttributeProto.INTS: # pylint: disable=E1101
760 n_val = list(att.ints)
761 if len(n_val) < 6:
762 val = f"[{','.join(map(str, n_val))}]"
763 else:
764 val = "%d:[%s...%s]" % (
765 len(n_val),
766 ",".join(map(str, n_val[:3])),
767 ",".join(map(str, n_val[-3:])))
768 elif att.type == AttributeProto.FLOATS: # pylint: disable=E1101
769 n_val = list(att.floats)
770 if len(n_val) < 5:
771 val = f"[{','.join(map(str, n_val))}]"
772 else:
773 val = "%d:[%s...%s]" % (
774 len(n_val),
775 ",".join(map(str, n_val[:2])),
776 ",".join(map(str, n_val[-2:])))
777 else:
778 val = '.%d' % att.type
779 atts.append(f"{att.name}={val}")
780 inputs = list(node.input)
781 if len(atts) > 0:
782 inputs.extend(atts)
783 if node.domain in ('', 'ai.onnx.ml'):
784 domain = ''
785 else:
786 domain = f'[{node.domain}]'
787 return "%s%s%s(%s) -> %s" % (
788 " " * indent, node.op_type, domain,
789 ", ".join(inputs), ", ".join(node.output))
791 rows = []
792 if hasattr(model, 'opset_import'):
793 for opset in model.opset_import:
794 rows.append(
795 f"opset: domain={opset.domain!r} version={opset.version!r}")
796 if hasattr(model, 'graph'):
797 if model.doc_string:
798 rows.append(f'doc_string: {model.doc_string}')
799 main_model = model
800 model = model.graph
801 else:
802 main_model = None
804 # inputs
805 line_name_new = {}
806 line_name_in = {}
807 if level == 0:
808 rows.append("----- input ----")
809 for inp in model.input:
810 if isinstance(inp, str):
811 rows.append(f"input: {inp!r}")
812 else:
813 line_name_new[inp.name] = len(rows)
814 rows.append("input: name=%r type=%r shape=%r" % (
815 inp.name, _get_type(inp), _get_shape(inp)))
816 if hasattr(model, 'attribute'):
817 for att in model.attribute:
818 if isinstance(att, str):
819 rows.append(f"attribute: {att!r}")
820 else:
821 raise NotImplementedError( # pragma: no cover
822 "Not yet introduced in onnx.")
824 # initializer
825 if hasattr(model, 'initializer'):
826 if len(model.initializer) and level == 0:
827 rows.append("----- initializer ----")
828 for init in model.initializer:
829 if numpy.prod(_get_shape(init)) < 5:
830 content = f" -- {to_array(init).ravel()!r}"
831 else:
832 content = ""
833 line_name_new[init.name] = len(rows)
834 rows.append("init: name=%r type=%r shape=%r%s" % (
835 init.name, _get_type(init), _get_shape(init), content))
836 if level == 0:
837 rows.append("----- main graph ----")
839 # successors, predecessors, it needs to support subgraphs
840 subgraphs = graph_predecessors_and_successors(model)[0]
842 # walk through nodes
843 init_names = set()
844 indents = {}
845 for inp in model.input:
846 if isinstance(inp, str):
847 indents[inp] = 0
848 init_names.add(inp)
849 else:
850 indents[inp.name] = 0
851 init_names.add(inp.name)
852 if hasattr(model, 'initializer'):
853 for init in model.initializer:
854 indents[init.name] = 0
855 init_names.add(init.name)
857 try:
858 nodes = reorder_nodes_for_display(model.node, verbose=verbose)
859 except RuntimeError as e: # pragma: no cover
860 if raise_exc:
861 raise e
862 else:
863 rows.append(f"ERROR: {e}")
864 nodes = model.node
866 previous_indent = None
867 previous_out = None
868 previous_in = None
869 for node in nodes:
870 add_break = False
871 name = node.name + "#" + "|".join(node.output)
872 if name in indents:
873 indent = indents[name]
874 if previous_indent is not None and indent < previous_indent:
875 if verbose:
876 print(f"[onnx_simple_text_plot] break1 {node.op_type}")
877 add_break = True
878 elif previous_in is not None and set(node.input) == previous_in:
879 indent = previous_indent
880 else:
881 inds = [indents.get(i, 0)
882 for i in node.input if i not in init_names]
883 if len(inds) == 0:
884 indent = 0
885 else:
886 mi = min(inds)
887 indent = mi
888 if previous_indent is not None and indent < previous_indent:
889 if verbose:
890 print( # pragma: no cover
891 f"[onnx_simple_text_plot] break2 {node.op_type}")
892 add_break = True
893 if not add_break and previous_out is not None:
894 if len(set(node.input) & previous_out) == 0:
895 if verbose:
896 print(f"[onnx_simple_text_plot] break3 {node.op_type}")
897 add_break = True
898 indent = 0
900 if add_break and verbose:
901 print("[onnx_simple_text_plot] add break")
902 for n in node.input:
903 if n in line_name_in:
904 line_name_in[n].append(len(rows))
905 else:
906 line_name_in[n] = [len(rows)]
907 for n in node.output:
908 line_name_new[n] = len(rows)
909 rows.append(str_node(indent if use_indentation else 0, node))
910 indents[name] = indent
912 for i, o in enumerate(node.output):
913 indents[o] = indent + 1
915 previous_indent = indents[name]
916 previous_out = set(node.output)
917 previous_in = set(node.input)
919 # outputs
920 if level == 0:
921 rows.append("----- output ----")
922 for out in model.output:
923 if isinstance(out, str):
924 if out in line_name_in:
925 line_name_in[out].append(len(rows))
926 else:
927 line_name_in[out] = [len(rows)]
928 rows.append(f"output: name={out!r} type={'?'} shape={'?'}")
929 else:
930 if out.name in line_name_in:
931 line_name_in[out.name].append(len(rows))
932 else:
933 line_name_in[out.name] = [len(rows)]
934 rows.append("output: name=%r type=%r shape=%r" % (
935 out.name, _get_type(out), _get_shape(out)))
937 if add_links:
939 def _mark_link(rows, lengths, r1, r2, d):
940 maxl = max(lengths[r1], lengths[r2]) + d * 2
941 maxl = max(maxl, max(len(rows[r]) for r in range(r1, r2 + 1))) + 2
943 if rows[r1][-1] == '|':
944 p1, p2 = rows[r1][:lengths[r1] + 2], rows[r1][lengths[r1] + 2:]
945 rows[r1] = p1 + p2.replace(' ', '-')
946 rows[r1] += ("-" * (maxl - len(rows[r1]) - 1)) + "+"
948 if rows[r2][-1] == " ":
949 rows[r2] += "<"
950 elif rows[r2][-1] == '|':
951 if "<" not in rows[r2]:
952 p = lengths[r2]
953 rows[r2] = rows[r2][:p] + '<' + rows[r2][p + 1:]
954 p1, p2 = rows[r2][:lengths[r2] + 2], rows[r2][lengths[r2] + 2:]
955 rows[r2] = p1 + p2.replace(' ', '-')
956 rows[r2] += ("-" * (maxl - len(rows[r2]) - 1)) + "+"
958 for r in range(r1 + 1, r2):
959 if len(rows[r]) < maxl:
960 rows[r] += " " * (maxl - len(rows[r]) - 1)
961 rows[r] += "|"
963 diffs = []
964 for n, r1 in line_name_new.items():
965 if n not in line_name_in:
966 continue
967 r2s = line_name_in[n]
968 for r2 in r2s:
969 if r1 >= r2:
970 continue
971 diffs.append((r2 - r1, (n, r1, r2)))
972 diffs.sort()
973 for i in range(len(rows)): # pylint: disable=C0200
974 rows[i] += " "
975 lengths = [len(r) for r in rows]
977 for d, (n, r1, r2) in diffs:
978 if d == 1 and len(line_name_in[n]) == 1:
979 # no line for link to the next node
980 continue
981 _mark_link(rows, lengths, r1, r2, d)
983 # subgraphs
984 if recursive:
985 for node, name, g in subgraphs:
986 rows.append('----- subgraph ---- %s - %s - att.%s=%s -- level=%d -- %s -> %s' % (
987 node.op_type, node.name, name, _get_subgraph_name(id(g)),
988 level, ','.join(i.name for i in g.input),
989 ','.join(i.name for i in g.output)))
990 res = onnx_simple_text_plot(
991 g, verbose=verbose, att_display=att_display,
992 add_links=add_links, recursive=recursive,
993 sub_graphs_names=sub_graphs_names, level=level + 1,
994 raise_exc=raise_exc)
995 rows.append(res)
997 # functions
998 if functions and main_model is not None:
999 for fct in main_model.functions:
1000 rows.append(f'----- function name={fct.name} domain={fct.domain}')
1001 if fct.doc_string:
1002 rows.append(f'----- doc_string: {fct.doc_string}')
1003 res = onnx_simple_text_plot(
1004 fct, verbose=verbose, att_display=att_display,
1005 add_links=add_links, recursive=recursive,
1006 functions=False, sub_graphs_names=sub_graphs_names,
1007 level=1)
1008 rows.append(res)
1010 return "\n".join(rows)
1013def onnx_text_plot_io(model, verbose=False, att_display=None):
1014 """
1015 Displays information about input and output types.
1017 :param model: ONNX graph
1018 :param verbose: display debugging information
1019 :return: str
1021 An ONNX graph is printed the following way:
1023 .. runpython::
1024 :showcode:
1025 :warningout: DeprecationWarning
1027 import numpy
1028 from sklearn.cluster import KMeans
1029 from mlprodict.plotting.plotting import onnx_text_plot_io
1030 from mlprodict.onnx_conv import to_onnx
1032 x = numpy.random.randn(10, 3)
1033 y = numpy.random.randn(10)
1034 model = KMeans(3)
1035 model.fit(x, y)
1036 onx = to_onnx(model, x.astype(numpy.float32),
1037 target_opset=15)
1038 text = onnx_text_plot_io(onx, verbose=False)
1039 print(text)
1040 """
1041 rows = []
1042 if hasattr(model, 'opset_import'):
1043 for opset in model.opset_import:
1044 rows.append(
1045 f"opset: domain={opset.domain!r} version={opset.version!r}")
1046 if hasattr(model, 'graph'):
1047 model = model.graph
1049 # inputs
1050 for inp in model.input:
1051 rows.append("input: name=%r type=%r shape=%r" % (
1052 inp.name, _get_type(inp), _get_shape(inp)))
1053 # initializer
1054 for init in model.initializer:
1055 rows.append("init: name=%r type=%r shape=%r" % (
1056 init.name, _get_type(init), _get_shape(init)))
1057 # outputs
1058 for out in model.output:
1059 rows.append("output: name=%r type=%r shape=%r" % (
1060 out.name, _get_type(out), _get_shape(out)))
1061 return "\n".join(rows)