Coverage for mlprodict/onnx_tools/onnx_manipulations.py: 94%

950 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# Copyright (c) Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT License. 

3# pylint: disable=E1101, C0302 

4 

5""" 

6@file 

7@brief Implements a class able to compute the predictions 

8from on an :epkg:`ONNX` model. 

9""" 

10import hashlib 

11from collections import Counter 

12import pprint 

13from onnx import ( 

14 shape_inference, ModelProto, FunctionProto, GraphProto, 

15 AttributeProto) 

16from onnx.helper import ( 

17 make_tensor_value_info, ValueInfoProto, set_model_props, 

18 make_graph, make_function, make_model, make_node, 

19 make_operatorsetid, make_attribute, make_value_info) 

20from .onnx2py_helper import ( 

21 guess_proto_dtype, from_array, get_tensor_shape, 

22 get_tensor_elem_type) 

23from .optim import onnx_remove_node_unused 

24from .onnx_tools import enumerate_onnx_names, enumerate_onnx_nodes 

25from ..onnx_tools.onnx2py_helper import _var_as_dict, from_array 

26 

27 

28def enumerate_model_node_outputs(model, add_node=False, order=False): 

29 """ 

30 Enumerates all the nodes of a model. 

31 

32 :param model: :epkg:`ONNX` graph 

33 :param add_node: if False, the function enumerates 

34 all output names from every node, otherwise, it 

35 enumerates tuple (output name, node) 

36 :param order: goes through outputs following the graph order 

37 :return: enumerator 

38 """ 

39 if not hasattr(model, "graph"): 

40 raise TypeError( # pragma: no cover 

41 f"Parameter model is not an ONNX model but {type(model)}") 

42 if order: 

43 edges = [] 

44 order = {} 

45 node_names = {} 

46 for inp in model.graph.input: 

47 order[0, inp.name] = 0 

48 for node in model.graph.node: 

49 order[1, node.name] = 0 

50 for i in node.input: 

51 edges.append(('in', i, node.name)) 

52 for o in node.output: 

53 edges.append(('out', o, node.name)) 

54 node_names[o] = node 

55 order[0, o] = 0 

56 

57 modif = 1 

58 n_iter = 0 

59 while modif > 0 and n_iter <= len(model.graph.node): 

60 modif = 0 

61 n_iter += 1 

62 for kind, data_name, node_name in edges: 

63 if kind == 'in': 

64 if (0, data_name) not in order: 

65 continue 

66 if order[0, data_name] + 1 > order[1, node_name]: 

67 modif += 1 

68 order[1, node_name] = order[0, data_name] + 1 

69 else: 

70 if order[1, node_name] + 1 > order[0, data_name]: 

71 modif += 1 

72 order[0, data_name] = order[1, node_name] + 1 

73 

74 orders = [(v, k) for k, v in order.items()] 

75 orders.sort() 

76 

77 for _, k in orders: 

78 if k[0] == 1: 

79 continue 

80 out = k[1] 

81 if out not in node_names: 

82 continue 

83 yield (out, node_names[out]) if add_node else out 

84 else: 

85 for node in model.graph.node: 

86 for out in node.output: 

87 yield (out, node) if add_node else out 

88 

89 

90def get_opsets(model, include_functions=True, exc=True): 

91 """ 

92 Enumerates all opsets used in a model. 

93 

94 :param model: :epkg:`ModelProto` or :epkg:`FunctionProto` 

95 :param include_functions: include opsets used in functions 

96 :param exc: raise an exception if conflicts are detected 

97 :return: dictionary 

98 """ 

99 if isinstance(model, ModelProto): 

100 res = {} 

101 for op in model.opset_import: 

102 if exc and op.domain in res: 

103 raise ValueError( # pragma: no cover 

104 f"Domain {op.domain!r} appears multiple times.") 

105 res[op.domain] = op.version 

106 if include_functions: 

107 for f in model.functions: 

108 ops = get_opsets(f, exc=exc) 

109 for k, v in ops.items(): 

110 if k in res: 

111 if res[k] != v: 

112 if exc: 

113 raise ValueError( # pragma: no cover 

114 "Domain %r has different version in " 

115 "main graph (%d) and function %r " 

116 "(%d)." % (k, res[k], f.name, v)) 

117 res[k] = max(res[k], v) 

118 else: 

119 res[k] = v 

120 return res 

121 

122 res = {} 

123 for op in model.opset_import: 

124 if exc and op.domain in res: 

125 raise ValueError( # pragma: no cover 

126 f"Domain {op.domain!r} appears multiple times.") 

127 res[op.domain] = op.version 

128 return res 

129 

130 

131def get_hidden_inputs(nodes): 

132 """ 

133 Returns the list of hidden inputs used by subgraphs. 

134 

135 :param nodes: list of nodes 

136 :return: list of names 

137 """ 

138 inputs = set() 

139 outputs = set() 

140 for node in nodes: 

141 inputs |= set(node.input) 

142 outputs |= set(node.output) 

143 for att in node.attribute: 

144 if (att.type != AttributeProto.GRAPH or # pylint: disable=E1101 

145 not hasattr(att, 'g') or att.g is None): 

146 continue 

147 hidden = get_hidden_inputs(att.g.node) 

148 inits = set(att.g.initializer) 

149 inputs |= hidden - (inits & hidden) 

150 return inputs - (outputs & inputs) 

151 

152 

153def select_model_inputs_outputs(model, outputs=None, inputs=None, 

154 infer_shapes=False, overwrite=None, 

155 remove_unused=True, 

156 verbose=0, fLOG=None): 

157 """ 

158 Takes a model and changes its outputs. 

159 

160 :param model: :epkg:`ONNX` model 

161 :param inputs: new inputs, same ones if None 

162 :param outputs: new outputs, same ones if None 

163 :param infer_shapes: infer inputs and outputs shapes 

164 :param overwrite: overwrite type and shapes for 

165 inputs or outputs, *overwrite* is a 

166 dictionary `{'name': (numpy dtype, shape)}` 

167 :param remove_unused: remove unused nodes from the graph 

168 :param verbose: display information while converting 

169 :param fLOG: logging function 

170 :return: modified model 

171 

172 The function removes unneeded nodes. 

173 

174 .. exref:: 

175 :title: Change ONNX model inputs 

176 

177 The following exampels shows how to change the inputs of model 

178 to bypass the first nodes. Shape inferences fails to determine 

179 the new inputs type. They need to be overwritten. 

180 `verbose=1, fLOG=print` shows the number of deleted nodes. 

181 

182 :: 

183 

184 import onnx 

185 from mlprodict.onnx_tools.onnx_manipulations import select_model_inputs_outputs 

186 

187 onx = onnx.load(path) 

188 onx2 = select_model_inputs_outputs( 

189 onx, inputs=["SentenceTokenizer/SentencepieceTokenizeOp:0", 

190 "SentenceTokenizer/SentencepieceTokenizeOp:1"], 

191 infer_shapes=True, verbose=1, fLOG=print, 

192 overwrite={'SentenceTokenizer/SentencepieceTokenizeOp:0': (numpy.int32, None), 

193 'SentenceTokenizer/SentencepieceTokenizeOp:1': (numpy.int64, None)}) 

194 onnx.save(onx2, path2) 

195 """ 

196 if inputs is not None and not isinstance(inputs, list): 

197 inputs = [inputs] 

198 if outputs is not None and not isinstance(outputs, list): 

199 outputs = [outputs] 

200 if inputs is None: 

201 inputs = [i.name for i in model.graph.input] 

202 if outputs is None: 

203 outputs = [o.name for o in model.graph.output] 

204 

205 mark_var = {} 

206 for out in enumerate_model_node_outputs(model): 

207 mark_var[out] = 0 

208 for inp in inputs: 

209 mark_var[inp] = 0 

210 for out in outputs: 

211 if out not in mark_var: 

212 raise ValueError( # pragma: no cover 

213 f"Output '{out}' not found in model.") 

214 mark_var[out] = 1 

215 

216 nodes = list(model.graph.node[::-1]) 

217 mark_op = {} 

218 for node in list(nodes): 

219 mark_op[id(node)] = 0 

220 

221 # We mark all the nodes we need to keep. 

222 nb = 1 

223 while nb > 0: 

224 nb = 0 

225 for node in nodes: 

226 if mark_op[id(node)] == 1: 

227 continue 

228 mod = False 

229 for out in node.output: 

230 if mark_var[out] == 1: 

231 mark_op[id(node)] = 1 

232 mod = True 

233 break 

234 if not mod: 

235 continue 

