Coverage for mlprodict/plotting/text_plot.py: 92%

518 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# pylint: disable=R0912,R0914,C0302 

2""" 

3@file 

4@brief Text representations of graphs. 

5""" 

6import pprint 

7from collections import OrderedDict 

8import numpy 

9from onnx import TensorProto, AttributeProto 

10from onnx.numpy_helper import to_array 

11from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE 

12from ..tools.graphs import onnx2bigraph 

13from ..onnx_tools.onnx2py_helper import _var_as_dict, get_tensor_shape 

14 

15 

16def onnx_text_plot(model_onnx, recursive=False, graph_type='basic', 

17 grid=5, distance=5): 

18 """ 

19 Uses @see fn onnx2bigraph to convert the ONNX graph 

20 into text. 

21 

22 :param model_onnx: onnx representation 

23 :param recursive: @see fn onnx2bigraph 

24 :param graph_type: @see fn onnx2bigraph 

25 :param grid: @see me display_structure 

26 :param distance: @see fn display_structure 

27 :return: text 

28 

29 .. runpython:: 

30 :showcode: 

31 :warningout: DeprecationWarning 

32 

33 import numpy 

34 from mlprodict.onnx_conv import to_onnx 

35 from mlprodict import __max_supported_opset__ as opv 

36 from mlprodict.plotting.plotting import onnx_text_plot 

37 from mlprodict.npy.xop import loadop 

38 

39 OnnxAdd, OnnxSub = loadop('Add', 'Sub') 

40 

41 idi = numpy.identity(2).astype(numpy.float32) 

42 A = OnnxAdd('X', idi, op_version=opv) 

43 B = OnnxSub(A, 'W', output_names=['Y'], op_version=opv) 

44 onx = B.to_onnx({'X': idi, 'W': idi}) 

45 print(onnx_text_plot(onx)) 

46 """ 

47 bigraph = onnx2bigraph(model_onnx) 

48 graph = bigraph.display_structure() 

49 return graph.to_text() 

50 

51 

52def onnx_text_plot_tree(node): 

53 """ 

54 Gives a textual representation of a tree ensemble. 

55 

56 :param node: `TreeEnsemble*` 

57 :return: text 

58 

59 .. runpython:: 

60 :showcode: 

61 :warningout: DeprecationWarning 

62 

63 import numpy 

64 from sklearn.datasets import load_iris 

65 from sklearn.tree import DecisionTreeRegressor 

66 from mlprodict.onnx_conv import to_onnx 

67 from mlprodict.plotting.plotting import onnx_text_plot_tree 

68 

69 iris = load_iris() 

70 X, y = iris.data.astype(numpy.float32), iris.target 

71 clr = DecisionTreeRegressor(max_depth=3) 

72 clr.fit(X, y) 

73 onx = to_onnx(clr, X) 

74 res = onnx_text_plot_tree(onx.graph.node[0]) 

75 print(res) 

76 """ 

77 def rule(r): 

78 if r == b'BRANCH_LEQ': 

79 return '<=' 

80 if r == b'BRANCH_LT': # pragma: no cover 

81 return '<' 

82 if r == b'BRANCH_GEQ': # pragma: no cover 

83 return '>=' 

84 if r == b'BRANCH_GT': # pragma: no cover 

85 return '>' 

86 if r == b'BRANCH_EQ': # pragma: no cover 

87 return '==' 

88 if r == b'BRANCH_NEQ': # pragma: no cover 

89 return '!=' 

90 raise ValueError( # pragma: no cover 

91 f"Unexpected rule {rule!r}.") 

92 

93 class Node: 

94 "Node representation." 

95 

96 def __init__(self, i, atts): 

97 self.nodes_hitrates = None 

98 self.nodes_missing_value_tracks_true = None 

99 for k, v in atts.items(): 

100 if k.startswith('nodes'): 

101 setattr(self, k, v[i]) 

102 self.depth = 0 

103 self.true_false = '' 

104 

105 def process_node(self): 

106 "node to string" 

107 if self.nodes_modes == b'LEAF': # pylint: disable=E1101 

108 text = "%s y=%r f=%r i=%r" % ( 

109 self.true_false, 

110 self.target_weights, self.target_ids, # pylint: disable=E1101 

111 self.target_nodeids) # pylint: disable=E1101 

112 else: 

113 text = "%s X%d %s %r" % ( 

114 self.true_false, 

115 self.nodes_featureids, # pylint: disable=E1101 

116 rule(self.nodes_modes), # pylint: disable=E1101 

117 self.nodes_values) # pylint: disable=E1101 

118 if self.nodes_hitrates and self.nodes_hitrates != 1: 

119 text += f" hi={self.nodes_hitrates!r}" 

120 if self.nodes_missing_value_tracks_true: 

121 text += f" miss={self.nodes_missing_value_tracks_true!r}" 

122 return f"{' ' * self.depth}{text}" 

123 

124 def process_tree(atts, treeid): 

125 "tree to string" 

126 rows = [f'treeid={treeid!r}'] 

127 if 'base_values' in atts: 

128 if treeid < len(atts['base_values']): 

129 rows.append(f"base_value={atts['base_values'][treeid]!r}") 

130 

131 short = {} 