236 

237 hidden = get_hidden_inputs([node]) 

238 node_inputs = list(node.input) + list(hidden) 

239 

240 nb += 1 

241 for inp in node_inputs: 

242 if inp in inputs: 

243 continue 

244 if mark_var.get(inp, 0) == 1: 

245 continue 

246 mark_var[inp] = 1 

247 nb += 1 

248 

249 # All nodes verifies mark_op[node.name] == 1 

250 keep_nodes = [node for node in nodes[::-1] if mark_op[id(node)] == 1] 

251 

252 if verbose > 1 and fLOG is not None: # pragma: no cover 

253 for node in nodes: 

254 s = "+" if mark_op[id(node)] == 1 else "-" 

255 fLOG("[select_model_inputs_outputs] %s %s (%s) -> %s [%s]" % ( 

256 s, node.op_type, ", ".join(node.input), 

257 ', '.join(node.output), node.name)) 

258 

259 known_shapes = {} 

260 if infer_shapes: 

261 shapes = shape_inference.infer_shapes(model) 

262 for shape in shapes.graph.value_info: # pylint: disable=E1101 

263 known_shapes[shape.name] = shape.type 

264 for shape in shapes.graph.input: # pylint: disable=E1101 

265 known_shapes[shape.name] = shape.type 

266 for shape in shapes.graph.output: # pylint: disable=E1101 

267 known_shapes[shape.name] = shape.type 

268 else: 

269 for shape in model.graph.input: # pylint: disable=E1101 

270 known_shapes[shape.name] = shape.type 

271 for shape in model.graph.output: # pylint: disable=E1101 

272 known_shapes[shape.name] = shape.type 

273 

274 var_in = [] 

275 for name in inputs: 

276 if overwrite is not None and name in overwrite: 

277 dtype, shape = overwrite[name] 

278 proto_dtype = guess_proto_dtype(dtype) 

279 value_info = make_tensor_value_info( 

280 name, proto_dtype, shape) 

281 elif name in known_shapes: 

282 info = known_shapes[name].tensor_type 

283 proto_dtype = info.elem_type 

284 if proto_dtype == 0: 

285 value_info = ValueInfoProto() 

286 value_info.name = name 

287 else: 

288 shape = get_tensor_shape(known_shapes[name]) 

289 value_info = make_tensor_value_info( 

290 name, proto_dtype, shape) 

291 else: 

292 value_info = ValueInfoProto() 

293 value_info.name = name 

294 var_in.append(value_info) 

295 

296 var_out = [] 

297 for name in outputs: 

298 if overwrite is not None and name in overwrite: 

299 dtype, shape = overwrite[name] 

300 proto_dtype = guess_proto_dtype(dtype) 

301 value_info = make_tensor_value_info( 

302 name, proto_dtype, shape) 

303 elif name in known_shapes: 

304 info = known_shapes[name].tensor_type 

305 proto_dtype = info.elem_type 

306 if proto_dtype == 0: 

307 value_info = ValueInfoProto() 

308 value_info.name = name 

309 else: 

310 shape = get_tensor_shape(known_shapes[name]) 

311 value_info = make_tensor_value_info( 

312 name, proto_dtype, shape) 

313 else: 

314 value_info = ValueInfoProto() 

315 value_info.name = name 

316 var_out.append(value_info) 

317 

318 if verbose > 0 and fLOG is not None: # pragma: no cover 

319 fLOG("[select_model_inputs_outputs] nodes %r --> %r" % ( 

320 len(model.graph.node), len(keep_nodes))) 

321 fLOG("[select_model_inputs_outputs] inputs: %r" % 

322 [_.name for _ in var_in]) 

323 fLOG("[select_model_inputs_outputs] inputs: %r" % 

324 [_.name for _ in var_out]) 

325 

326 graph = make_graph(keep_nodes, model.graph.name, var_in, 

327 var_out, model.graph.initializer, 

328 sparse_initializer=model.graph.sparse_initializer) 

329 onnx_model = make_model(graph, functions=model.functions) 

330 onnx_model.ir_version = model.ir_version 

331 onnx_model.producer_name = model.producer_name 

332 onnx_model.producer_version = model.producer_version 

333 onnx_model.domain = model.domain 

334 onnx_model.model_version = model.model_version 

335 onnx_model.doc_string = model.doc_string 

336 if len(model.metadata_props) > 0: # pragma: no cover 

337 values = {p.key: p.value for p in model.metadata_props} 

338 set_model_props(onnx_model, values) 

339 

340 del onnx_model.opset_import[:] # pylint: disable=E1101 

341 for oimp in model.opset_import: 

342 op_set = onnx_model.opset_import.add() # pylint: disable=E1101 

343 op_set.domain = oimp.domain 

344 op_set.version = oimp.version 

345 

346 # remove unused nodes 

347 if remove_unused: 

348 onnx_model = onnx_remove_node_unused(onnx_model, recursive=False) 

349 

350 return onnx_model 

351 

352 

353def change_input_type(onx, changes): 

354 """ 

355 Changes the type of an input. 

356 

357 :param onx: ONNX model 

358 :param changes: dictionary '{ name: new proto element type }` 

359 :return: new onx 

360 """ 

361 new_inputs = [] 

362 for inp in onx.graph.input: 

363 if inp.name not in changes: 

364 new_inputs.append(inp) 

365 continue 

366 value_info = make_tensor_value_info( 

367 inp.name, changes[inp.name], None) 

368 new_inputs.append(value_info) 

369 

370 # final 

371 graph = make_graph(list(onx.graph.node), 

372 onx.graph.name, new_inputs, 

373 list(onx.graph.output), 

374 onx.graph.initializer, 

375 sparse_initializer=onx.graph.sparse_initializer) 

376 onnx_model = make_model(graph, functions=onx.functions) 

377 onnx_model.ir_version = onx.ir_version 

378 onnx_model.producer_name = onx.producer_name 

379 onnx_model.producer_version = onx.producer_version 

380 onnx_model.domain = onx.domain 

381 onnx_model.model_version = onx.model_version 

382 onnx_model.doc_string = onx.doc_string 

383 if len(onx.metadata_props) > 0: # pragma: no cover 

384 values = {p.key: p.value for p in onx.metadata_props} 

385 set_model_props(onnx_model, values) 

386 

387 del onnx_model.opset_import[:] # pylint: disable=E1101 

388 for oimp in onx.opset_import: 

389 op_set = onnx_model.opset_import.add() # pylint: disable=E1101 

390 op_set.domain = oimp.domain 

391 op_set.version = oimp.version 

392 return onnx_model 

393 

394 

395def _change_subgraph_io_type_shape_list(io_list, type_changes, shape_changes): 

396 ms = False 

397 new_inputs = [] 

398 for inp in io_list: 

399 m = False 

400 if isinstance(shape_changes, dict): 

401 if inp.name in shape_changes: 

402 shape = shape_changes[inp.name] 

403 m = True 

404 else: 

405 shape = get_tensor_shape(inp) 

406 else: 

407 shape = shape_changes(inp) 

408 m = True 

409 

410 if isinstance(type_changes, dict): 

411 if inp.name in type_changes: 

412 ntype = type_changes[inp.name] 

413 m = True 

414 else: 

415 ntype = get_tensor_elem_type(inp) 

416 else: 

417 ntype = type_changes(inp) 

418 m = True 

419 

420 if m: 

421 ms = True 

422 value_info = make_tensor_value_info(inp.name, ntype, shape) 

423 new_inputs.append(value_info) 

424 else: 

425 new_inputs.append(inp) 

426 return new_inputs if ms else None 

427 

428 

429def change_subgraph_io_type_shape(onx, type_changes=None, shape_changes=None, 

430 recursive=True): 

431 """ 

432 Changes the type of an input or an output of a subgraph. 

433 

434 :param onx: ModelProto, GraphProto 

435 :param type_changes: dictionary '{ name: new proto element type }` 

436 or function `f(input) -> type` 

437 :param shape_changes: dictionary '{ name: new shape }` 

438 or function `f(input) -> shape` 

439 :param recursive: True 

440 :return: new onx 

441 """ 

442 if isinstance(onx, ModelProto): 

443 graph = change_subgraph_io_type_shape( 

444 onx.graph, type_changes, shape_changes, recursive) 

445 onnx_model = make_model(graph, functions=onx.functions) 

446 onnx_model.ir_version = onx.ir_version 

447 onnx_model.producer_name = onx.producer_name 

448 onnx_model.producer_version = onx.producer_version 

449 onnx_model.domain = onx.domain 