132 for prefix in ['nodes', 'target', 'class']: 

133 if (f'{prefix}_treeids') not in atts: 

134 continue 

135 idx = [i for i in range(len(atts[f'{prefix}_treeids'])) 

136 if atts[f'{prefix}_treeids'][i] == treeid] 

137 for k, v in atts.items(): 

138 if k.startswith(prefix): 

139 if 'classlabels' in k: 

140 short[k] = list(v) 

141 else: 

142 short[k] = [v[i] for i in idx] 

143 

144 nodes = OrderedDict() 

145 for i in range(len(short['nodes_treeids'])): 

146 nodes[i] = Node(i, short) 

147 prefix = 'target' if 'target_treeids' in short else 'class' 

148 for i in range(len(short[f'{prefix}_treeids'])): 

149 idn = short[f'{prefix}_nodeids'][i] 

150 node = nodes[idn] 

151 node.target_nodeids = idn 

152 node.target_ids = short[f'{prefix}_ids'][i] 

153 node.target_weights = short[f'{prefix}_weights'][i] 

154 

155 def iterate(nodes, node, depth=0, true_false=''): 

156 node.depth = depth 

157 node.true_false = true_false 

158 yield node 

159 if node.nodes_falsenodeids > 0: 

160 for n in iterate(nodes, nodes[node.nodes_falsenodeids], 

161 depth=depth + 1, true_false='F'): 

162 yield n 

163 for n in iterate(nodes, nodes[node.nodes_truenodeids], 

164 depth=depth + 1, true_false='T'): 

165 yield n 

166 

167 for node in iterate(nodes, nodes[0]): 

168 rows.append(node.process_node()) 

169 return rows 

170 

171 if node.op_type in ("TreeEnsembleRegressor", "TreeEnsembleClassifier"): 

172 d = {k: v['value'] for k, v in _var_as_dict(node)['atts'].items()} 

173 atts = {} 

174 for k, v in d.items(): 

175 atts[k] = v if isinstance(v, int) else list(v) 

176 trees = list(sorted(set(atts['nodes_treeids']))) 

177 if 'n_targets' in atts: 

178 rows = [f"n_targets={atts['n_targets']!r}"] 

179 else: 

180 rows = ['n_classes=%r' % len( 

181 atts.get('classlabels_int64s', 

182 atts.get('classlabels_strings', [])))] 

183 rows.append(f'n_trees={len(trees)!r}') 

184 for tree in trees: 

185 r = process_tree(atts, tree) 

186 rows.append('----') 

187 rows.extend(r) 

188 return "\n".join(rows) 

189 

190 raise NotImplementedError( # pragma: no cover 

191 f"Type {node.op_type!r} cannot be displayed.") 

192 

193 

194def _append_succ_pred(subgraphs, successors, predecessors, node_map, node, prefix="", 

195 parent_node_name=None): 

196 node_name = prefix + node.name + "#" + "|".join(node.output) 

197 node_map[node_name] = node 

198 successors[node_name] = [] 

199 predecessors[node_name] = [] 

200 for name in node.input: 

201 predecessors[node_name].append(name) 

202 if name not in successors: 

203 successors[name] = [] 

204 successors[name].append(node_name) 

205 for name in node.output: 

206 successors[node_name].append(name) 

207 predecessors[name] = [node_name] 

208 if node.op_type in {'If', 'Scan', 'Loop', 'Expression'}: 

209 for att in node.attribute: 

210 if (att.type != AttributeProto.GRAPH or # pylint: disable=E1101 

211 not hasattr(att, 'g') or att.g is None): 

212 continue 

213 subgraphs.append((node, att.name, att.g)) 

214 _append_succ_pred_s(subgraphs, successors, predecessors, node_map, 

215 att.g.node, prefix=node_name + ":/:", 

216 parent_node_name=node_name, 

217 parent_graph=att.g) 

218 

219 

220def _append_succ_pred_s(subgraphs, successors, predecessors, node_map, nodes, prefix="", 

221 parent_node_name=None, parent_graph=None): 

222 for node in nodes: 

223 _append_succ_pred(subgraphs, successors, predecessors, node_map, node, 

224 prefix=prefix, parent_node_name=parent_node_name) 

225 if parent_node_name is not None: 

226 unknown = set() 

227 known = {} 

228 for i in parent_graph.initializer: 

229 known[i.name] = None 

230 for i in parent_graph.input: 

231 known[i.name] = None 

232 for n in parent_graph.node: 

233 for i in n.input: 

234 if i not in known: 

235 unknown.add(i) 

236 for i in n.output: 

237 known[i] = n 

238 if len(unknown) > 0: 

239 # These inputs are coming from the graph below. 

240 for name in unknown: 

241 successors[name].append(parent_node_name) 

242 predecessors[parent_node_name].append(name) 

243 

244 

245def graph_predecessors_and_successors(graph): 

246 """ 

247 Returns the successors and the predecessors within on ONNX graph. 

248 """ 

249 node_map = {} 

250 successors = {} 

251 predecessors = {} 

252 subgraphs = [] 

253 _append_succ_pred_s(subgraphs, successors, 

254 predecessors, node_map, graph.node) 

255 return subgraphs, predecessors, successors, node_map 

256 

257 