450 onnx_model.model_version = onx.model_version 

451 onnx_model.doc_string = onx.doc_string 

452 if len(onx.metadata_props) > 0: # pragma: no cover 

453 values = {p.key: p.value for p in onx.metadata_props} 

454 set_model_props(onnx_model, values) 

455 

456 del onnx_model.opset_import[:] # pylint: disable=E1101 

457 for oimp in onx.opset_import: 

458 op_set = onnx_model.opset_import.add() # pylint: disable=E1101 

459 op_set.domain = oimp.domain 

460 op_set.version = oimp.version 

461 return onnx_model 

462 

463 graph = onx 

464 new_inputs = _change_subgraph_io_type_shape_list( 

465 graph.input, type_changes or {}, shape_changes or {}) 

466 if new_inputs is None: 

467 new_inputs = graph.input 

468 

469 new_outputs = _change_subgraph_io_type_shape_list( 

470 graph.output, type_changes or {}, shape_changes or {}) 

471 if new_outputs is None: 

472 new_outputs = graph.output 

473 

474 # recursive 

475 if recursive: 

476 new_nodes = [] 

477 for node in list(graph.node): 

478 modified = False 

479 atts = [] 

480 for att in node.attribute: 

481 if (att.type == AttributeProto.GRAPH and 

482 hasattr(att, 'g') and att.g is not None): 

483 modified = True 

484 g = change_subgraph_io_type_shape( 

485 att.g, type_changes, shape_changes, 

486 recursive=recursive) 

487 att = make_attribute(att.name, g) 

488 atts.append(att) 

489 if modified: 

490 node = make_node(node.op_type, node.input, node.output) 

491 node.attribute.extend(atts) 

492 new_nodes.append(node) 

493 else: 

494 new_nodes = list(graph.node) 

495 

496 # final 

497 graph = make_graph(new_nodes, graph.name, new_inputs, new_outputs, 

498 graph.initializer, 

499 sparse_initializer=graph.sparse_initializer) 

500 return graph 

501 

502 

503def overwrite_opset(model, new_opset): 

504 """ 

505 Overwrites the main opset in an ONNX file. 

506 Does not change any node definition. 

507 

508 :param model: ONNX model 

509 :param new_opset: new opset 

510 :return: ONNX model 

511 """ 

512 graph = make_graph( 

513 model.graph.node, model.graph.name, model.graph.input, 

514 model.graph.output, model.graph.initializer, 

515 sparse_initializer=model.graph.sparse_initializer) 

516 onnx_model = make_model(graph, functions=model.functions) 

517 onnx_model.ir_version = model.ir_version 

518 onnx_model.producer_name = model.producer_name 

519 onnx_model.producer_version = model.producer_version 

520 onnx_model.domain = model.domain 

521 onnx_model.model_version = model.model_version 

522 onnx_model.doc_string = model.doc_string 

523 if len(model.metadata_props) > 0: # pragma: no cover 

524 values = {p.key: p.value for p in model.metadata_props} 

525 set_model_props(onnx_model, values) 

526 

527 del onnx_model.opset_import[:] # pylint: disable=E1101 

528 for oimp in model.opset_import: 

529 op_set = onnx_model.opset_import.add() # pylint: disable=E1101 

530 op_set.domain = oimp.domain 

531 op_set.version = new_opset if oimp.domain == '' else oimp.version 

532 return onnx_model 

533 

534 

535def hash_onnx_object(obj, max_size): 

536 """ 

537 Hashes the content of an object. 

538 It uses module :mod:`hashlib`. 

539 

540 :param obj: onnx graph (it must have a method `SerializeToString`) 

541 :param max_size: size of the hash 

542 :return: hash 

543 """ 

544 m = hashlib.sha256() 

545 if hasattr(obj, 'op_type'): 

546 # An operator. 

547 m.update(obj.op_type.encode('ascii')) 

548 m.update(str(len(obj.input)).encode('ascii')) 

549 m.update(str(len(obj.output)).encode('ascii')) 

550 if hasattr(obj, 'attribute'): 

551 for att in obj.attribute: 

552 m.update(att.name.encode('ascii')) 

553 m.update(att.SerializeToString()) 

554 else: 

555 # An initializer. 

556 name = obj.name 

557 docf = obj.doc_string 

558 obj.name = '' 

559 obj.doc_string = '' 

560 try: 

561 m.update(obj.SerializeToString()) 

562 except AttributeError as e: # pragma: no cover 

563 raise RuntimeError( 

564 f"Unable to hash object type {type(obj)!r}, value={obj!r}.") from e 

565 finally: 

566 obj.name = name 

567 obj.doc_string = docf 

568 

569 content = m.hexdigest() 

570 if len(content) > max_size: 

571 content = content[:max_size] 

572 return content.upper() 

573 

574 

575def onnx_rename_names(model, strategy='simple', recursive=True, 

576 verbose=0, fLOG=print, 

577 counts=None, replace=None, taken=None): 

578 """ 

579 Renames all names except the inputs and outputs. 

580 

581 :param model: onnx model 

582 :param strategy: two strategies are implemented, see below 

583 :param recursive: walk through subgraphs 

584 :param verbose: verbose, if positive, reports on all changed names 

585 :param fLOG: logging function 

586 :param counts: used for recursion 

587 :param replace: used for recursion, it can be also used to 

588 to fix some replacements 

589 :param taken: used for recursion 

590 :return: onnx model (the model is modified in place) 

591 

592 Strategies: 

593 

594 * `'simple'`: use a letter `n` for node, `r`, `i` for initializer, 

595 this letter is followed by a number 

596 * `'type'`: the name depends on the node type and content, 

597 the hash is kept as small as possible 

598 """ 

599 counts = counts or {'init': 0, 'node': 0, 'result': 0} 

600 replace = replace or {} 

601 taken = taken or set() 

602 graph = model.graph if hasattr(model, 'graph') else model 

603 

604 for obj in graph.input: 

605 replace[obj.name] = obj.name 

606 for obj in graph.output: 

607 replace[obj.name] = obj.name 

608 

609 def _check_name_simple(prefix): 

610 if prefix not in replace: 

611 return prefix 

612 c = 1 

613 final = "%s_%d" % (prefix, c) 

614 while final in taken: 

615 c += 1 

616 final = "%s_%d" % (prefix, c) 

617 taken.add(final) 

618 return final 

619 

620 def _check_name_type(obj, prefix): 

621 c = 2 

622 hash = hash_onnx_object(obj, c) 

623 final = f"{prefix}_{hash}" 

624 while final in taken: 

625 c += 2 

626 hash = hash_onnx_object(obj, c) 

627 final = f"{prefix}_{hash}" 

628 taken.add(final) 

629 return final 

630 

631 def get_name_init(init): 

632 if init.name in replace: 

633 return replace[init.name] 

634 if strategy == 'simple': 

635 name = _check_name_simple('i%d' % counts['init']) 

636 counts['init'] += 1 

637 replace[init.name] = name 

638 if verbose > 0 and fLOG is not None: 

639 fLOG(f'[onnx_rename_names] init: {init.name!r} -> {name!r}') 

640 return name 

641 if strategy == 'type': 

642 name = _check_name_type(init, 'i') 

643 counts['init'] += 1 

644 replace[init.name] = name 

645 if verbose > 0 and fLOG is not None: 

646 fLOG(f'[onnx_rename_names] init: {init.name!r} -> {name!r}') 

647 return name 

648 raise ValueError( # pragma: no cover 

649 f"Unknown strategy {strategy!r}.") 

650 

651 def get_name_node(node): 

652 node_name = 'node_%s_%d' % (node.name, id(node)) 

653 if node_name in replace: 

654 return replace[node_name] 

655 if strategy == 'simple': 

656 name = _check_name_simple('n%d' % counts['node']) 

657 counts['node'] += 1 

658 replace[node_name] = name 

659 if verbose > 0 and fLOG is not None: 

660 fLOG(f'[onnx_rename_names] node: {node_name!r} -> {name!r}') 

661 return name 

662 if strategy == 'type': 

663 name = _check_name_type(node, 'n') 

664 counts['node'] += 1 

665 replace[node_name] = name 

666 if verbose > 0 and fLOG is not None: 

667 fLOG(f'[onnx_rename_names] node: {node_name!r} -> {name!r}') 

668 return name 

669 raise ValueError( # pragma: no cover 

670 f"Unknown strategy {strategy!r}.") 

671 

672 def get_name_result(node, i, name, suffix): 

673 if name in replace: 

674 return replace[name] 

675 if strategy == 'simple': 