258def get_hidden_inputs(nodes): 

259 """ 

260 Returns the list of hidden inputs used by subgraphs. 

261 

262 :param nodes: list of nodes 

263 :return: list of names 

264 """ 

265 inputs = set() 

266 outputs = set() 

267 for node in nodes: 

268 inputs |= set(node.input) 

269 outputs |= set(node.output) 

270 for att in node.attribute: 

271 if (att.type != AttributeProto.GRAPH or # pylint: disable=E1101 

272 not hasattr(att, 'g') or att.g is None): 

273 continue 

274 hidden = get_hidden_inputs(att.g.node) 

275 inits = set(i.name for i in att.g.initializer) 

276 inits |= set(i.name for i in att.g.sparse_initializer) 

277 inputs |= hidden - (inits & hidden) 

278 return inputs - (outputs & inputs) 

279 

280 

281def reorder_nodes_for_display(nodes, verbose=False): 

282 """ 

283 Reorders the node with breadth first seach (BFS). 

284 

285 :param nodes: list of ONNX nodes 

286 :param verbose: dislay intermediate informations 

287 :return: reordered list of nodes 

288 """ 

289 class temp: 

290 "Fake GraphProto." 

291 

292 def __init__(self, nodes): 

293 self.node = nodes 

294 

295 _, predecessors, successors, dnodes = graph_predecessors_and_successors( 

296 temp(nodes)) 

297 local_variables = get_hidden_inputs(nodes) 

298 

299 all_outputs = set() 

300 all_inputs = set(local_variables) 

301 for node in nodes: 

302 all_outputs |= set(node.output) 

303 all_inputs |= set(node.input) 

304 common = all_outputs & all_inputs 

305 

306 successors = {k: set(v) for k, v in successors.items()} 

307 predecessors = {k: set(v) for k, v in predecessors.items()} 

308 if verbose: 

309 pprint.pprint( # pragma: no cover 

310 ["[reorder_nodes_for_display]", "predecessors", 

311 predecessors, "successors", successors]) 

312 

313 known = all_inputs - common 

314 new_nodes = [] 

315 done = set() 

316 

317 def _find_sequence(node_name, known, done): 

318 inputs = dnodes[node_name].input 

319 if any(map(lambda i: i not in known, inputs)): 

320 return [] 

321 

322 res = [node_name] 

323 while res[-1] in successors: 

324 next_names = successors[res[-1]] 

325 if res[-1] not in dnodes: 

326 next_names = set(v for v in next_names if v not in known) 

327 if len(next_names) == 1: 

328 next_name = next_names.pop() 

329 inputs = dnodes[next_name].input 

330 if any(map(lambda i: i not in known, inputs)): 

331 break 

332 res.extend(next_name) 

333 else: 

334 break 

335 else: 

336 next_names = set(v for v in next_names if v not in done) 

337 if len(next_names) == 1: 

338 next_name = next_names.pop() 

339 res.append(next_name) 

340 else: 

341 break 

342 

343 return [r for r in res if r in dnodes and r not in done] 

344 

345 while len(done) < len(nodes): 

346 # possible 

347 possibles = OrderedDict() 

348 for k, v in dnodes.items(): 

349 if k in done: 

350 continue 

351 if ':/:' in k: 

352 # node part of a sub graph (assuming :/: is never used in a node name) 

353 continue 

354 if predecessors[k] <= known: 

355 possibles[k] = v 

356 

357 sequences = OrderedDict() 

358 for k, v in possibles.items(): 

359 if k in done: 

360 continue 

361 sequences[k] = _find_sequence(k, known, done) 

362 if verbose: 

363 print("[reorder_nodes_for_display] * sequence(%s)=%s - %r" % ( 

364 k, ",".join(sequences[k]), list(sequences))) 

365 

366 if len(sequences) == 0: 

367 raise RuntimeError( # pragma: no cover 

368 "Unexpected empty sequence (len(possibles)=%d, " 

369 "len(done)=%d, len(nodes)=%d). This is usually due to " 

370 "a name used both as result name and node node. " 

371 "known=%r." % (len(possibles), len(done), len(nodes), known)) 

372 

373 # find the best sequence 

374 best = None 

375 for k, v in sequences.items(): 

376 if best is None or len(v) > len(sequences[best]): 

377 # if the sequence of successors is longer 

378 best = k 

379 elif len(v) == len(sequences[best]): 

380 if len(new_nodes) > 0: 

381 # then choose the next successor sharing input with 

382 # previous output 

383 so = set(new_nodes[-1].output) 

384 first1 = dnodes[sequences[best][0]] 

385 first2 = dnodes[v[0]] 

386 if len(set(first1.input) & so) < len(set(first2.input) & so): 

387 best = k 

388 else: 

389 first1 = dnodes[sequences[best][0]] 

390 first2 = dnodes[v[0]] 

391 if first1.op_type > first2.op_type: 

392 best = k 

393 elif (first1.op_type == first2.op_type and 

394 first1.name > first2.name): 

395 best = k 

396 

397 if best is None: 

398 raise RuntimeError( # pragma: no cover 

399 f"Wrong implementation (len(sequence)={len(sequences)}).") 

400 if verbose: 

401 print("[reorder_nodes_for_display] BEST: sequence(%s)=%s" % ( 

402 best, ",".join(sequences[best]))) 

403 

404 # process the sequence 

405 for k in sequences[best]: 

406 v = dnodes[k] 

407 new_nodes.append(v) 

408 if verbose: 

409 print( 

410 f"[reorder_nodes_for_display] + {v.name!r} ({v.op_type!r})") 

411 done.add(k) 

412 known |= set(v.output) 

413 

414 if len(new_nodes) != len(nodes): 

415 raise RuntimeError( # pragma: no cover 

416 "The returned new nodes are different. " 

417 "len(nodes=%d) != %d=len(new_nodes). done=\n%r" 

418 "\n%s\n----------\n%s" % ( 

419 len(nodes), len(new_nodes), done, 

420 "\n".join("%d - %s - %s - %s" % ( 

421 (n.name + "".join(n.output)) in done, 

422 n.op_type, n.name, n.name + "".join(n.output)) 

423 for n in nodes), 

424 "\n".join("%d - %s - %s - %s" % ( 

425 (n.name + "".join(n.output)) in done, 

426 n.op_type, n.name, n.name + "".join(n.output)) 

427 for n in new_nodes))) 

428 n0s = set(n.name for n in nodes) 

429 n1s = set(n.name for n in new_nodes) 

430 if n0s != n1s: 

431 raise RuntimeError( # pragma: no cover 

432 "The returned new nodes are different.\n" 

433 "%r !=\n%r\ndone=\n%r" 

434 "\n----------\n%s\n----------\n%s" % ( 

435 n0s, n1s, done, 

436 "\n".join("%d - %s - %s - %s" % ( 

437 (n.name + "".join(n.output)) in done, 

438 n.op_type, n.name, n.name + "".join(n.output)) 

439 for n in nodes), 

440 "\n".join("%d - %s - %s - %s" % ( 

441 (n.name + "".join(n.output)) in done, 

442 n.op_type, n.name, n.name + "".join(n.output)) 

443 for n in new_nodes))) 

444 return new_nodes 

445 

446 

447def _get_type(obj0): 

448 obj = obj0 

449 if hasattr(obj, 'data_type'): 

450 if (obj.data_type == TensorProto.FLOAT and # pylint: disable=E1101 

451 hasattr(obj, 'float_data')): 

452 return TENSOR_TYPE_TO_NP_TYPE[TensorProto.FLOAT] # pylint: disable=E1101 

453 if (obj.data_type == TensorProto.DOUBLE and # pylint: disable=E1101 

454 hasattr(obj, 'double_data')): 

455 return TENSOR_TYPE_TO_NP_TYPE[TensorProto.DOUBLE] # pylint: disable=E1101 

456 if (obj.data_type == TensorProto.INT64 and # pylint: disable=E1101 

457 hasattr(obj, 'int64_data')): 

458 return TENSOR_TYPE_TO_NP_TYPE[TensorProto.INT64] # pylint: disable=E1101 

459 if (obj.data_type == TensorProto.INT32 and # pylint: disable=E1101 

460 hasattr(obj, 'int32_data')): 

461 return TENSOR_TYPE_TO_NP_TYPE[TensorProto.INT32] # pylint: disable=E1101 

462 if hasattr(obj, 'raw_data') and len(obj.raw_data) > 0: 

463 arr = to_array(obj) 

464 return arr.dtype 

465 raise RuntimeError( # pragma: no cover 

466 f"Unable to guess type from {obj0!r}.") 

467 if hasattr(obj, 'type'): 

468 obj = obj.type 

469 if hasattr(obj, 'tensor_type'): 

470 obj = obj.tensor_type 

471 if hasattr(obj, 'elem_type'): 

472 return TENSOR_TYPE_TO_NP_TYPE.get(obj.elem_type, '?') 

473 raise RuntimeError( # pragma: no cover 

474 f"Unable to guess type from {obj0!r}.") 

475 

476 

477def _get_shape(obj): 

478 try: 

479 arr = to_array(obj) 

480 return arr.shape 

481 except Exception: # pylint: disable=W0703 

482 pass 

483 obj0 = obj 

484 if hasattr(obj, 'data_type'): 

485 if (obj.data_type == TensorProto.FLOAT and # pylint: disable=E1101 

486 hasattr(obj, 'float_data')): 

487 return (len(obj.float_data), ) 

488 if (obj.data_type == TensorProto.DOUBLE and # pylint: disable=E1101 

489 hasattr(obj, 'double_data')): 

490 return (len(obj.double_data), ) 

491 if (obj.data_type == TensorProto.INT64 and # pylint: disable=E1101 

492 hasattr(obj, 'int64_data')): 

493 return (len(obj.int64_data), ) 

494 if (obj.data_type == TensorProto.INT32 and # pylint: disable=E1101 

495 hasattr(obj, 'int32_data')): 

496 return (len(obj.int32_data), ) 

497 if hasattr(obj, 'raw_data') and len(obj.raw_data) > 0: 

498 arr = to_array(obj) 

499 return arr.shape 

500 raise RuntimeError( # pragma: no cover 

501 f"Unable to guess type from {obj0!r}, " 

502 f"data_type is {obj.data_type!r}.") 