676 new_name = _check_name_simple('r%d' % counts['result']) 

677 counts['result'] += 1 

678 replace[name] = new_name 

679 if verbose > 0 and fLOG is not None: 

680 fLOG(f'[onnx_rename_names] result: {name!r} -> {new_name!r}') 

681 return new_name 

682 if strategy == 'type': 

683 new_name = _check_name_type(node, 'r%s%d' % (suffix, i)) 

684 counts['result'] += 1 

685 replace[name] = new_name 

686 if verbose > 0 and fLOG is not None: 

687 fLOG(f'[onnx_rename_names] result: {name!r} -> {new_name!r}') 

688 return new_name 

689 raise ValueError( # pragma: no cover 

690 f"Unknown strategy {strategy!r}.") 

691 

692 def get_name_input(node, i): 

693 return get_name_result(node, i, node.input[i], 'i') 

694 

695 def get_name_output(node, i): 

696 return get_name_result(node, i, node.output[i], 'o') 

697 

698 for init in graph.initializer: 

699 init.name = get_name_init(init) 

700 for init in graph.sparse_initializer: 

701 init.name = get_name_init(init) 

702 

703 for node in list(graph.node): 

704 node.name = get_name_node(node) 

705 for i in range(len(node.input)): # pylint: disable=C0200 

706 node.input[i] = get_name_input(node, i) 

707 for i in range(len(node.output)): # pylint: disable=C0200 

708 node.output[i] = get_name_output(node, i) 

709 if not recursive or node.op_type not in {'Scan', 'If', 'Loop'}: 

710 continue 

711 # recursion 

712 for att in node.attribute: 

713 if att.name not in {'if_branch', 'else_branch', 'body'}: 

714 continue 

715 onnx_rename_names( 

716 att.g, strategy=strategy, fLOG=fLOG, verbose=verbose, 

717 counts=counts, replace=replace, taken=taken) 

718 

719 return model 

720 

721 

722def onnx_rename_inputs_outputs(onx, rename): 

723 """ 

724 Renames input or outputs names. 

725 

726 :param onx: GraphProto, ModelProto 

727 :param rename: dictionary `{old_name: new_name}` 

728 :return: new onx 

729 """ 

730 if isinstance(onx, ModelProto): 

731 graph = onnx_rename_inputs_outputs(onx.graph, rename) 

732 onnx_model = make_model(graph, functions=onx.functions) 

733 onnx_model.ir_version = onx.ir_version 

734 onnx_model.producer_name = onx.producer_name 

735 onnx_model.producer_version = onx.producer_version 

736 onnx_model.domain = onx.domain 

737 onnx_model.model_version = onx.model_version 

738 onnx_model.doc_string = onx.doc_string 

739 if len(onx.metadata_props) > 0: # pragma: no cover 

740 values = {p.key: p.value for p in onx.metadata_props} 

741 set_model_props(onnx_model, values) 

742 

743 del onnx_model.opset_import[:] # pylint: disable=E1101 

744 for oimp in onx.opset_import: 

745 op_set = onnx_model.opset_import.add() # pylint: disable=E1101 

746 op_set.domain = oimp.domain 

747 op_set.version = oimp.version 

748 return onnx_model 

749 

750 graph = onx 

751 new_inputs = [] 

752 for inp in graph.input: 

753 if inp.name not in rename: 

754 new_inputs.append(inp) 

755 continue 

756 value_info = make_tensor_value_info( 

757 rename[inp.name], get_tensor_elem_type(inp), get_tensor_shape(inp)) 

758 new_inputs.append(value_info) 

759 

760 new_outputs = [] 

761 for inp in graph.output: 

762 if inp.name not in rename: 

763 new_outputs.append(inp) 

764 continue 

765 value_info = make_tensor_value_info( 

766 rename[inp.name], get_tensor_elem_type(inp), get_tensor_shape(inp)) 

767 new_outputs.append(value_info) 

768 

769 new_inits = [] 

770 for init in graph.initializer: 

771 if init.name in rename: 

772 init.name = rename[init.name] 

773 new_inits.append(init) 

774 

775 new_sparse_inits = [] 

776 for init in graph.sparse_initializer: 

777 if init.name in rename: 

778 init.name = rename[init.name] 

779 new_sparse_inits.append(init) 

780 

781 new_nodes = [] 

782 for node in list(graph.node): 

783 modified = False 

784 atts = [] 

785 for att in node.attribute: 

786 if (att.type == AttributeProto.GRAPH and 

787 hasattr(att, 'g') and att.g is not None): 

788 modified = True 

789 g = onnx_rename_inputs_outputs(att.g, rename) 

790 att = make_attribute(att.name, g) 

791 atts.append(att) 

792 if modified: 

793 node = make_node(node.op_type, node.input, node.output) 

794 node.attribute.extend(atts) 

795 

796 inp = [rename.get(i, i) for i in node.input] 

797 out = [rename.get(i, i) for i in node.output] 

798 if inp == list(node.input) and out == list(node.output): 

799 new_nodes.append(node) 

800 continue 

801 

802 node = make_node(node.op_type, inp, out, domain=node.domain, 

803 name=node.name) 

804 node.attribute.extend(atts) 

805 new_nodes.append(node) 

806 

807 # final 

808 graph = make_graph(new_nodes, graph.name, new_inputs, new_outputs, 

809 new_inits, sparse_initializer=new_sparse_inits) 

810 return graph 

811 

812 

813def onnx_replace_functions(model, replace): 

814 """ 

815 Replaces some of the function in model. 

816 

817 :param model: *ModelProto* 

818 :param replace: dictionary `{ (domain, name): FunctionProto }` 

819 :return: new model 

820 """ 

821 if not isinstance(model, ModelProto): 

822 raise TypeError( # pragma: no cover 

823 f"Unexpected type {type(model)!r}.") 

824 new_functions = [] 

825 modified = False 

826 for fct in model.functions: 

827 key = fct.domain, fct.name 

828 if key in replace: 

829 modified = True 

830 f = replace[key] 

831 if not isinstance(f, FunctionProto): 

832 raise TypeError( # pragma: no cover 

833 f"Unexpected type {type(f)!r} for function {key!r} in replace.") 

834 if len(f.input) != len(fct.input): 

835 raise ValueError( # pragma: no cover 

836 f"Input mismatches {f.input!r} != {fct.input!r} (expected).") 

837 if len(f.output) != len(fct.output): 

838 raise ValueError( # pragma: no cover 

839 f"Output mismatches {f.output!r} != {fct.output!r} (expected).") 

840 new_functions.append(f) 

841 else: 

842 new_functions.append(fct) 

843 if not modified: 

844 return model 

845 opsets = [make_operatorsetid(op.domain, op.version) 

846 for op in model.opset_import] 

847 onnx_model = make_model( 

848 model.graph, opset_imports=opsets, functions=new_functions) 

849 onnx_model.ir_version = model.ir_version 

850 onnx_model.producer_name = model.producer_name 

851 onnx_model.producer_version = model.producer_version 

852 onnx_model.domain = model.domain 

853 onnx_model.model_version = model.model_version 

854 onnx_model.doc_string = model.doc_string 

855 if len(model.metadata_props) > 0: # pragma: no cover 

856 values = {p.key: p.value for p in model.metadata_props} 

857 set_model_props(onnx_model, values) 

858 return onnx_model 

859 

860 

861def insert_results_into_onnx(model, results, as_parameter=True, suffix='_DBG', 

862 param_name=None, node_type='DEBUG', 

863 domain='DEBUG', domain_opset=1): 

864 """ 

865 Inserts results into an ONNX graph to produce an extended 

866 ONNX graph. It can be saved and looked into with a tool such as 

867 :epkg:`netron`. 

868 

869 :param model: ONNX graph 

870 :param results: results to be added in a dictionary 

871 :param as_parameter: add new nodes with results as one parameter 

872 (True) or as initializer (False) 

873 :param suffix: suffix to add to new results 

874 :param param_name: name of the parameter to add 

875 (by default the result name), it can be a function 

876 `param_name(reult_name) -> parameter_name` 

877 :param node_type: type of the new node 

878 :param domain: domain the new node 

879 :param domain_opset: opset for *domain* 

880 :return: new ONNX graph 

881 

882 See method :meth:`OnnxInference.run2onnx 

883 <mlprodict.onnxrt.onnx_inference.OnnxInference.run2onnx>` 

884 to see a graph this function produces. 

885 

886 .. image:: debug.png 

887 

888 .. versionadded:: 0.7 

889 """ 

890 inputs = list(model.graph.input) 

891 outputs = list(model.graph.output) 

892 inits = list(model.graph.initializer) 

893 inits_sparse = list(model.graph.sparse_initializer) 

894 node_list = list(model.graph.node) 

895 nodes = {id(n): n for n in node_list} 

896 order = {id(n): i for i, n in enumerate(node_list)} 

897 nodes_copy = {} 

898 

899 names_init = (set(init.name for init in inits) | 

900 set(init.name for init in inits_sparse)) 

901 names_input = set(init.name for init in inputs) 

902 names_output = {} 

903 for node in nodes.values(): 

904 for i, o in enumerate(node.output): 

905 names_output[o] = (i, node) 

906 

907 for k, v in results.items(): 

908 if k in names_init: 

909 # initializer are not inserted again 

910 continue 

911 if k in names_input: 

912 # inputs are added as 

913 raise NotImplementedError( 

914 f"Unable to add debug information on input {k!r}.") 

915 

916 if k not in names_output: 

917 raise RuntimeError( 

918 "Unable to find result %r in the ONNX graph. Available=" 

919 "[%s]." % (k, ", ".join(sorted(names_output)))) 

920 

921 index, node = names_output[k] 

922 new_name = k + suffix 

923 

924 if id(node) not in nodes_copy: 

925 new_node = make_node( 

926 node.op_type, list(node.input), list(node.output), 

927 domain=node.domain if node.domain else None, 

928 name=node.name + suffix) 

929 new_node.attribute.extend(node.attribute) # pylint: disable=E1101 

930 nodes_copy[id(node)] = new_node 

931 order[id(new_node)] = order[id(node)] 

932 new_node = nodes_copy[id(node)] 

933 new_node.output[index] = new_name 

934 

935 if as_parameter: 

936 pname = k if param_name is None else param_name(k) 

937 atts = {pname: from_array(v, name=pname)} 

938 inserted_node = make_node( 

939 node_type, [new_name], [k], domain=domain, 

940 **atts) 

941 else: 

942 pname = k if param_name is None else param_name(k) 

943 pname += suffix + 'i' 

944 inserted_node = make_node( 

945 node_type, [new_name, pname], [k], domain=domain) 

946 inits.append(from_array(v, name=pname)) 

947 

948 order[id(inserted_node)] = order[id(node)] + 1. / (index + 2) 

949 nodes[id(inserted_node)] = inserted_node 

950 

951 new_nodes = [(order[id(n)], n) 

952 for n in nodes.values() if id(n) not in nodes_copy] 

953 new_nodes.extend((order[id(n)], n) for n in nodes_copy.values()) 

954 new_nodes = [n[1] for n in sorted(new_nodes)] 

955 

956 graph = make_graph(new_nodes, model.graph.name, inputs, outputs, 

957 inits, sparse_initializer=inits_sparse) 

958 onnx_model = make_model(graph, functions=model.functions) 

959 onnx_model.ir_version = model.ir_version 

960 onnx_model.producer_name = model.producer_name 

961 onnx_model.producer_version = model.producer_version 

962 onnx_model.domain = model.domain 

963 onnx_model.model_version = model.model_version 

964 onnx_model.doc_string = model.doc_string 

965 if len(model.metadata_props) > 0: # pragma: no cover 

966 values = {p.key: p.value for p in model.metadata_props} 

967 set_model_props(onnx_model, values) 

968 

969 del onnx_model.opset_import[:] # pylint: disable=E1101 

970 for oimp in model.opset_import: 

971 op_set = onnx_model.opset_import.add() # pylint: disable=E1101 

972 op_set.domain = oimp.domain 

973 op_set.version = oimp.version 

974 op_set = onnx_model.opset_import.add() # pylint: disable=E1101 

975 op_set.domain = domain 

976 op_set.version = domain_opset 

977 return onnx_model 

978 

979 

980def onnx_model_to_function(onx, name=None, domain="custom", 

981 opset_imports=None, doc_string=None, 

982 inputs2par=None): 

983 """ 

984 Converts an ONNX model into a function. The returned function 

985 has no attribute. 

986 

987 :param onx: onnx model 

988 :param name: function name 

989 :param domain: function domain 

990 :param opset_imports: opset to import as a dictionary 

991 `{domain: version}` 

992 :param doc_string: doc string 

993 :param inputs2par: dictionary to move some inputs as attributes 

994 `{ name: None or default value }` 

995 :return: function, other functions 

996 

997 .. warning:: 

998 :epkg:`FunctionProto` does not support default values yet. 

999 They are ignored. 

1000 """ 

1001 if isinstance(onx, ModelProto): 

1002 if opset_imports is None: 

1003 domains = {} 

1004 for op in onx.opset_import: 

1005 domains[op.domain] = op.version 

1006 opset_imports = domains 

1007 if doc_string is None: 

1008 doc_string = onx.doc_string 

1009 fp, lf = onnx_model_to_function( 

1010 onx.graph, name=name, domain=domain, 

1011 opset_imports=opset_imports, doc_string=doc_string, 

1012 inputs2par=inputs2par) 

1013 return fp, lf + list(onx.functions) 

1014 

1015 if not isinstance(onx, GraphProto): 

1016 raise TypeError( # pragma: no cover 

1017 f"Unexpected type {type(onx)!r} for onx.") 

1018 

1019 if name is None: 

1020 name = onx.name 

1021 

1022 inputs = [] 

1023 outputs = [o.name for o in onx.output] 

1024 attributes = [] 

1025 nodes = [] 

1026 if inputs2par is None: 

1027 inputs.extend(i.name for i in onx.input) 

1028 else: 

1029 for i in onx.input: 

1030 if i.name not in inputs2par: 

1031 inputs.append(i.name) 

1032 continue 

1033 attributes.append(i.name) 

1034 

1035 if len(onx.initializer) > 0 or len(onx.sparse_initializer) > 0: 

1036 # Needs to convert every initializer into Constant. 

1037 csts = [] 

1038 for init in onx.initializer: 

1039 v = _var_as_dict(init) 

1040 value = from_array(v['value']) 

1041 n = make_node('Constant', [], [init.name], value=value) 

1042 csts.append(n) 

1043 for init in onx.sparse_initializer: 

1044 v = _var_as_dict(init) 

1045 value = from_array(v['sparse_value']) 

1046 n = make_node('Constant', [], [init.name], sparse_value=value) 

1047 csts.append(n) 

1048 nodes.extend(csts) 

1049 

1050 nodes.extend(onx.node) 

1051 

1052 if isinstance(opset_imports, dict): 

1053 ops = [make_operatorsetid(k, v) for k, v in opset_imports.items()] 

1054 opset_imports = ops 

1055 return make_function( 

1056 domain, name, inputs, outputs, nodes, 

1057 opset_imports=opset_imports, doc_string=doc_string or '', 

1058 attributes=attributes), [] 

1059 

1060 

1061def _onnx_function_to_model_convert_io(ens, type_info, shape_fct): 

1062 typed_io = [] 

1063 for name in ens: 

1064 if isinstance(type_info, dict): 

1065 res = type_info[name] 

1066 elif callable(type_info): 

1067 res = type_info(name) 

1068 else: 

1069 raise TypeError( # pragma: no cover 

1070 "type_info is not a callable or a dictionary, " 

1071 "unable to guess type for name=%r with " 

1072 "type(type_info)=%r." % (name, type(type_info))) 

1073 if isinstance(res, int): 

1074 proto_dtype = res 

1075 else: 

1076 proto_dtype = guess_proto_dtype(res) 

1077 value_info = make_tensor_value_info( 

1078 name, proto_dtype, shape_fct(name, proto_dtype)) 

1079 typed_io.append(value_info) 

1080 return typed_io 

1081 

1082 

1083def onnx_function_to_model(onx, functions=None, type_info=None, 

1084 as_function=False, shape_fct=None): 

1085 """ 

1086 Converts an ONNX FunctionProto into a ModelProto. 

1087 The function does not handle attributes yet. 

1088 

1089 :param onx: onnx function 

1090 :param functions: additional functions to append to the model 

1091 :param type_info: dictionary or callable which returns the type of 

1092 inputs or outputs if it cannot be guessed 

1093 :param as_function: if True, the function stays as a function and a single node 

1094 is created to call that function 

1095 :param shape_fct: function to specify the shapes, 

1096 signature: `shape_fct(name, proto_type) -> list` 

1097 :return: function 

1098 """ 

1099 if not isinstance(onx, FunctionProto): 

1100 raise TypeError( # pragma: no cover 

1101 f"onx must be a FunctionProto not {type(onx)!r}.") 