503 if hasattr(obj, 'type'): 

504 obj = obj.type 

505 if hasattr(obj, 'tensor_type'): 

506 return get_tensor_shape(obj) 

507 raise RuntimeError( # pragma: no cover 

508 f"Unable to guess type from {obj0!r}.") 

509 

510 

511def onnx_simple_text_plot(model, verbose=False, att_display=None, # pylint: disable=R0915 

512 add_links=False, recursive=False, functions=True, 

513 raise_exc=True, sub_graphs_names=None, 

514 level=1, indent=True): 

515 """ 

516 Displays an ONNX graph into text. 

517 

518 :param model: ONNX graph 

519 :param verbose: display debugging information 

520 :param att_display: list of attributes to display, if None, 

521 a default list if used 

522 :param add_links: displays links of the right side 

523 :param recursive: display subgraphs as well 

524 :param functions: display functions as well 

525 :param raise_exc: raises an exception if the model is not valid, 

526 otherwise tries to continue 

527 :param sub_graphs_names: list of sub-graphs names 

528 :param level: sub-graph level 

529 :param indent: use indentation or not 

530 :return: str 

531 

532 An ONNX graph is printed the following way: 

533 

534 .. runpython:: 

535 :showcode: 

536 :warningout: DeprecationWarning 

537 

538 import numpy 

539 from sklearn.cluster import KMeans 

540 from mlprodict.plotting.plotting import onnx_simple_text_plot 

541 from mlprodict.onnx_conv import to_onnx 

542 

543 x = numpy.random.randn(10, 3) 

544 y = numpy.random.randn(10) 

545 model = KMeans(3) 

546 model.fit(x, y) 

547 onx = to_onnx(model, x.astype(numpy.float32), 

548 target_opset=15) 

549 text = onnx_simple_text_plot(onx, verbose=False) 

550 print(text) 

551 

552 The same graphs with links. 

553 

554 .. runpython:: 

555 :showcode: 

556 :warningout: DeprecationWarning 

557 

558 import numpy 

559 from sklearn.cluster import KMeans 

560 from mlprodict.plotting.plotting import onnx_simple_text_plot 

561 from mlprodict.onnx_conv import to_onnx 

562 

563 x = numpy.random.randn(10, 3) 

564 y = numpy.random.randn(10) 

565 model = KMeans(3) 

566 model.fit(x, y) 

567 onx = to_onnx(model, x.astype(numpy.float32), 

568 target_opset=15) 

569 text = onnx_simple_text_plot(onx, verbose=False, add_links=True) 

570 print(text) 

571 

572 Visually, it looks like the following: 

573 

574 .. gdot:: 

575 :script: DOT-SECTION 

576 

577 import numpy 

578 from sklearn.cluster import KMeans 

579 from mlprodict.onnxrt import OnnxInference 

580 from mlprodict.onnx_conv import to_onnx 

581 

582 x = numpy.random.randn(10, 3) 

583 y = numpy.random.randn(10) 

584 model = KMeans(3) 

585 model.fit(x, y) 

586 model_onnx = to_onnx(model, x.astype(numpy.float32), 

587 target_opset=15) 

588 oinf = OnnxInference(model_onnx, inplace=False) 

589 

590 print("DOT-SECTION", oinf.to_dot()) 

591 """ 

592 use_indentation = indent 

593 if att_display is None: 

594 att_display = [ 

595 'activations', 

596 'align_corners', 

597 'allowzero', 

598 'alpha', 

599 'auto_pad', 

600 'axis', 

601 'axes', 

602 'batch_axis', 

603 'batch_dims', 

604 'beta', 

605 'bias', 

606 'blocksize', 

607 'case_change_action', 

608 'ceil_mode', 

609 'center_point_box', 

610 'clip', 

611 'coordinate_transformation_mode', 

612 'count_include_pad', 

613 'cubic_coeff_a', 

614 'decay_factor', 

615 'detect_negative', 

616 'detect_positive', 

617 'dilation', 

618 'dilations', 

619 'direction', 

620 'dtype', 

621 'end', 

622 'epsilon', 

623 'equation', 

624 'exclusive', 

625 'exclude_outside', 

626 'extrapolation_value', 

627 'fmod', 

628 'gamma', 

629 'group', 

630 'hidden_size', 

631 'high', 

632 'ignore_index', 

633 'input_forget', 

634 'is_case_sensitive', 

635 'k', 

636 'keepdims', 

637 'kernel_shape', 

638 'lambd', 

639 'largest', 

640 'layout', 

641 'linear_before_reset', 

642 'locale', 

643 'low', 

644 'max_gram_length', 

645 'max_skip_count', 

646 'mean', 

647 'min_gram_length', 

648 'mode', 

649 'momentum', 

650 'nearest_mode', 

651 'ngram_counts', 

652 'ngram_indexes', 

653 'noop_with_empty_axes', 

654 'norm_coefficient', 

655 'norm_coefficient_post', 

656 'num_scan_inputs', 

657 'output_height', 

658 'output_padding', 

659 'output_shape', 

660 'output_width', 

661 'p', 

662 'padding_mode', 

663 'pads', 

664 'perm', 

665 'pooled_shape', 

666 'reduction', 

667 'reverse', 

668 'sample_size', 

669 'sampling_ratio', 

670 'scale', 

671 'scan_input_axes', 

672 'scan_input_directions', 

673 'scan_output_axes', 

674 'scan_output_directions', 

675 'seed', 

676 'select_last_index', 

677 'size', 

678 'sorted', 

679 'spatial_scale', 

680 'start', 

681 'storage_order', 

682 'strides', 

683 'time_axis', 

684 'to', 

685 'training_mode', 

686 'transA', 

687 'transB', 

688 'type', 

689 'upper', 

690 'xs', 

691 'y', 

692 'zs', 

693 ] 

694 

695 if sub_graphs_names is None: 

696 sub_graphs_names = {} 

697 

698 def _get_subgraph_name(idg): 

699 if idg in sub_graphs_names: 

700 return sub_graphs_names[idg] 

701 g = "G%d" % (len(sub_graphs_names) + 1) 

702 sub_graphs_names[idg] = g 

703 return g 

704 

705 def str_node(indent, node): 

706 atts = [] 

707 if hasattr(node, 'attribute'): 

708 for att in node.attribute: 

709 done = True 

710 if hasattr(att, "ref_attr_name") and att.ref_attr_name: 

711 atts.append(f"{att.name}=${att.ref_attr_name}") 

712 continue 

713 if att.name in att_display: 

714 if att.type == AttributeProto.INT: # pylint: disable=E1101 

715 atts.append("%s=%d" % (att.name, att.i)) 

716 elif att.type == AttributeProto.FLOAT: # pylint: disable=E1101 

717 atts.append(f"{att.name}={att.f:1.2f}") 

718 elif att.type == AttributeProto.INTS: # pylint: disable=E1101 

719 atts.append("%s=%s" % (att.name, str( 

720 list(att.ints)).replace(" ", ""))) 

721 else: 

722 done = False 

723 elif (att.type == AttributeProto.GRAPH and # pylint: disable=E1101 

724 hasattr(att, 'g') and att.g is not None): 

725 atts.append(f"{att.name}={_get_subgraph_name(id(att.g))}") 

726 else: 

727 done = False 

728 if done: 

729 continue 

730 if att.type in (AttributeProto.TENSOR, # pylint: disable=E1101 

731 AttributeProto.TENSORS, # pylint: disable=E1101 

732 AttributeProto.SPARSE_TENSOR, # pylint: disable=E1101 

733 AttributeProto.SPARSE_TENSORS): # pylint: disable=E1101 

734 try: 

735 val = str(to_array(att.t).tolist()) 

736 except TypeError as e: # pragma: no cover 

737 raise TypeError( 

738 "Unable to display tensor type %r.\n%s" % ( 

739 att.type, str(att))) from e 

740 if "\n" in val: 

741 val = val.split("\n", maxsplit=1) + "..." 

742 if len(val) > 10: 

743 val = val[:10] + "..." 

744 elif att.type == AttributeProto.STRING: # pylint: disable=E1101 

745 val = str(att.s) 

746 elif att.type == AttributeProto.STRINGS: # pylint: disable=E1101 

747 n_val = list(att.strings) 

748 if len(n_val) < 5: 

749 val = ",".join(map(str, n_val)) 

750 else: 

751 val = "%d:[%s...%s]" % ( 

752 len(n_val), 

753 ",".join(map(str, n_val[:2])), 

754 ",".join(map(str, n_val[-2:]))) 

755 elif att.type == AttributeProto.INT: # pylint: disable=E1101 

756 val = str(att.i) 

757 elif att.type == AttributeProto.FLOAT: # pylint: disable=E1101 

758 val = str(att.f) 

759 elif att.type == AttributeProto.INTS: # pylint: disable=E1101 

760 n_val = list(att.ints) 

761 if len(n_val) < 6: 

762 val = f"[{','.join(map(str, n_val))}]" 

763 else: 

764 val = "%d:[%s...%s]" % ( 

765 len(n_val), 

766 ",".join(map(str, n_val[:3])), 

767 ",".join(map(str, n_val[-3:]))) 

768 elif att.type == AttributeProto.FLOATS: # pylint: disable=E1101 

769 n_val = list(att.floats) 

770 if len(n_val) < 5: 

771 val = f"[{','.join(map(str, n_val))}]" 

772 else: 

773 val = "%d:[%s...%s]" % ( 

774 len(n_val), 

775 ",".join(map(str, n_val[:2])), 

776 ",".join(map(str, n_val[-2:]))) 

777 else: 

778 val = '.%d' % att.type 

779 atts.append(f"{att.name}={val}") 

780 inputs = list(node.input) 

781 if len(atts) > 0: 

782 inputs.extend(atts) 

783 if node.domain in ('', 'ai.onnx.ml'): 

784 domain = '' 

785 else: 

786 domain = f'[{node.domain}]' 

787 return "%s%s%s(%s) -> %s" % ( 

788 " " * indent, node.op_type, domain, 

789 ", ".join(inputs), ", ".join(node.output)) 

790 

791 rows = [] 

792 if hasattr(model, 'opset_import'): 

793 for opset in model.opset_import: 

794 rows.append( 

795 f"opset: domain={opset.domain!r} version={opset.version!r}") 