1102 if len(onx.attribute) > 0: 

1103 raise NotImplementedError( # pragma: no cover 

1104 "The function has attributes, it is not implemented yet.") 

1105 

1106 if isinstance(functions, list): 

1107 added_functions = functions.copy() 

1108 elif isinstance(functions, dict): 

1109 added_functions = list(functions.values()) 

1110 elif functions is None: 

1111 added_functions = [] 

1112 else: 

1113 raise TypeError( # pragma: no cover 

1114 f"Unexpected type for functions {type(functions)!r}.") 

1115 

1116 if shape_fct is None: 

1117 shape_fct = lambda name, dtype: None 

1118 

1119 inputs = _onnx_function_to_model_convert_io( 

1120 onx.input, type_info, shape_fct=shape_fct) 

1121 outputs = _onnx_function_to_model_convert_io( 

1122 onx.output, type_info, shape_fct=shape_fct) 

1123 if as_function: 

1124 nodes = [make_node(onx.name, 

1125 [i.name for i in inputs], 

1126 [o.name for o in outputs], 

1127 domain=onx.domain)] 

1128 added_functions.append(onx) 

1129 opsets = [make_operatorsetid(onx.domain, 1)] 

1130 else: 

1131 nodes = list(onx.node) 

1132 opsets = [make_operatorsetid(op.domain, op.version) 

1133 for op in onx.opset_import] 

1134 graph = make_graph(nodes, onx.name, inputs, outputs, 

1135 [], doc_string=onx.doc_string) 

1136 model = make_model(graph, functions=added_functions, 

1137 opset_imports=opsets, 

1138 doc_string=onx.doc_string, 

1139 model_version=1, 

1140 domain=onx.domain) 

1141 return model 

1142 

1143 

1144def _get_new_name(prefix, name, existing_names): 

1145 opt = f"{prefix}_{name}_0" 

1146 i = 0 

1147 while opt in existing_names: 

1148 i += 1 

1149 opt = "%s_%s_%d" % (prefix, name, i) 

1150 existing_names.add(opt) 

1151 return opt 

1152 

1153 

1154def onnx_subgraphs_level(obj): 

1155 """ 

1156 Returns the depth of the graph. 

1157 

1158 :param obj: onnx object 

1159 :return: integer 

1160 """ 

1161 if isinstance(obj, ModelProto): 

1162 return onnx_subgraphs_level(obj.graph) 

1163 best = 0 

1164 for node in obj.node: 

1165 for att in node.attribute: 

1166 if (att.type == AttributeProto.GRAPH and 

1167 hasattr(att, 'g') and att.g is not None): 

1168 m = onnx_subgraphs_level(att.g) 

1169 if m > best: 

1170 best = m 

1171 return best + 1 

1172 

1173 

1174class _inline_mapping(dict): 

1175 """ 

1176 Overwrites class dictionary to debug more easily. 

1177 

1178 :param verbose: verbosity 

1179 :param fLOG: logging function 

1180 :param level: sub graph level 

1181 """ 

1182 

1183 def __init__(self, verbose, fLOG, level): 

1184 dict.__init__(self) 

1185 self._verbose = verbose 

1186 self._fLOG = fLOG 

1187 self._level = level 

1188 

1189 def __setitem__(self, key, value): 

1190 "Adds a value." 

1191 if self._verbose > 3: 

1192 self._fLOG("[_inline_mapping-dict-addkv] %s + %r: %r" % 

1193 (" " * self._level, key, value)) 

1194 if key in self: 

1195 raise RuntimeError( # pragma: no cover 

1196 "Key %r was already added (with value %r, new one is %r)." 

1197 "" % (key, self[key], value)) 

1198 dict.__setitem__(self, key, value) 

1199 

1200 def update(self, d): 

1201 "Updates many values." 

1202 for k, v in d.items(): 

1203 self[k] = v 

1204 

1205 def copy(self): 

1206 "Returns a copy." 

1207 m = _inline_mapping(self._verbose, self._fLOG, self._level) 

1208 for k, v in self.items(): 

1209 m[k] = v 

1210 return m 

1211 

1212 def remove(self, o): 

1213 "Removes one element." 

1214 if o not in self: 

1215 raise KeyError( # pragma: no cover 

1216 f"Cannot remove a key {o!r}.") 

1217 self.pop(o) 

1218 

1219 

1220def _onnx_inline_function_graph(graph, protos, existing_names, mapping, 

1221 verbose, fLOG, rename, level): 

1222 if len(graph.node) == 0: 

1223 # Outputs have still to be renamed. 

1224 graph0 = graph 

1225 if verbose > 1: 

1226 fLOG( # pragma: no cover 

1227 "[onnx_inline_function-graph] %s visit0 graph=%d rename=%r " 

1228 "len(mapping)=%d begin" % ( 

1229 " " * level, id(graph), rename, len(mapping))) 

1230 if rename: 

1231 modified_nodes = [] 

1232 mapping = mapping.copy() 

1233 for i in graph.input: 

1234 mapping[i.name] = i.name 

1235 for i in graph.initializer: 

1236 mapping[i.name] = i.name 

1237 for i in graph.sparse_initializer: 

1238 mapping[i.name] = i.name 

1239 outputs = [] 

1240 for o in graph.output: 

1241 no = make_value_info(mapping[o.name], o.type) 

1242 if no.name != o.name: 

1243 modified_nodes.append(o) 

1244 outputs.append(no) 

1245 else: 

1246 outputs.append(o) 

1247 if len(modified_nodes) > 0: 

1248 graph = make_graph( 

1249 [], graph.name, graph.input, outputs, 

1250 graph.initializer, doc_string=graph.doc_string, 

1251 sparse_initializer=list(graph.sparse_initializer)) 

1252 else: 

1253 modified_nodes = [] 

1254 

1255 if verbose > 1: 

1256 fLOG( # pragma: no cover 

1257 "[onnx_inline_function-graph] %s visit graph=%d end " 

1258 "changed=%r len(modified_nodes)=%d" % ( 

1259 " " * level, id(graph0), id(graph0) != id(graph), 

1260 len(modified_nodes))) 

1261 

1262 return graph, modified_nodes 

1263 

1264 graph0 = graph 

1265 mapping = mapping.copy() 

1266 init = list(graph.initializer) 

1267 init_sparse = list(graph.sparse_initializer) 

1268 inputs = list(graph.input) 

1269 modified_nodes = [] 

1270 outputs = list(graph.output) 

1271 

1272 if verbose > 1: 

1273 fLOG("[onnx_inline_function-graph] %s >visit graph=%d rename=%r " 

1274 "len(mapping)=%d begin" % ( 

1275 " " * level, id(graph), rename, len(mapping))) 

1276 

1277 output_names = [o.name for o in outputs] 

1278 for i in init: 

1279 mapping[i.name] = i.name 

1280 for i in init_sparse: 

1281 mapping[i.name] = i.name 

1282 for i in inputs: 

1283 mapping[i.name] = i.name 

1284 

1285 # first step, replace names 

1286 nodes = [] 

1287 for node in list(graph.node): 

1288 mod = 0 

1289 inp = [] 

1290 for i in node.input: 

1291 if i in mapping: 

1292 inp.append(mapping[i]) 

1293 if mapping[i] != i: 

1294 mod += 1 

1295 else: 

1296 raise RuntimeError( # pragma: no cover 

1297 "Cannot find input %r in %s for node (level=%d)\n%r." % ( 

1298 i, pprint.pformat(mapping), level, node)) 

1299 out = [] 

1300 for o in node.output: 

1301 new_o = o 

1302 if rename: 

1303 if o not in output_names: 

1304 new_o = _get_new_name('_inl', o, existing_names) 

1305 if o in mapping: 

1306 # See below. 

1307 mapping.remove(o) 

1308 elif o in mapping: 

1309 # That means the main contains a result node but is overwritten by 

1310 # the subgraph. The local variable cannot be reached anymore, 

1311 # we remove it. 

1312 mapping.remove(o) 

1313 if o in node.input: 

1314 new_o = _get_new_name('_inl', o, existing_names) 

1315 if verbose > 3: 

1316 fLOG( 

1317 "[onnx_inline_function-renam] %s node %r(%r): %r -> %r " 

1318 "overwrite result (%r -> %r)." % ( 

1319 " " * level, node.op_type, node.name, node.input, 

1320 node.output, o, new_o)) 

1321 out.append(new_o) 

1322 mapping[o] = new_o 

1323 if o != new_o: 

1324 mapping[new_o] = new_o 

1325 mod += 1 

1326 

1327 if verbose > 3: 

1328 fLOG("[onnx_inline_function-renam] %s rep node %r(%r): %r -> %r" % ( 

1329 " " * level, node.op_type, node.name, node.input, node.output)) 

1330 new_node = make_node(node.op_type, inp, out, domain=node.domain, 

1331 name=_get_new_name('_inln', node.name, existing_names)) 

1332 for att in node.attribute: 

1333 if (att.type == AttributeProto.GRAPH and 

1334 hasattr(att, 'g') and att.g is not None): 

1335 g, m = _onnx_inline_function_graph( 

1336 att.g, protos, existing_names=existing_names, 

1337 verbose=verbose, fLOG=fLOG, mapping=mapping, 

1338 rename=rename, level=level + 1) 

1339 if len(m) > 0: 

1340 att = make_attribute(att.name, g) 

1341 mod += len(m) 

1342 else: 

1343 att = make_attribute(att.name, att.g) 

1344 new_node.attribute.append(att) 

1345 if mod > 0: 

1346 if verbose > 2: 

1347 fLOG("[onnx_inline_function-renam] %s add node %r(%r): %r -> %r" % ( 

1348 " " * level, 

1349 new_node.op_type, new_node.name, 

1350 new_node.input, new_node.output)) 

1351 nodes.append(new_node) 

1352 modified_nodes.append(node) 

1353 else: 

1354 nodes.append(node) 

1355 

1356 if len(modified_nodes) > 0: 

1357 if verbose > 1: 

1358 fLOG("[onnx_inline_function-graph] %s -1 graph=%d " 

1359 "len(modified_nodes)=%d" % ( 

1360 " " * level, id(graph), len(modified_nodes))) 

1361 

1362 graph = make_graph( 

1363 nodes, graph.name, inputs, outputs, 

1364 init, doc_string=graph.doc_string, 

1365 sparse_initializer=list(graph.sparse_initializer)) 

1366 elif not rename: 

1367 # no modification, let's check the node hiding a functions 

1368 new_nodes = [] 

1369 for node in nodes: 

1370 nnodes, m = _onnx_inline_function_node( 

1371 node, protos, existing_names, verbose, fLOG, 

1372 level=level) 

1373 if len(m) > 0: 

1374 if verbose > 0: 

1375 fLOG("[onnx_inline_function-subgr] %s replaced node %r (%r) " 

1376 "with %d nodes (id=%r) -- %r -> %r" % ( 

1377 " " * level, 

1378 node.name, node.op_type, len(nnodes), id(node), 

1379 node.input, node.output)) 

1380 new_nodes.extend(nnodes) 

1381 modified_nodes.extend(m) 

1382 else: 

1383 new_nodes.append(node) 

1384 if len(modified_nodes) > 0: 

1385 if verbose > 1: 

1386 fLOG("[onnx_inline_function-graph] %s -2 graph=%d " 

1387 "len(modified_nodes)=%d" % ( 

1388 " " * level, id(graph), len(modified_nodes))) 

1389 

1390 nodes = new_nodes 

1391 graph = make_graph( 

1392 nodes, graph.name, inputs, outputs, 

1393 init, doc_string=graph.doc_string, 

1394 sparse_initializer=list(graph.sparse_initializer)) 

1395 

1396 if verbose > 1: 

1397 fLOG("[onnx_inline_function-graph] %s <visit graph=%d end " 

1398 "changed=%r len(modified_nodes)=%d" % ( 

1399 " " * level, id(graph0), id(graph0) != id(graph), 

1400 len(modified_nodes))) 

1401 

1402 return graph, modified_nodes 

1403 

1404 

1405def _onnx_inline_function_node(node, protos, existing_names, verbose, 

1406 fLOG, level): 

1407 # The function does not rename input or output 

1408 # of the node, it just replaces the node but a function 

1409 # if the function exists. 

1410 modified_nodes = [] 

1411 key = node.domain, node.op_type 

1412 if key in protos: 

1413 proto = protos[key] 

1414 if not isinstance(proto, FunctionProto): 

1415 raise TypeError( # pragma: no cover 

1416 "Prototype for key=%r must be a Function Proto, not %r." % ( 

1417 key, type(proto))) 

1418 modified_nodes.append(node) 

1419 new_nodes = [] 

1420 mapping = _inline_mapping(verbose, fLOG, level) 

1421 prefix = "_inl" 

1422 

1423 for fr, to in zip(node.input, proto.input): 

1424 n = make_node('Identity', [fr], 

1425 [_get_new_name(prefix, to, existing_names)]) 

1426 if verbose > 2: 

1427 fLOG("[onnx_inline_function-ninpu] %s add node %r(%r): %r -> %r" % ( 

1428 " " * level, n.op_type, n.name, n.input, n.output)) 

1429 mapping[to] = n.output[0] 

1430 if to != n.output[0]: 

1431 mapping[n.output[0]] = n.output[0] 

1432 new_nodes.append(n) 

1433 

1434 for nn in proto.node: 

1435 new_input = [mapping[i] for i in nn.input] 

1436 new_output = [_get_new_name(prefix, o, existing_names) 

1437 for o in nn.output] 

1438 mapping.update( 

1439 {o: oo for o, oo in zip(nn.output, new_output)}) 

1440 mapping.update({oo: oo for oo in new_output}) 

1441 new_node = make_node( 

1442 nn.op_type, new_input, new_output, 

1443 domain=nn.domain, name=_get_new_name( 

1444 prefix, nn.name, existing_names)) 

1445 if verbose > 3: 

1446 fLOG("[onnx_inline_function-nnode] %s rep node %r(%r): %r -> %r" % ( 

1447 " " * level, nn.op_type, nn.name, nn.input, nn.output)) 

1448 if verbose > 2: 

1449 fLOG("[onnx_inline_function-nnode] %s add node %r(%r): %r -> %r" % ( 

1450 " " * level, 

1451 new_node.op_type, new_node.name, 

1452 new_node.input, new_node.output)) 

1453 for att in nn.attribute: 

1454 if (att.type == AttributeProto.GRAPH and 

1455 hasattr(att, 'g') and att.g is not None): 

1456 if verbose > 1: 

1457 fLOG("[onnx_inline_function-funct] %s fct=%r graph=%d node=%d" % ( 

1458 " " * level, key, id(att.g), id(new_node))) 

1459 

1460 g, m = _onnx_inline_function_graph( 

1461 att.g, protos, existing_names=existing_names, 

1462 verbose=verbose, fLOG=fLOG, mapping=mapping, 

1463 rename=True, level=level + 1) 

1464 if len(m) > 0: 

1465 att = make_attribute(att.name, g) 

1466 else: 

1467 att = make_attribute(att.name, att.g) 

1468 new_node.attribute.append(att) 

1469 new_nodes.append(new_node) 

1470 

1471 for fr, to in zip(proto.output, node.output): 

1472 n = make_node('Identity', [mapping[fr]], [to]) 

1473 if verbose > 2: 

1474 fLOG("[onnx_inline_function-noutt] %s add node %r(%r): %r -> %r" % ( 

1475 " " * level, n.op_type, n.name, n.input, n.output)) 

1476 new_nodes.append(n) 

1477 else: 

1478 new_nodes = [node] 

1479 modified_nodes = [] 

1480 return new_nodes, modified_nodes 

1481 

1482 

1483def onnx_inline_function(obj, protos=None, existing_names=None, verbose=0, fLOG=None): 

1484 """ 

1485 Inlines functions in an ONNX graph. 

1486 

1487 :param obj: onnx graph, :epkg:`FunctionProto`, :epkg:`GraphProto`, 

1488 :epkg:`ModelProto` 

1489 :param protos: if None, the function assumes *obj* is of type 

1490 :epkg:`ModelProto` and the goal is to inline every function. 

1491 If *protos* a list of strings, the function only inlines the 

1492 functions in that list. If *protos* is a dictionary 

1493 `{ (domain, type): FunctionProto }`, the function replaces every 

1494 node `(domain, type)` by the code given in this dictionary 

1495 :param existing_names: no new name will be taken in that set 

1496 :param verbose: verbosity 

1497 :param fLOG: logging function 

1498 :return: modified object, list of modified nodes 

1499 

1500 .. versionadded:: 0.9 

1501 """ 

1502 if verbose > 0 and fLOG is None: 

1503 fLOG = print # pragma: no cover 

1504 if isinstance(obj, ModelProto): 

1505 if verbose > 0: 

1506 fLOG("[onnx_inline_function] type=%r graph=%d" % ( 

1507 type(obj), id(obj))) 