796 if hasattr(model, 'graph'): 

797 if model.doc_string: 

798 rows.append(f'doc_string: {model.doc_string}') 

799 main_model = model 

800 model = model.graph 

801 else: 

802 main_model = None 

803 

804 # inputs 

805 line_name_new = {} 

806 line_name_in = {} 

807 if level == 0: 

808 rows.append("----- input ----") 

809 for inp in model.input: 

810 if isinstance(inp, str): 

811 rows.append(f"input: {inp!r}") 

812 else: 

813 line_name_new[inp.name] = len(rows) 

814 rows.append("input: name=%r type=%r shape=%r" % ( 

815 inp.name, _get_type(inp), _get_shape(inp))) 

816 if hasattr(model, 'attribute'): 

817 for att in model.attribute: 

818 if isinstance(att, str): 

819 rows.append(f"attribute: {att!r}") 

820 else: 

821 raise NotImplementedError( # pragma: no cover 

822 "Not yet introduced in onnx.") 

823 

824 # initializer 

825 if hasattr(model, 'initializer'): 

826 if len(model.initializer) and level == 0: 

827 rows.append("----- initializer ----") 

828 for init in model.initializer: 

829 if numpy.prod(_get_shape(init)) < 5: 

830 content = f" -- {to_array(init).ravel()!r}" 

831 else: 

832 content = "" 

833 line_name_new[init.name] = len(rows) 

834 rows.append("init: name=%r type=%r shape=%r%s" % ( 

835 init.name, _get_type(init), _get_shape(init), content)) 

836 if level == 0: 

837 rows.append("----- main graph ----") 

838 

839 # successors, predecessors, it needs to support subgraphs 

840 subgraphs = graph_predecessors_and_successors(model)[0] 

841 

842 # walk through nodes 

843 init_names = set() 

844 indents = {} 

845 for inp in model.input: 

846 if isinstance(inp, str): 

847 indents[inp] = 0 

848 init_names.add(inp) 

849 else: 

850 indents[inp.name] = 0 

851 init_names.add(inp.name) 

852 if hasattr(model, 'initializer'): 

853 for init in model.initializer: 

854 indents[init.name] = 0 

855 init_names.add(init.name) 

856 

857 try: 

858 nodes = reorder_nodes_for_display(model.node, verbose=verbose) 

859 except RuntimeError as e: # pragma: no cover 

860 if raise_exc: 

861 raise e 

862 else: 

863 rows.append(f"ERROR: {e}") 

864 nodes = model.node 

865 

866 previous_indent = None 

867 previous_out = None 

868 previous_in = None 

869 for node in nodes: 

870 add_break = False 

871 name = node.name + "#" + "|".join(node.output) 

872 if name in indents: 

873 indent = indents[name] 

874 if previous_indent is not None and indent < previous_indent: 

875 if verbose: 

876 print(f"[onnx_simple_text_plot] break1 {node.op_type}") 

877 add_break = True 

878 elif previous_in is not None and set(node.input) == previous_in: 

879 indent = previous_indent 

880 else: 

881 inds = [indents.get(i, 0) 

882 for i in node.input if i not in init_names] 

883 if len(inds) == 0: 

884 indent = 0 

885 else: 

886 mi = min(inds) 

887 indent = mi 

888 if previous_indent is not None and indent < previous_indent: 

889 if verbose: 

890 print( # pragma: no cover 

891 f"[onnx_simple_text_plot] break2 {node.op_type}") 

892 add_break = True 

893 if not add_break and previous_out is not None: 

894 if len(set(node.input) & previous_out) == 0: 

895 if verbose: 

896 print(f"[onnx_simple_text_plot] break3 {node.op_type}") 

897 add_break = True 

898 indent = 0 

899 

900 if add_break and verbose: 

901 print("[onnx_simple_text_plot] add break") 

902 for n in node.input: 

903 if n in line_name_in: 

904 line_name_in[n].append(len(rows)) 

905 else: 

906 line_name_in[n] = [len(rows)] 

907 for n in node.output: 

908 line_name_new[n] = len(rows) 

909 rows.append(str_node(indent if use_indentation else 0, node)) 

910 indents[name] = indent 

911 

912 for i, o in enumerate(node.output): 

913 indents[o] = indent + 1 

914 

915 previous_indent = indents[name] 

916 previous_out = set(node.output) 

917 previous_in = set(node.input) 

918 

919 # outputs 

920 if level == 0: 

921 rows.append("----- output ----") 

922 for out in model.output: 

923 if isinstance(out, str): 

924 if out in line_name_in: 

925 line_name_in[out].append(len(rows)) 

926 else: 

927 line_name_in[out] = [len(rows)] 

928 rows.append(f"output: name={out!r} type={'?'} shape={'?'}") 

929 else: 

930 if out.name in line_name_in: 

931 line_name_in[out.name].append(len(rows)) 

932 else: 

933 line_name_in[out.name] = [len(rows)] 

934 rows.append("output: name=%r type=%r shape=%r" % ( 

935 out.name, _get_type(out), _get_shape(out))) 

936 

937 if add_links: 

938 

939 def _mark_link(rows, lengths, r1, r2, d): 

940 maxl = max(lengths[r1], lengths[r2]) + d * 2 