1508 if protos is None: 

1509 fct = [f.name for f in obj.functions] 

1510 ex_names = set(enumerate_onnx_names(obj)) 

1511 if existing_names is not None: 

1512 ex_names |= existing_names 

1513 return onnx_inline_function(obj, fct, existing_names=ex_names, 

1514 verbose=verbose, fLOG=fLOG) 

1515 if isinstance(protos, list): 

1516 ex_names = set(enumerate_onnx_names(obj)) 

1517 if existing_names is not None: 

1518 ex_names |= existing_names 

1519 protos = {(f.domain, f.name): f for f in obj.functions} 

1520 return onnx_inline_function(obj, protos, existing_names=ex_names, 

1521 verbose=verbose, fLOG=fLOG) 

1522 if isinstance(protos, list): 

1523 protos = {(f.domain, f.name): f for f in protos} 

1524 if not isinstance(protos, dict): 

1525 raise TypeError( # pragma: no cover 

1526 "obj is of type %r and protos must be a dictionary not %r." % ( 

1527 type(obj), type(protos))) 

1528 

1529 if isinstance(obj, ModelProto): 

1530 new_graph, m = onnx_inline_function( 

1531 obj.graph, protos, verbose=verbose, fLOG=fLOG) 

1532 if len(new_graph.initializer) != len(obj.graph.initializer): 

1533 raise RuntimeError( # pragma: no cover 

1534 "Mismatched number of initializers %d != %d." % ( 

1535 len(new_graph.initializer), len(obj.graph.initializer))) 

1536 if len(new_graph.sparse_initializer) != len(obj.graph.sparse_initializer): 

1537 raise RuntimeError( # pragma: no cover 

1538 "Mismatched number of initializers %d != %d." % ( 

1539 len(new_graph.sparse_initializer), 

1540 len(obj.graph.sparse_initializer))) 

1541 new_functions = [] 

1542 distri = Counter( 

1543 (n.domain, n.op_type) 

1544 for n in enumerate_onnx_nodes(new_graph)) 

1545 opsets = {op.domain: op.version for op in obj.opset_import} 

1546 for f in obj.functions: 

1547 key = f.domain, f.name 

1548 if key not in protos: 

1549 new_functions.append(f) 

1550 elif key in distri: 

1551 raise RuntimeError( # pragma: no cover 

1552 "Function %r still appears in the graph, " 

1553 "distibution=%s." % (key, pprint.pformat(distri))) 

1554 if f.domain not in opsets: 

1555 opsets[f.domain] = 1 

1556 return ( 

1557 make_model( 

1558 new_graph, 

1559 functions=new_functions, 

1560 opset_imports=[ 

1561 make_operatorsetid(k, v) 

1562 for k, v in opsets.items()], 

1563 producer_name=obj.producer_name, 

1564 producer_version=obj.producer_version, 

1565 ir_version=obj.ir_version, 

1566 doc_string=obj.doc_string, 

1567 domain=obj.domain, 

1568 model_version=obj.model_version), 

1569 m) 

1570 

1571 # FunctionProto, GraphProto 

1572 if existing_names is None: 

1573 existing_names = set(enumerate_onnx_names(obj)) 

1574 

1575 if verbose > 0: 

1576 fLOG("[onnx_inline_function] type=%r graph=%d begin" % ( 

1577 type(obj), id(obj))) 

1578 distri = Counter((n.domain, n.op_type) 

1579 for n in enumerate_onnx_nodes(obj)) 

1580 

1581 new_nodes = list(obj.node) 

1582 modified_nodes = [] 

1583 n_iter = 0 

1584 max_iter = onnx_subgraphs_level(obj) + 1 

1585 modified = 1 

1586 while modified > 0 and n_iter < max_iter: 

1587 if verbose > 0: 

1588 fLOG(f"[onnx_inline_function] start iteration {n_iter!r}") 

1589 

1590 # local context 

1591 mapping = _inline_mapping(verbose, fLOG, level=0) 

1592 if isinstance(obj, GraphProto): 

1593 mapping.update({i.name: i.name for i in obj.initializer}) 

1594 mapping.update({i.name: i.name for i in obj.sparse_initializer}) 

1595 for i in obj.input: 

1596 if i.name not in mapping: 

1597 mapping[i.name] = i.name 

1598 elif isinstance(obj, FunctionProto): 

1599 mapping.update({i: i for i in obj.input}) 

1600 else: 

1601 raise TypeError( # pragma: no cover 

1602 f"Unexpected type for obj: {type(obj)!r}.") 

1603 

1604 # loop on nodes 

1605 old_nodes = new_nodes 

1606 modified = 0 

1607 new_nodes = [] 

1608 for node in old_nodes: 

1609 nnodes, m = _onnx_inline_function_node( 

1610 node, protos, existing_names, verbose, fLOG, level=0) 

1611 mapping.update({o: o for o in node.output}) 

1612 

1613 if len(m) > 0: 

1614 if verbose > 0: 

1615 fLOG("[onnx_inline_function] replaced node %r (%r) " 

1616 "with %d nodes (id=%r) -- %r -> %r (iter=%r)" % ( 

1617 node.name, node.op_type, len(nnodes), id(node), 

1618 node.input, node.output, n_iter)) 

1619 modified += len(m) 

1620 new_nodes.extend(nnodes) 

1621 modified_nodes.extend(m) 

1622 else: 

1623 has_graph = False 

1624 new_attributes = [] 

1625 for att in node.attribute: 

1626 if (att.type == AttributeProto.GRAPH and 

1627 hasattr(att, 'g') and att.g is not None): 

1628 g, m = _onnx_inline_function_graph( 

1629 att.g, protos, verbose=verbose, fLOG=fLOG, 

1630 existing_names=existing_names, mapping=mapping, 

1631 rename=False, level=1) 

1632 if len(m) > 0: 

1633 modified_nodes.extend(m) 

1634 modified_nodes.append(node) 

1635 modified += 1 + len(m) 

1636 has_graph = True 

1637 att = make_attribute(att.name, g) 

1638 new_attributes.append(att) 

1639 if has_graph: 

1640 new_node = make_node( 

1641 node.op_type, node.input, node.output, 

1642 domain=node.domain, name=node.name) 

1643 new_node.attribute.extend(new_attributes) 

1644 new_nodes.append(new_node) 

1645 else: 

1646 # we still need to check that this subgraph does 

1647 # not include a function 

1648 new_nodes.append(node) 

1649 

1650 n_iter += 1 

1651 if verbose > 0: 

1652 total_node = len(list(enumerate_onnx_nodes(new_nodes))) 

1653 fLOG("[onnx_inline_function] n_iter=%r/%r nodes=%r modified=%r " 

1654 "n_nodes=%d total=%d" % ( 

1655 n_iter, max_iter, len(obj.node), modified, 

1656 len(new_nodes), total_node)) 

1657 

1658 if verbose > 0: 

1659 fLOG("[onnx_inline_function] type=%r graph=%d end with %d " 

1660 "modified nodes" % ( 

1661 type(obj), id(obj), len(modified_nodes))) 

1662 distri2 = Counter((n.domain, n.op_type) 

1663 for n in enumerate_onnx_nodes(new_nodes)) 

1664 if distri != distri2: 

1665 fLOG("[onnx_inline_function] BEFORE") 

1666 for k, v in sorted(distri.items()): 

1667 fLOG("[onnx_inline_function] %d -- %s" % (v, k)) 

1668 fLOG("[onnx_inline_function] AFTER") 

1669 for k, v in sorted(distri2.items()): 

1670 fLOG("[onnx_inline_function] %d -- %s" % (v, k)) 

1671 

1672 if isinstance(obj, FunctionProto): 

1673 return ( 

1674 make_function( 

1675 domain=obj.domain, fname=obj.name, 

1676 inputs=obj.input, outputs=obj.output, nodes=new_nodes, 

1677 opset_imports=[ 

1678 make_operatorsetid(op.domain, op.version) 

1679 for op in obj.opset_import], 

1680 doc_string=obj.doc_string, 

1681 attributes=obj.attribute), 

1682 modified_nodes) 

1683 if isinstance(obj, GraphProto): 

1684 return ( 

1685 make_graph(new_nodes, obj.name, list(obj.input), list(obj.output), 

1686 list(obj.initializer), doc_string=obj.doc_string, 

1687 sparse_initializer=list(obj.sparse_initializer)), 

1688 modified_nodes) 

1689 raise TypeError( # pragma: no cover 

1690 f"Unexpected type for obj {type(obj)!r}.")