941 maxl = max(maxl, max(len(rows[r]) for r in range(r1, r2 + 1))) + 2 

942 

943 if rows[r1][-1] == '|': 

944 p1, p2 = rows[r1][:lengths[r1] + 2], rows[r1][lengths[r1] + 2:] 

945 rows[r1] = p1 + p2.replace(' ', '-') 

946 rows[r1] += ("-" * (maxl - len(rows[r1]) - 1)) + "+" 

947 

948 if rows[r2][-1] == " ": 

949 rows[r2] += "<" 

950 elif rows[r2][-1] == '|': 

951 if "<" not in rows[r2]: 

952 p = lengths[r2] 

953 rows[r2] = rows[r2][:p] + '<' + rows[r2][p + 1:] 

954 p1, p2 = rows[r2][:lengths[r2] + 2], rows[r2][lengths[r2] + 2:] 

955 rows[r2] = p1 + p2.replace(' ', '-') 

956 rows[r2] += ("-" * (maxl - len(rows[r2]) - 1)) + "+" 

957 

958 for r in range(r1 + 1, r2): 

959 if len(rows[r]) < maxl: 

960 rows[r] += " " * (maxl - len(rows[r]) - 1) 

961 rows[r] += "|" 

962 

963 diffs = [] 

964 for n, r1 in line_name_new.items(): 

965 if n not in line_name_in: 

966 continue 

967 r2s = line_name_in[n] 

968 for r2 in r2s: 

969 if r1 >= r2: 

970 continue 

971 diffs.append((r2 - r1, (n, r1, r2))) 

972 diffs.sort() 

973 for i in range(len(rows)): # pylint: disable=C0200 

974 rows[i] += " " 

975 lengths = [len(r) for r in rows] 

976 

977 for d, (n, r1, r2) in diffs: 

978 if d == 1 and len(line_name_in[n]) == 1: 

979 # no line for link to the next node 

980 continue 

981 _mark_link(rows, lengths, r1, r2, d) 

982 

983 # subgraphs 

984 if recursive: 

985 for node, name, g in subgraphs: 

986 rows.append('----- subgraph ---- %s - %s - att.%s=%s -- level=%d -- %s -> %s' % ( 

987 node.op_type, node.name, name, _get_subgraph_name(id(g)), 

988 level, ','.join(i.name for i in g.input), 

989 ','.join(i.name for i in g.output))) 

990 res = onnx_simple_text_plot( 

991 g, verbose=verbose, att_display=att_display, 

992 add_links=add_links, recursive=recursive, 

993 sub_graphs_names=sub_graphs_names, level=level + 1, 

994 raise_exc=raise_exc) 

995 rows.append(res) 

996 

997 # functions 

998 if functions and main_model is not None: 

999 for fct in main_model.functions: 

1000 rows.append(f'----- function name={fct.name} domain={fct.domain}') 

1001 if fct.doc_string: 

1002 rows.append(f'----- doc_string: {fct.doc_string}') 

1003 res = onnx_simple_text_plot( 

1004 fct, verbose=verbose, att_display=att_display, 

1005 add_links=add_links, recursive=recursive, 

1006 functions=False, sub_graphs_names=sub_graphs_names, 

1007 level=1) 

1008 rows.append(res) 

1009 

1010 return "\n".join(rows) 

1011 

1012 

1013def onnx_text_plot_io(model, verbose=False, att_display=None): 

1014 """ 

1015 Displays information about input and output types. 

1016 

1017 :param model: ONNX graph 

1018 :param verbose: display debugging information 

1019 :return: str 

1020 

1021 An ONNX graph is printed the following way: 

1022 

1023 .. runpython:: 

1024 :showcode: 

1025 :warningout: DeprecationWarning 

1026 

1027 import numpy 

1028 from sklearn.cluster import KMeans 

1029 from mlprodict.plotting.plotting import onnx_text_plot_io 

1030 from mlprodict.onnx_conv import to_onnx 

1031 

1032 x = numpy.random.randn(10, 3) 

1033 y = numpy.random.randn(10) 

1034 model = KMeans(3) 

1035 model.fit(x, y) 

1036 onx = to_onnx(model, x.astype(numpy.float32), 

1037 target_opset=15) 

1038 text = onnx_text_plot_io(onx, verbose=False) 

1039 print(text) 

1040 """ 

1041 rows = [] 

1042 if hasattr(model, 'opset_import'): 

1043 for opset in model.opset_import: 

1044 rows.append( 

1045 f"opset: domain={opset.domain!r} version={opset.version!r}") 

1046 if hasattr(model, 'graph'): 

1047 model = model.graph 

1048 

1049 # inputs 

1050 for inp in model.input: 

1051 rows.append("input: name=%r type=%r shape=%r" % ( 

1052 inp.name, _get_type(inp), _get_shape(inp))) 

1053 # initializer 

1054 for init in model.initializer: 

1055 rows.append("init: name=%r type=%r shape=%r" % ( 

1056 init.name, _get_type(init), _get_shape(init))) 

1057 # outputs 

1058 for out in model.output: 

1059 rows.append("output: name=%r type=%r shape=%r" % ( 

1060 out.name, _get_type(out), _get_shape(out))) 

1061 return "\n".join(rows)