Coverage for mlprodict/onnx_tools/onnx_manipulations.py: 94%
950 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License.
3# pylint: disable=E1101, C0302
5"""
6@file
7@brief Implements a class able to compute the predictions
8from on an :epkg:`ONNX` model.
9"""
10import hashlib
11from collections import Counter
12import pprint
13from onnx import (
14 shape_inference, ModelProto, FunctionProto, GraphProto,
15 AttributeProto)
16from onnx.helper import (
17 make_tensor_value_info, ValueInfoProto, set_model_props,
18 make_graph, make_function, make_model, make_node,
19 make_operatorsetid, make_attribute, make_value_info)
20from .onnx2py_helper import (
21 guess_proto_dtype, from_array, get_tensor_shape,
22 get_tensor_elem_type)
23from .optim import onnx_remove_node_unused
24from .onnx_tools import enumerate_onnx_names, enumerate_onnx_nodes
25from ..onnx_tools.onnx2py_helper import _var_as_dict, from_array
28def enumerate_model_node_outputs(model, add_node=False, order=False):
29 """
30 Enumerates all the nodes of a model.
32 :param model: :epkg:`ONNX` graph
33 :param add_node: if False, the function enumerates
34 all output names from every node, otherwise, it
35 enumerates tuple (output name, node)
36 :param order: goes through outputs following the graph order
37 :return: enumerator
38 """
39 if not hasattr(model, "graph"):
40 raise TypeError( # pragma: no cover
41 f"Parameter model is not an ONNX model but {type(model)}")
42 if order:
43 edges = []
44 order = {}
45 node_names = {}
46 for inp in model.graph.input:
47 order[0, inp.name] = 0
48 for node in model.graph.node:
49 order[1, node.name] = 0
50 for i in node.input:
51 edges.append(('in', i, node.name))
52 for o in node.output:
53 edges.append(('out', o, node.name))
54 node_names[o] = node
55 order[0, o] = 0
57 modif = 1
58 n_iter = 0
59 while modif > 0 and n_iter <= len(model.graph.node):
60 modif = 0
61 n_iter += 1
62 for kind, data_name, node_name in edges:
63 if kind == 'in':
64 if (0, data_name) not in order:
65 continue
66 if order[0, data_name] + 1 > order[1, node_name]:
67 modif += 1
68 order[1, node_name] = order[0, data_name] + 1
69 else:
70 if order[1, node_name] + 1 > order[0, data_name]:
71 modif += 1
72 order[0, data_name] = order[1, node_name] + 1
74 orders = [(v, k) for k, v in order.items()]
75 orders.sort()
77 for _, k in orders:
78 if k[0] == 1:
79 continue
80 out = k[1]
81 if out not in node_names:
82 continue
83 yield (out, node_names[out]) if add_node else out
84 else:
85 for node in model.graph.node:
86 for out in node.output:
87 yield (out, node) if add_node else out
90def get_opsets(model, include_functions=True, exc=True):
91 """
92 Enumerates all opsets used in a model.
94 :param model: :epkg:`ModelProto` or :epkg:`FunctionProto`
95 :param include_functions: include opsets used in functions
96 :param exc: raise an exception if conflicts are detected
97 :return: dictionary
98 """
99 if isinstance(model, ModelProto):
100 res = {}
101 for op in model.opset_import:
102 if exc and op.domain in res:
103 raise ValueError( # pragma: no cover
104 f"Domain {op.domain!r} appears multiple times.")
105 res[op.domain] = op.version
106 if include_functions:
107 for f in model.functions:
108 ops = get_opsets(f, exc=exc)
109 for k, v in ops.items():
110 if k in res:
111 if res[k] != v:
112 if exc:
113 raise ValueError( # pragma: no cover
114 "Domain %r has different version in "
115 "main graph (%d) and function %r "
116 "(%d)." % (k, res[k], f.name, v))
117 res[k] = max(res[k], v)
118 else:
119 res[k] = v
120 return res
122 res = {}
123 for op in model.opset_import:
124 if exc and op.domain in res:
125 raise ValueError( # pragma: no cover
126 f"Domain {op.domain!r} appears multiple times.")
127 res[op.domain] = op.version
128 return res
131def get_hidden_inputs(nodes):
132 """
133 Returns the list of hidden inputs used by subgraphs.
135 :param nodes: list of nodes
136 :return: list of names
137 """
138 inputs = set()
139 outputs = set()
140 for node in nodes:
141 inputs |= set(node.input)
142 outputs |= set(node.output)
143 for att in node.attribute:
144 if (att.type != AttributeProto.GRAPH or # pylint: disable=E1101
145 not hasattr(att, 'g') or att.g is None):
146 continue
147 hidden = get_hidden_inputs(att.g.node)
148 inits = set(att.g.initializer)
149 inputs |= hidden - (inits & hidden)
150 return inputs - (outputs & inputs)
153def select_model_inputs_outputs(model, outputs=None, inputs=None,
154 infer_shapes=False, overwrite=None,
155 remove_unused=True,
156 verbose=0, fLOG=None):
157 """
158 Takes a model and changes its outputs.
160 :param model: :epkg:`ONNX` model
161 :param inputs: new inputs, same ones if None
162 :param outputs: new outputs, same ones if None
163 :param infer_shapes: infer inputs and outputs shapes
164 :param overwrite: overwrite type and shapes for
165 inputs or outputs, *overwrite* is a
166 dictionary `{'name': (numpy dtype, shape)}`
167 :param remove_unused: remove unused nodes from the graph
168 :param verbose: display information while converting
169 :param fLOG: logging function
170 :return: modified model
172 The function removes unneeded nodes.
174 .. exref::
175 :title: Change ONNX model inputs
177 The following exampels shows how to change the inputs of model
178 to bypass the first nodes. Shape inferences fails to determine
179 the new inputs type. They need to be overwritten.
180 `verbose=1, fLOG=print` shows the number of deleted nodes.
182 ::
184 import onnx
185 from mlprodict.onnx_tools.onnx_manipulations import select_model_inputs_outputs
187 onx = onnx.load(path)
188 onx2 = select_model_inputs_outputs(
189 onx, inputs=["SentenceTokenizer/SentencepieceTokenizeOp:0",
190 "SentenceTokenizer/SentencepieceTokenizeOp:1"],
191 infer_shapes=True, verbose=1, fLOG=print,
192 overwrite={'SentenceTokenizer/SentencepieceTokenizeOp:0': (numpy.int32, None),
193 'SentenceTokenizer/SentencepieceTokenizeOp:1': (numpy.int64, None)})
194 onnx.save(onx2, path2)
195 """
196 if inputs is not None and not isinstance(inputs, list):
197 inputs = [inputs]
198 if outputs is not None and not isinstance(outputs, list):
199 outputs = [outputs]
200 if inputs is None:
201 inputs = [i.name for i in model.graph.input]
202 if outputs is None:
203 outputs = [o.name for o in model.graph.output]
205 mark_var = {}
206 for out in enumerate_model_node_outputs(model):
207 mark_var[out] = 0
208 for inp in inputs:
209 mark_var[inp] = 0
210 for out in outputs:
211 if out not in mark_var:
212 raise ValueError( # pragma: no cover
213 f"Output '{out}' not found in model.")
214 mark_var[out] = 1
216 nodes = list(model.graph.node[::-1])
217 mark_op = {}
218 for node in list(nodes):
219 mark_op[id(node)] = 0
221 # We mark all the nodes we need to keep.
222 nb = 1
223 while nb > 0:
224 nb = 0
225 for node in nodes:
226 if mark_op[id(node)] == 1:
227 continue
228 mod = False
229 for out in node.output:
230 if mark_var[out] == 1:
231 mark_op[id(node)] = 1
232 mod = True
233 break
234 if not mod:
235 continue
237 hidden = get_hidden_inputs([node])
238 node_inputs = list(node.input) + list(hidden)
240 nb += 1
241 for inp in node_inputs:
242 if inp in inputs:
243 continue
244 if mark_var.get(inp, 0) == 1:
245 continue
246 mark_var[inp] = 1
247 nb += 1
249 # All nodes verifies mark_op[node.name] == 1
250 keep_nodes = [node for node in nodes[::-1] if mark_op[id(node)] == 1]
252 if verbose > 1 and fLOG is not None: # pragma: no cover
253 for node in nodes:
254 s = "+" if mark_op[id(node)] == 1 else "-"
255 fLOG("[select_model_inputs_outputs] %s %s (%s) -> %s [%s]" % (
256 s, node.op_type, ", ".join(node.input),
257 ', '.join(node.output), node.name))
259 known_shapes = {}
260 if infer_shapes:
261 shapes = shape_inference.infer_shapes(model)
262 for shape in shapes.graph.value_info: # pylint: disable=E1101
263 known_shapes[shape.name] = shape.type
264 for shape in shapes.graph.input: # pylint: disable=E1101
265 known_shapes[shape.name] = shape.type
266 for shape in shapes.graph.output: # pylint: disable=E1101
267 known_shapes[shape.name] = shape.type
268 else:
269 for shape in model.graph.input: # pylint: disable=E1101
270 known_shapes[shape.name] = shape.type
271 for shape in model.graph.output: # pylint: disable=E1101
272 known_shapes[shape.name] = shape.type
274 var_in = []
275 for name in inputs:
276 if overwrite is not None and name in overwrite:
277 dtype, shape = overwrite[name]
278 proto_dtype = guess_proto_dtype(dtype)
279 value_info = make_tensor_value_info(
280 name, proto_dtype, shape)
281 elif name in known_shapes:
282 info = known_shapes[name].tensor_type
283 proto_dtype = info.elem_type
284 if proto_dtype == 0:
285 value_info = ValueInfoProto()
286 value_info.name = name
287 else:
288 shape = get_tensor_shape(known_shapes[name])
289 value_info = make_tensor_value_info(
290 name, proto_dtype, shape)
291 else:
292 value_info = ValueInfoProto()
293 value_info.name = name
294 var_in.append(value_info)
296 var_out = []
297 for name in outputs:
298 if overwrite is not None and name in overwrite:
299 dtype, shape = overwrite[name]
300 proto_dtype = guess_proto_dtype(dtype)
301 value_info = make_tensor_value_info(
302 name, proto_dtype, shape)
303 elif name in known_shapes:
304 info = known_shapes[name].tensor_type
305 proto_dtype = info.elem_type
306 if proto_dtype == 0:
307 value_info = ValueInfoProto()
308 value_info.name = name
309 else:
310 shape = get_tensor_shape(known_shapes[name])
311 value_info = make_tensor_value_info(
312 name, proto_dtype, shape)
313 else:
314 value_info = ValueInfoProto()
315 value_info.name = name
316 var_out.append(value_info)
318 if verbose > 0 and fLOG is not None: # pragma: no cover
319 fLOG("[select_model_inputs_outputs] nodes %r --> %r" % (
320 len(model.graph.node), len(keep_nodes)))
321 fLOG("[select_model_inputs_outputs] inputs: %r" %
322 [_.name for _ in var_in])
323 fLOG("[select_model_inputs_outputs] inputs: %r" %
324 [_.name for _ in var_out])
326 graph = make_graph(keep_nodes, model.graph.name, var_in,
327 var_out, model.graph.initializer,
328 sparse_initializer=model.graph.sparse_initializer)
329 onnx_model = make_model(graph, functions=model.functions)
330 onnx_model.ir_version = model.ir_version
331 onnx_model.producer_name = model.producer_name
332 onnx_model.producer_version = model.producer_version
333 onnx_model.domain = model.domain
334 onnx_model.model_version = model.model_version
335 onnx_model.doc_string = model.doc_string
336 if len(model.metadata_props) > 0: # pragma: no cover
337 values = {p.key: p.value for p in model.metadata_props}
338 set_model_props(onnx_model, values)
340 del onnx_model.opset_import[:] # pylint: disable=E1101
341 for oimp in model.opset_import:
342 op_set = onnx_model.opset_import.add() # pylint: disable=E1101
343 op_set.domain = oimp.domain
344 op_set.version = oimp.version
346 # remove unused nodes
347 if remove_unused:
348 onnx_model = onnx_remove_node_unused(onnx_model, recursive=False)
350 return onnx_model
353def change_input_type(onx, changes):
354 """
355 Changes the type of an input.
357 :param onx: ONNX model
358 :param changes: dictionary '{ name: new proto element type }`
359 :return: new onx
360 """
361 new_inputs = []
362 for inp in onx.graph.input:
363 if inp.name not in changes:
364 new_inputs.append(inp)
365 continue
366 value_info = make_tensor_value_info(
367 inp.name, changes[inp.name], None)
368 new_inputs.append(value_info)
370 # final
371 graph = make_graph(list(onx.graph.node),
372 onx.graph.name, new_inputs,
373 list(onx.graph.output),
374 onx.graph.initializer,
375 sparse_initializer=onx.graph.sparse_initializer)
376 onnx_model = make_model(graph, functions=onx.functions)
377 onnx_model.ir_version = onx.ir_version
378 onnx_model.producer_name = onx.producer_name
379 onnx_model.producer_version = onx.producer_version
380 onnx_model.domain = onx.domain
381 onnx_model.model_version = onx.model_version
382 onnx_model.doc_string = onx.doc_string
383 if len(onx.metadata_props) > 0: # pragma: no cover
384 values = {p.key: p.value for p in onx.metadata_props}
385 set_model_props(onnx_model, values)
387 del onnx_model.opset_import[:] # pylint: disable=E1101
388 for oimp in onx.opset_import:
389 op_set = onnx_model.opset_import.add() # pylint: disable=E1101
390 op_set.domain = oimp.domain
391 op_set.version = oimp.version
392 return onnx_model
395def _change_subgraph_io_type_shape_list(io_list, type_changes, shape_changes):
396 ms = False
397 new_inputs = []
398 for inp in io_list:
399 m = False
400 if isinstance(shape_changes, dict):
401 if inp.name in shape_changes:
402 shape = shape_changes[inp.name]
403 m = True
404 else:
405 shape = get_tensor_shape(inp)
406 else:
407 shape = shape_changes(inp)
408 m = True
410 if isinstance(type_changes, dict):
411 if inp.name in type_changes:
412 ntype = type_changes[inp.name]
413 m = True
414 else:
415 ntype = get_tensor_elem_type(inp)
416 else:
417 ntype = type_changes(inp)
418 m = True
420 if m:
421 ms = True
422 value_info = make_tensor_value_info(inp.name, ntype, shape)
423 new_inputs.append(value_info)
424 else:
425 new_inputs.append(inp)
426 return new_inputs if ms else None
429def change_subgraph_io_type_shape(onx, type_changes=None, shape_changes=None,
430 recursive=True):
431 """
432 Changes the type of an input or an output of a subgraph.
434 :param onx: ModelProto, GraphProto
435 :param type_changes: dictionary '{ name: new proto element type }`
436 or function `f(input) -> type`
437 :param shape_changes: dictionary '{ name: new shape }`
438 or function `f(input) -> shape`
439 :param recursive: True
440 :return: new onx
441 """
442 if isinstance(onx, ModelProto):
443 graph = change_subgraph_io_type_shape(
444 onx.graph, type_changes, shape_changes, recursive)
445 onnx_model = make_model(graph, functions=onx.functions)
446 onnx_model.ir_version = onx.ir_version
447 onnx_model.producer_name = onx.producer_name
448 onnx_model.producer_version = onx.producer_version
449 onnx_model.domain = onx.domain
450 onnx_model.model_version = onx.model_version
451 onnx_model.doc_string = onx.doc_string
452 if len(onx.metadata_props) > 0: # pragma: no cover
453 values = {p.key: p.value for p in onx.metadata_props}
454 set_model_props(onnx_model, values)
456 del onnx_model.opset_import[:] # pylint: disable=E1101
457 for oimp in onx.opset_import:
458 op_set = onnx_model.opset_import.add() # pylint: disable=E1101
459 op_set.domain = oimp.domain
460 op_set.version = oimp.version
461 return onnx_model
463 graph = onx
464 new_inputs = _change_subgraph_io_type_shape_list(
465 graph.input, type_changes or {}, shape_changes or {})
466 if new_inputs is None:
467 new_inputs = graph.input
469 new_outputs = _change_subgraph_io_type_shape_list(
470 graph.output, type_changes or {}, shape_changes or {})
471 if new_outputs is None:
472 new_outputs = graph.output
474 # recursive
475 if recursive:
476 new_nodes = []
477 for node in list(graph.node):
478 modified = False
479 atts = []
480 for att in node.attribute:
481 if (att.type == AttributeProto.GRAPH and
482 hasattr(att, 'g') and att.g is not None):
483 modified = True
484 g = change_subgraph_io_type_shape(
485 att.g, type_changes, shape_changes,
486 recursive=recursive)
487 att = make_attribute(att.name, g)
488 atts.append(att)
489 if modified:
490 node = make_node(node.op_type, node.input, node.output)
491 node.attribute.extend(atts)
492 new_nodes.append(node)
493 else:
494 new_nodes = list(graph.node)
496 # final
497 graph = make_graph(new_nodes, graph.name, new_inputs, new_outputs,
498 graph.initializer,
499 sparse_initializer=graph.sparse_initializer)
500 return graph
503def overwrite_opset(model, new_opset):
504 """
505 Overwrites the main opset in an ONNX file.
506 Does not change any node definition.
508 :param model: ONNX model
509 :param new_opset: new opset
510 :return: ONNX model
511 """
512 graph = make_graph(
513 model.graph.node, model.graph.name, model.graph.input,
514 model.graph.output, model.graph.initializer,
515 sparse_initializer=model.graph.sparse_initializer)
516 onnx_model = make_model(graph, functions=model.functions)
517 onnx_model.ir_version = model.ir_version
518 onnx_model.producer_name = model.producer_name
519 onnx_model.producer_version = model.producer_version
520 onnx_model.domain = model.domain
521 onnx_model.model_version = model.model_version
522 onnx_model.doc_string = model.doc_string
523 if len(model.metadata_props) > 0: # pragma: no cover
524 values = {p.key: p.value for p in model.metadata_props}
525 set_model_props(onnx_model, values)
527 del onnx_model.opset_import[:] # pylint: disable=E1101
528 for oimp in model.opset_import:
529 op_set = onnx_model.opset_import.add() # pylint: disable=E1101
530 op_set.domain = oimp.domain
531 op_set.version = new_opset if oimp.domain == '' else oimp.version
532 return onnx_model
535def hash_onnx_object(obj, max_size):
536 """
537 Hashes the content of an object.
538 It uses module :mod:`hashlib`.
540 :param obj: onnx graph (it must have a method `SerializeToString`)
541 :param max_size: size of the hash
542 :return: hash
543 """
544 m = hashlib.sha256()
545 if hasattr(obj, 'op_type'):
546 # An operator.
547 m.update(obj.op_type.encode('ascii'))
548 m.update(str(len(obj.input)).encode('ascii'))
549 m.update(str(len(obj.output)).encode('ascii'))
550 if hasattr(obj, 'attribute'):
551 for att in obj.attribute:
552 m.update(att.name.encode('ascii'))
553 m.update(att.SerializeToString())
554 else:
555 # An initializer.
556 name = obj.name
557 docf = obj.doc_string
558 obj.name = ''
559 obj.doc_string = ''
560 try:
561 m.update(obj.SerializeToString())
562 except AttributeError as e: # pragma: no cover
563 raise RuntimeError(
564 f"Unable to hash object type {type(obj)!r}, value={obj!r}.") from e
565 finally:
566 obj.name = name
567 obj.doc_string = docf
569 content = m.hexdigest()
570 if len(content) > max_size:
571 content = content[:max_size]
572 return content.upper()
575def onnx_rename_names(model, strategy='simple', recursive=True,
576 verbose=0, fLOG=print,
577 counts=None, replace=None, taken=None):
578 """
579 Renames all names except the inputs and outputs.
581 :param model: onnx model
582 :param strategy: two strategies are implemented, see below
583 :param recursive: walk through subgraphs
584 :param verbose: verbose, if positive, reports on all changed names
585 :param fLOG: logging function
586 :param counts: used for recursion
587 :param replace: used for recursion, it can be also used to
588 to fix some replacements
589 :param taken: used for recursion
590 :return: onnx model (the model is modified in place)
592 Strategies:
594 * `'simple'`: use a letter `n` for node, `r`, `i` for initializer,
595 this letter is followed by a number
596 * `'type'`: the name depends on the node type and content,
597 the hash is kept as small as possible
598 """
599 counts = counts or {'init': 0, 'node': 0, 'result': 0}
600 replace = replace or {}
601 taken = taken or set()
602 graph = model.graph if hasattr(model, 'graph') else model
604 for obj in graph.input:
605 replace[obj.name] = obj.name
606 for obj in graph.output:
607 replace[obj.name] = obj.name
609 def _check_name_simple(prefix):
610 if prefix not in replace:
611 return prefix
612 c = 1
613 final = "%s_%d" % (prefix, c)
614 while final in taken:
615 c += 1
616 final = "%s_%d" % (prefix, c)
617 taken.add(final)
618 return final
620 def _check_name_type(obj, prefix):
621 c = 2
622 hash = hash_onnx_object(obj, c)
623 final = f"{prefix}_{hash}"
624 while final in taken:
625 c += 2
626 hash = hash_onnx_object(obj, c)
627 final = f"{prefix}_{hash}"
628 taken.add(final)
629 return final
631 def get_name_init(init):
632 if init.name in replace:
633 return replace[init.name]
634 if strategy == 'simple':
635 name = _check_name_simple('i%d' % counts['init'])
636 counts['init'] += 1
637 replace[init.name] = name
638 if verbose > 0 and fLOG is not None:
639 fLOG(f'[onnx_rename_names] init: {init.name!r} -> {name!r}')
640 return name
641 if strategy == 'type':
642 name = _check_name_type(init, 'i')
643 counts['init'] += 1
644 replace[init.name] = name
645 if verbose > 0 and fLOG is not None:
646 fLOG(f'[onnx_rename_names] init: {init.name!r} -> {name!r}')
647 return name
648 raise ValueError( # pragma: no cover
649 f"Unknown strategy {strategy!r}.")
651 def get_name_node(node):
652 node_name = 'node_%s_%d' % (node.name, id(node))
653 if node_name in replace:
654 return replace[node_name]
655 if strategy == 'simple':
656 name = _check_name_simple('n%d' % counts['node'])
657 counts['node'] += 1
658 replace[node_name] = name
659 if verbose > 0 and fLOG is not None:
660 fLOG(f'[onnx_rename_names] node: {node_name!r} -> {name!r}')
661 return name
662 if strategy == 'type':
663 name = _check_name_type(node, 'n')
664 counts['node'] += 1
665 replace[node_name] = name
666 if verbose > 0 and fLOG is not None:
667 fLOG(f'[onnx_rename_names] node: {node_name!r} -> {name!r}')
668 return name
669 raise ValueError( # pragma: no cover
670 f"Unknown strategy {strategy!r}.")
672 def get_name_result(node, i, name, suffix):
673 if name in replace:
674 return replace[name]
675 if strategy == 'simple':
676 new_name = _check_name_simple('r%d' % counts['result'])
677 counts['result'] += 1
678 replace[name] = new_name
679 if verbose > 0 and fLOG is not None:
680 fLOG(f'[onnx_rename_names] result: {name!r} -> {new_name!r}')
681 return new_name
682 if strategy == 'type':
683 new_name = _check_name_type(node, 'r%s%d' % (suffix, i))
684 counts['result'] += 1
685 replace[name] = new_name
686 if verbose > 0 and fLOG is not None:
687 fLOG(f'[onnx_rename_names] result: {name!r} -> {new_name!r}')
688 return new_name
689 raise ValueError( # pragma: no cover
690 f"Unknown strategy {strategy!r}.")
692 def get_name_input(node, i):
693 return get_name_result(node, i, node.input[i], 'i')
695 def get_name_output(node, i):
696 return get_name_result(node, i, node.output[i], 'o')
698 for init in graph.initializer:
699 init.name = get_name_init(init)
700 for init in graph.sparse_initializer:
701 init.name = get_name_init(init)
703 for node in list(graph.node):
704 node.name = get_name_node(node)
705 for i in range(len(node.input)): # pylint: disable=C0200
706 node.input[i] = get_name_input(node, i)
707 for i in range(len(node.output)): # pylint: disable=C0200
708 node.output[i] = get_name_output(node, i)
709 if not recursive or node.op_type not in {'Scan', 'If', 'Loop'}:
710 continue
711 # recursion
712 for att in node.attribute:
713 if att.name not in {'if_branch', 'else_branch', 'body'}:
714 continue
715 onnx_rename_names(
716 att.g, strategy=strategy, fLOG=fLOG, verbose=verbose,
717 counts=counts, replace=replace, taken=taken)
719 return model
722def onnx_rename_inputs_outputs(onx, rename):
723 """
724 Renames input or outputs names.
726 :param onx: GraphProto, ModelProto
727 :param rename: dictionary `{old_name: new_name}`
728 :return: new onx
729 """
730 if isinstance(onx, ModelProto):
731 graph = onnx_rename_inputs_outputs(onx.graph, rename)
732 onnx_model = make_model(graph, functions=onx.functions)
733 onnx_model.ir_version = onx.ir_version
734 onnx_model.producer_name = onx.producer_name
735 onnx_model.producer_version = onx.producer_version
736 onnx_model.domain = onx.domain
737 onnx_model.model_version = onx.model_version
738 onnx_model.doc_string = onx.doc_string
739 if len(onx.metadata_props) > 0: # pragma: no cover
740 values = {p.key: p.value for p in onx.metadata_props}
741 set_model_props(onnx_model, values)
743 del onnx_model.opset_import[:] # pylint: disable=E1101
744 for oimp in onx.opset_import:
745 op_set = onnx_model.opset_import.add() # pylint: disable=E1101
746 op_set.domain = oimp.domain
747 op_set.version = oimp.version
748 return onnx_model
750 graph = onx
751 new_inputs = []
752 for inp in graph.input:
753 if inp.name not in rename:
754 new_inputs.append(inp)
755 continue
756 value_info = make_tensor_value_info(
757 rename[inp.name], get_tensor_elem_type(inp), get_tensor_shape(inp))
758 new_inputs.append(value_info)
760 new_outputs = []
761 for inp in graph.output:
762 if inp.name not in rename:
763 new_outputs.append(inp)
764 continue
765 value_info = make_tensor_value_info(
766 rename[inp.name], get_tensor_elem_type(inp), get_tensor_shape(inp))
767 new_outputs.append(value_info)
769 new_inits = []
770 for init in graph.initializer:
771 if init.name in rename:
772 init.name = rename[init.name]
773 new_inits.append(init)
775 new_sparse_inits = []
776 for init in graph.sparse_initializer:
777 if init.name in rename:
778 init.name = rename[init.name]
779 new_sparse_inits.append(init)
781 new_nodes = []
782 for node in list(graph.node):
783 modified = False
784 atts = []
785 for att in node.attribute:
786 if (att.type == AttributeProto.GRAPH and
787 hasattr(att, 'g') and att.g is not None):
788 modified = True
789 g = onnx_rename_inputs_outputs(att.g, rename)
790 att = make_attribute(att.name, g)
791 atts.append(att)
792 if modified:
793 node = make_node(node.op_type, node.input, node.output)
794 node.attribute.extend(atts)
796 inp = [rename.get(i, i) for i in node.input]
797 out = [rename.get(i, i) for i in node.output]
798 if inp == list(node.input) and out == list(node.output):
799 new_nodes.append(node)
800 continue
802 node = make_node(node.op_type, inp, out, domain=node.domain,
803 name=node.name)
804 node.attribute.extend(atts)
805 new_nodes.append(node)
807 # final
808 graph = make_graph(new_nodes, graph.name, new_inputs, new_outputs,
809 new_inits, sparse_initializer=new_sparse_inits)
810 return graph
813def onnx_replace_functions(model, replace):
814 """
815 Replaces some of the function in model.
817 :param model: *ModelProto*
818 :param replace: dictionary `{ (domain, name): FunctionProto }`
819 :return: new model
820 """
821 if not isinstance(model, ModelProto):
822 raise TypeError( # pragma: no cover
823 f"Unexpected type {type(model)!r}.")
824 new_functions = []
825 modified = False
826 for fct in model.functions:
827 key = fct.domain, fct.name
828 if key in replace:
829 modified = True
830 f = replace[key]
831 if not isinstance(f, FunctionProto):
832 raise TypeError( # pragma: no cover
833 f"Unexpected type {type(f)!r} for function {key!r} in replace.")
834 if len(f.input) != len(fct.input):
835 raise ValueError( # pragma: no cover
836 f"Input mismatches {f.input!r} != {fct.input!r} (expected).")
837 if len(f.output) != len(fct.output):
838 raise ValueError( # pragma: no cover
839 f"Output mismatches {f.output!r} != {fct.output!r} (expected).")
840 new_functions.append(f)
841 else:
842 new_functions.append(fct)
843 if not modified:
844 return model
845 opsets = [make_operatorsetid(op.domain, op.version)
846 for op in model.opset_import]
847 onnx_model = make_model(
848 model.graph, opset_imports=opsets, functions=new_functions)
849 onnx_model.ir_version = model.ir_version
850 onnx_model.producer_name = model.producer_name
851 onnx_model.producer_version = model.producer_version
852 onnx_model.domain = model.domain
853 onnx_model.model_version = model.model_version
854 onnx_model.doc_string = model.doc_string
855 if len(model.metadata_props) > 0: # pragma: no cover
856 values = {p.key: p.value for p in model.metadata_props}
857 set_model_props(onnx_model, values)
858 return onnx_model
861def insert_results_into_onnx(model, results, as_parameter=True, suffix='_DBG',
862 param_name=None, node_type='DEBUG',
863 domain='DEBUG', domain_opset=1):
864 """
865 Inserts results into an ONNX graph to produce an extended
866 ONNX graph. It can be saved and looked into with a tool such as
867 :epkg:`netron`.
869 :param model: ONNX graph
870 :param results: results to be added in a dictionary
871 :param as_parameter: add new nodes with results as one parameter
872 (True) or as initializer (False)
873 :param suffix: suffix to add to new results
874 :param param_name: name of the parameter to add
875 (by default the result name), it can be a function
876 `param_name(reult_name) -> parameter_name`
877 :param node_type: type of the new node
878 :param domain: domain the new node
879 :param domain_opset: opset for *domain*
880 :return: new ONNX graph
882 See method :meth:`OnnxInference.run2onnx
883 <mlprodict.onnxrt.onnx_inference.OnnxInference.run2onnx>`
884 to see a graph this function produces.
886 .. image:: debug.png
888 .. versionadded:: 0.7
889 """
890 inputs = list(model.graph.input)
891 outputs = list(model.graph.output)
892 inits = list(model.graph.initializer)
893 inits_sparse = list(model.graph.sparse_initializer)
894 node_list = list(model.graph.node)
895 nodes = {id(n): n for n in node_list}
896 order = {id(n): i for i, n in enumerate(node_list)}
897 nodes_copy = {}
899 names_init = (set(init.name for init in inits) |
900 set(init.name for init in inits_sparse))
901 names_input = set(init.name for init in inputs)
902 names_output = {}
903 for node in nodes.values():
904 for i, o in enumerate(node.output):
905 names_output[o] = (i, node)
907 for k, v in results.items():
908 if k in names_init:
909 # initializer are not inserted again
910 continue
911 if k in names_input:
912 # inputs are added as
913 raise NotImplementedError(
914 f"Unable to add debug information on input {k!r}.")
916 if k not in names_output:
917 raise RuntimeError(
918 "Unable to find result %r in the ONNX graph. Available="
919 "[%s]." % (k, ", ".join(sorted(names_output))))
921 index, node = names_output[k]
922 new_name = k + suffix
924 if id(node) not in nodes_copy:
925 new_node = make_node(
926 node.op_type, list(node.input), list(node.output),
927 domain=node.domain if node.domain else None,
928 name=node.name + suffix)
929 new_node.attribute.extend(node.attribute) # pylint: disable=E1101
930 nodes_copy[id(node)] = new_node
931 order[id(new_node)] = order[id(node)]
932 new_node = nodes_copy[id(node)]
933 new_node.output[index] = new_name
935 if as_parameter:
936 pname = k if param_name is None else param_name(k)
937 atts = {pname: from_array(v, name=pname)}
938 inserted_node = make_node(
939 node_type, [new_name], [k], domain=domain,
940 **atts)
941 else:
942 pname = k if param_name is None else param_name(k)
943 pname += suffix + 'i'
944 inserted_node = make_node(
945 node_type, [new_name, pname], [k], domain=domain)
946 inits.append(from_array(v, name=pname))
948 order[id(inserted_node)] = order[id(node)] + 1. / (index + 2)
949 nodes[id(inserted_node)] = inserted_node
951 new_nodes = [(order[id(n)], n)
952 for n in nodes.values() if id(n) not in nodes_copy]
953 new_nodes.extend((order[id(n)], n) for n in nodes_copy.values())
954 new_nodes = [n[1] for n in sorted(new_nodes)]
956 graph = make_graph(new_nodes, model.graph.name, inputs, outputs,
957 inits, sparse_initializer=inits_sparse)
958 onnx_model = make_model(graph, functions=model.functions)
959 onnx_model.ir_version = model.ir_version
960 onnx_model.producer_name = model.producer_name
961 onnx_model.producer_version = model.producer_version
962 onnx_model.domain = model.domain
963 onnx_model.model_version = model.model_version
964 onnx_model.doc_string = model.doc_string
965 if len(model.metadata_props) > 0: # pragma: no cover
966 values = {p.key: p.value for p in model.metadata_props}
967 set_model_props(onnx_model, values)
969 del onnx_model.opset_import[:] # pylint: disable=E1101
970 for oimp in model.opset_import:
971 op_set = onnx_model.opset_import.add() # pylint: disable=E1101
972 op_set.domain = oimp.domain
973 op_set.version = oimp.version
974 op_set = onnx_model.opset_import.add() # pylint: disable=E1101
975 op_set.domain = domain
976 op_set.version = domain_opset
977 return onnx_model
980def onnx_model_to_function(onx, name=None, domain="custom",
981 opset_imports=None, doc_string=None,
982 inputs2par=None):
983 """
984 Converts an ONNX model into a function. The returned function
985 has no attribute.
987 :param onx: onnx model
988 :param name: function name
989 :param domain: function domain
990 :param opset_imports: opset to import as a dictionary
991 `{domain: version}`
992 :param doc_string: doc string
993 :param inputs2par: dictionary to move some inputs as attributes
994 `{ name: None or default value }`
995 :return: function, other functions
997 .. warning::
998 :epkg:`FunctionProto` does not support default values yet.
999 They are ignored.
1000 """
1001 if isinstance(onx, ModelProto):
1002 if opset_imports is None:
1003 domains = {}
1004 for op in onx.opset_import:
1005 domains[op.domain] = op.version
1006 opset_imports = domains
1007 if doc_string is None:
1008 doc_string = onx.doc_string
1009 fp, lf = onnx_model_to_function(
1010 onx.graph, name=name, domain=domain,
1011 opset_imports=opset_imports, doc_string=doc_string,
1012 inputs2par=inputs2par)
1013 return fp, lf + list(onx.functions)
1015 if not isinstance(onx, GraphProto):
1016 raise TypeError( # pragma: no cover
1017 f"Unexpected type {type(onx)!r} for onx.")
1019 if name is None:
1020 name = onx.name
1022 inputs = []
1023 outputs = [o.name for o in onx.output]
1024 attributes = []
1025 nodes = []
1026 if inputs2par is None:
1027 inputs.extend(i.name for i in onx.input)
1028 else:
1029 for i in onx.input:
1030 if i.name not in inputs2par:
1031 inputs.append(i.name)
1032 continue
1033 attributes.append(i.name)
1035 if len(onx.initializer) > 0 or len(onx.sparse_initializer) > 0:
1036 # Needs to convert every initializer into Constant.
1037 csts = []
1038 for init in onx.initializer:
1039 v = _var_as_dict(init)
1040 value = from_array(v['value'])
1041 n = make_node('Constant', [], [init.name], value=value)
1042 csts.append(n)
1043 for init in onx.sparse_initializer:
1044 v = _var_as_dict(init)
1045 value = from_array(v['sparse_value'])
1046 n = make_node('Constant', [], [init.name], sparse_value=value)
1047 csts.append(n)
1048 nodes.extend(csts)
1050 nodes.extend(onx.node)
1052 if isinstance(opset_imports, dict):
1053 ops = [make_operatorsetid(k, v) for k, v in opset_imports.items()]
1054 opset_imports = ops
1055 return make_function(
1056 domain, name, inputs, outputs, nodes,
1057 opset_imports=opset_imports, doc_string=doc_string or '',
1058 attributes=attributes), []
1061def _onnx_function_to_model_convert_io(ens, type_info, shape_fct):
1062 typed_io = []
1063 for name in ens:
1064 if isinstance(type_info, dict):
1065 res = type_info[name]
1066 elif callable(type_info):
1067 res = type_info(name)
1068 else:
1069 raise TypeError( # pragma: no cover
1070 "type_info is not a callable or a dictionary, "
1071 "unable to guess type for name=%r with "
1072 "type(type_info)=%r." % (name, type(type_info)))
1073 if isinstance(res, int):
1074 proto_dtype = res
1075 else:
1076 proto_dtype = guess_proto_dtype(res)
1077 value_info = make_tensor_value_info(
1078 name, proto_dtype, shape_fct(name, proto_dtype))
1079 typed_io.append(value_info)
1080 return typed_io
1083def onnx_function_to_model(onx, functions=None, type_info=None,
1084 as_function=False, shape_fct=None):
1085 """
1086 Converts an ONNX FunctionProto into a ModelProto.
1087 The function does not handle attributes yet.
1089 :param onx: onnx function
1090 :param functions: additional functions to append to the model
1091 :param type_info: dictionary or callable which returns the type of
1092 inputs or outputs if it cannot be guessed
1093 :param as_function: if True, the function stays as a function and a single node
1094 is created to call that function
1095 :param shape_fct: function to specify the shapes,
1096 signature: `shape_fct(name, proto_type) -> list`
1097 :return: function
1098 """
1099 if not isinstance(onx, FunctionProto):
1100 raise TypeError( # pragma: no cover
1101 f"onx must be a FunctionProto not {type(onx)!r}.")
1102 if len(onx.attribute) > 0:
1103 raise NotImplementedError( # pragma: no cover
1104 "The function has attributes, it is not implemented yet.")
1106 if isinstance(functions, list):
1107 added_functions = functions.copy()
1108 elif isinstance(functions, dict):
1109 added_functions = list(functions.values())
1110 elif functions is None:
1111 added_functions = []
1112 else:
1113 raise TypeError( # pragma: no cover
1114 f"Unexpected type for functions {type(functions)!r}.")
1116 if shape_fct is None:
1117 shape_fct = lambda name, dtype: None
1119 inputs = _onnx_function_to_model_convert_io(
1120 onx.input, type_info, shape_fct=shape_fct)
1121 outputs = _onnx_function_to_model_convert_io(
1122 onx.output, type_info, shape_fct=shape_fct)
1123 if as_function:
1124 nodes = [make_node(onx.name,
1125 [i.name for i in inputs],
1126 [o.name for o in outputs],
1127 domain=onx.domain)]
1128 added_functions.append(onx)
1129 opsets = [make_operatorsetid(onx.domain, 1)]
1130 else:
1131 nodes = list(onx.node)
1132 opsets = [make_operatorsetid(op.domain, op.version)
1133 for op in onx.opset_import]
1134 graph = make_graph(nodes, onx.name, inputs, outputs,
1135 [], doc_string=onx.doc_string)
1136 model = make_model(graph, functions=added_functions,
1137 opset_imports=opsets,
1138 doc_string=onx.doc_string,
1139 model_version=1,
1140 domain=onx.domain)
1141 return model
1144def _get_new_name(prefix, name, existing_names):
1145 opt = f"{prefix}_{name}_0"
1146 i = 0
1147 while opt in existing_names:
1148 i += 1
1149 opt = "%s_%s_%d" % (prefix, name, i)
1150 existing_names.add(opt)
1151 return opt
1154def onnx_subgraphs_level(obj):
1155 """
1156 Returns the depth of the graph.
1158 :param obj: onnx object
1159 :return: integer
1160 """
1161 if isinstance(obj, ModelProto):
1162 return onnx_subgraphs_level(obj.graph)
1163 best = 0
1164 for node in obj.node:
1165 for att in node.attribute:
1166 if (att.type == AttributeProto.GRAPH and
1167 hasattr(att, 'g') and att.g is not None):
1168 m = onnx_subgraphs_level(att.g)
1169 if m > best:
1170 best = m
1171 return best + 1
1174class _inline_mapping(dict):
1175 """
1176 Overwrites class dictionary to debug more easily.
1178 :param verbose: verbosity
1179 :param fLOG: logging function
1180 :param level: sub graph level
1181 """
1183 def __init__(self, verbose, fLOG, level):
1184 dict.__init__(self)
1185 self._verbose = verbose
1186 self._fLOG = fLOG
1187 self._level = level
1189 def __setitem__(self, key, value):
1190 "Adds a value."
1191 if self._verbose > 3:
1192 self._fLOG("[_inline_mapping-dict-addkv] %s + %r: %r" %
1193 (" " * self._level, key, value))
1194 if key in self:
1195 raise RuntimeError( # pragma: no cover
1196 "Key %r was already added (with value %r, new one is %r)."
1197 "" % (key, self[key], value))
1198 dict.__setitem__(self, key, value)
1200 def update(self, d):
1201 "Updates many values."
1202 for k, v in d.items():
1203 self[k] = v
1205 def copy(self):
1206 "Returns a copy."
1207 m = _inline_mapping(self._verbose, self._fLOG, self._level)
1208 for k, v in self.items():
1209 m[k] = v
1210 return m
1212 def remove(self, o):
1213 "Removes one element."
1214 if o not in self:
1215 raise KeyError( # pragma: no cover
1216 f"Cannot remove a key {o!r}.")
1217 self.pop(o)
1220def _onnx_inline_function_graph(graph, protos, existing_names, mapping,
1221 verbose, fLOG, rename, level):
1222 if len(graph.node) == 0:
1223 # Outputs have still to be renamed.
1224 graph0 = graph
1225 if verbose > 1:
1226 fLOG( # pragma: no cover
1227 "[onnx_inline_function-graph] %s visit0 graph=%d rename=%r "
1228 "len(mapping)=%d begin" % (
1229 " " * level, id(graph), rename, len(mapping)))
1230 if rename:
1231 modified_nodes = []
1232 mapping = mapping.copy()
1233 for i in graph.input:
1234 mapping[i.name] = i.name
1235 for i in graph.initializer:
1236 mapping[i.name] = i.name
1237 for i in graph.sparse_initializer:
1238 mapping[i.name] = i.name
1239 outputs = []
1240 for o in graph.output:
1241 no = make_value_info(mapping[o.name], o.type)
1242 if no.name != o.name:
1243 modified_nodes.append(o)
1244 outputs.append(no)
1245 else:
1246 outputs.append(o)
1247 if len(modified_nodes) > 0:
1248 graph = make_graph(
1249 [], graph.name, graph.input, outputs,
1250 graph.initializer, doc_string=graph.doc_string,
1251 sparse_initializer=list(graph.sparse_initializer))
1252 else:
1253 modified_nodes = []
1255 if verbose > 1:
1256 fLOG( # pragma: no cover
1257 "[onnx_inline_function-graph] %s visit graph=%d end "
1258 "changed=%r len(modified_nodes)=%d" % (
1259 " " * level, id(graph0), id(graph0) != id(graph),
1260 len(modified_nodes)))
1262 return graph, modified_nodes
1264 graph0 = graph
1265 mapping = mapping.copy()
1266 init = list(graph.initializer)
1267 init_sparse = list(graph.sparse_initializer)
1268 inputs = list(graph.input)
1269 modified_nodes = []
1270 outputs = list(graph.output)
1272 if verbose > 1:
1273 fLOG("[onnx_inline_function-graph] %s >visit graph=%d rename=%r "
1274 "len(mapping)=%d begin" % (
1275 " " * level, id(graph), rename, len(mapping)))
1277 output_names = [o.name for o in outputs]
1278 for i in init:
1279 mapping[i.name] = i.name
1280 for i in init_sparse:
1281 mapping[i.name] = i.name
1282 for i in inputs:
1283 mapping[i.name] = i.name
1285 # first step, replace names
1286 nodes = []
1287 for node in list(graph.node):
1288 mod = 0
1289 inp = []
1290 for i in node.input:
1291 if i in mapping:
1292 inp.append(mapping[i])
1293 if mapping[i] != i:
1294 mod += 1
1295 else:
1296 raise RuntimeError( # pragma: no cover
1297 "Cannot find input %r in %s for node (level=%d)\n%r." % (
1298 i, pprint.pformat(mapping), level, node))
1299 out = []
1300 for o in node.output:
1301 new_o = o
1302 if rename:
1303 if o not in output_names:
1304 new_o = _get_new_name('_inl', o, existing_names)
1305 if o in mapping:
1306 # See below.
1307 mapping.remove(o)
1308 elif o in mapping:
1309 # That means the main contains a result node but is overwritten by
1310 # the subgraph. The local variable cannot be reached anymore,
1311 # we remove it.
1312 mapping.remove(o)
1313 if o in node.input:
1314 new_o = _get_new_name('_inl', o, existing_names)
1315 if verbose > 3:
1316 fLOG(
1317 "[onnx_inline_function-renam] %s node %r(%r): %r -> %r "
1318 "overwrite result (%r -> %r)." % (
1319 " " * level, node.op_type, node.name, node.input,
1320 node.output, o, new_o))
1321 out.append(new_o)
1322 mapping[o] = new_o
1323 if o != new_o:
1324 mapping[new_o] = new_o
1325 mod += 1
1327 if verbose > 3:
1328 fLOG("[onnx_inline_function-renam] %s rep node %r(%r): %r -> %r" % (
1329 " " * level, node.op_type, node.name, node.input, node.output))
1330 new_node = make_node(node.op_type, inp, out, domain=node.domain,
1331 name=_get_new_name('_inln', node.name, existing_names))
1332 for att in node.attribute:
1333 if (att.type == AttributeProto.GRAPH and
1334 hasattr(att, 'g') and att.g is not None):
1335 g, m = _onnx_inline_function_graph(
1336 att.g, protos, existing_names=existing_names,
1337 verbose=verbose, fLOG=fLOG, mapping=mapping,
1338 rename=rename, level=level + 1)
1339 if len(m) > 0:
1340 att = make_attribute(att.name, g)
1341 mod += len(m)
1342 else:
1343 att = make_attribute(att.name, att.g)
1344 new_node.attribute.append(att)
1345 if mod > 0:
1346 if verbose > 2:
1347 fLOG("[onnx_inline_function-renam] %s add node %r(%r): %r -> %r" % (
1348 " " * level,
1349 new_node.op_type, new_node.name,
1350 new_node.input, new_node.output))
1351 nodes.append(new_node)
1352 modified_nodes.append(node)
1353 else:
1354 nodes.append(node)
1356 if len(modified_nodes) > 0:
1357 if verbose > 1:
1358 fLOG("[onnx_inline_function-graph] %s -1 graph=%d "
1359 "len(modified_nodes)=%d" % (
1360 " " * level, id(graph), len(modified_nodes)))
1362 graph = make_graph(
1363 nodes, graph.name, inputs, outputs,
1364 init, doc_string=graph.doc_string,
1365 sparse_initializer=list(graph.sparse_initializer))
1366 elif not rename:
1367 # no modification, let's check the node hiding a functions
1368 new_nodes = []
1369 for node in nodes:
1370 nnodes, m = _onnx_inline_function_node(
1371 node, protos, existing_names, verbose, fLOG,
1372 level=level)
1373 if len(m) > 0:
1374 if verbose > 0:
1375 fLOG("[onnx_inline_function-subgr] %s replaced node %r (%r) "
1376 "with %d nodes (id=%r) -- %r -> %r" % (
1377 " " * level,
1378 node.name, node.op_type, len(nnodes), id(node),
1379 node.input, node.output))
1380 new_nodes.extend(nnodes)
1381 modified_nodes.extend(m)
1382 else:
1383 new_nodes.append(node)
1384 if len(modified_nodes) > 0:
1385 if verbose > 1:
1386 fLOG("[onnx_inline_function-graph] %s -2 graph=%d "
1387 "len(modified_nodes)=%d" % (
1388 " " * level, id(graph), len(modified_nodes)))
1390 nodes = new_nodes
1391 graph = make_graph(
1392 nodes, graph.name, inputs, outputs,
1393 init, doc_string=graph.doc_string,
1394 sparse_initializer=list(graph.sparse_initializer))
1396 if verbose > 1:
1397 fLOG("[onnx_inline_function-graph] %s <visit graph=%d end "
1398 "changed=%r len(modified_nodes)=%d" % (
1399 " " * level, id(graph0), id(graph0) != id(graph),
1400 len(modified_nodes)))
1402 return graph, modified_nodes
1405def _onnx_inline_function_node(node, protos, existing_names, verbose,
1406 fLOG, level):
1407 # The function does not rename input or output
1408 # of the node, it just replaces the node but a function
1409 # if the function exists.
1410 modified_nodes = []
1411 key = node.domain, node.op_type
1412 if key in protos:
1413 proto = protos[key]
1414 if not isinstance(proto, FunctionProto):
1415 raise TypeError( # pragma: no cover
1416 "Prototype for key=%r must be a Function Proto, not %r." % (
1417 key, type(proto)))
1418 modified_nodes.append(node)
1419 new_nodes = []
1420 mapping = _inline_mapping(verbose, fLOG, level)
1421 prefix = "_inl"
1423 for fr, to in zip(node.input, proto.input):
1424 n = make_node('Identity', [fr],
1425 [_get_new_name(prefix, to, existing_names)])
1426 if verbose > 2:
1427 fLOG("[onnx_inline_function-ninpu] %s add node %r(%r): %r -> %r" % (
1428 " " * level, n.op_type, n.name, n.input, n.output))
1429 mapping[to] = n.output[0]
1430 if to != n.output[0]:
1431 mapping[n.output[0]] = n.output[0]
1432 new_nodes.append(n)
1434 for nn in proto.node:
1435 new_input = [mapping[i] for i in nn.input]
1436 new_output = [_get_new_name(prefix, o, existing_names)
1437 for o in nn.output]
1438 mapping.update(
1439 {o: oo for o, oo in zip(nn.output, new_output)})
1440 mapping.update({oo: oo for oo in new_output})
1441 new_node = make_node(
1442 nn.op_type, new_input, new_output,
1443 domain=nn.domain, name=_get_new_name(
1444 prefix, nn.name, existing_names))
1445 if verbose > 3:
1446 fLOG("[onnx_inline_function-nnode] %s rep node %r(%r): %r -> %r" % (
1447 " " * level, nn.op_type, nn.name, nn.input, nn.output))
1448 if verbose > 2:
1449 fLOG("[onnx_inline_function-nnode] %s add node %r(%r): %r -> %r" % (
1450 " " * level,
1451 new_node.op_type, new_node.name,
1452 new_node.input, new_node.output))
1453 for att in nn.attribute:
1454 if (att.type == AttributeProto.GRAPH and
1455 hasattr(att, 'g') and att.g is not None):
1456 if verbose > 1:
1457 fLOG("[onnx_inline_function-funct] %s fct=%r graph=%d node=%d" % (
1458 " " * level, key, id(att.g), id(new_node)))
1460 g, m = _onnx_inline_function_graph(
1461 att.g, protos, existing_names=existing_names,
1462 verbose=verbose, fLOG=fLOG, mapping=mapping,
1463 rename=True, level=level + 1)
1464 if len(m) > 0:
1465 att = make_attribute(att.name, g)
1466 else:
1467 att = make_attribute(att.name, att.g)
1468 new_node.attribute.append(att)
1469 new_nodes.append(new_node)
1471 for fr, to in zip(proto.output, node.output):
1472 n = make_node('Identity', [mapping[fr]], [to])
1473 if verbose > 2:
1474 fLOG("[onnx_inline_function-noutt] %s add node %r(%r): %r -> %r" % (
1475 " " * level, n.op_type, n.name, n.input, n.output))
1476 new_nodes.append(n)
1477 else:
1478 new_nodes = [node]
1479 modified_nodes = []
1480 return new_nodes, modified_nodes
1483def onnx_inline_function(obj, protos=None, existing_names=None, verbose=0, fLOG=None):
1484 """
1485 Inlines functions in an ONNX graph.
1487 :param obj: onnx graph, :epkg:`FunctionProto`, :epkg:`GraphProto`,
1488 :epkg:`ModelProto`
1489 :param protos: if None, the function assumes *obj* is of type
1490 :epkg:`ModelProto` and the goal is to inline every function.
1491 If *protos* a list of strings, the function only inlines the
1492 functions in that list. If *protos* is a dictionary
1493 `{ (domain, type): FunctionProto }`, the function replaces every
1494 node `(domain, type)` by the code given in this dictionary
1495 :param existing_names: no new name will be taken in that set
1496 :param verbose: verbosity
1497 :param fLOG: logging function
1498 :return: modified object, list of modified nodes
1500 .. versionadded:: 0.9
1501 """
1502 if verbose > 0 and fLOG is None:
1503 fLOG = print # pragma: no cover
1504 if isinstance(obj, ModelProto):
1505 if verbose > 0:
1506 fLOG("[onnx_inline_function] type=%r graph=%d" % (
1507 type(obj), id(obj)))
1508 if protos is None:
1509 fct = [f.name for f in obj.functions]
1510 ex_names = set(enumerate_onnx_names(obj))
1511 if existing_names is not None:
1512 ex_names |= existing_names
1513 return onnx_inline_function(obj, fct, existing_names=ex_names,
1514 verbose=verbose, fLOG=fLOG)
1515 if isinstance(protos, list):
1516 ex_names = set(enumerate_onnx_names(obj))
1517 if existing_names is not None:
1518 ex_names |= existing_names
1519 protos = {(f.domain, f.name): f for f in obj.functions}
1520 return onnx_inline_function(obj, protos, existing_names=ex_names,
1521 verbose=verbose, fLOG=fLOG)
1522 if isinstance(protos, list):
1523 protos = {(f.domain, f.name): f for f in protos}
1524 if not isinstance(protos, dict):
1525 raise TypeError( # pragma: no cover
1526 "obj is of type %r and protos must be a dictionary not %r." % (
1527 type(obj), type(protos)))
1529 if isinstance(obj, ModelProto):
1530 new_graph, m = onnx_inline_function(
1531 obj.graph, protos, verbose=verbose, fLOG=fLOG)
1532 if len(new_graph.initializer) != len(obj.graph.initializer):
1533 raise RuntimeError( # pragma: no cover
1534 "Mismatched number of initializers %d != %d." % (
1535 len(new_graph.initializer), len(obj.graph.initializer)))
1536 if len(new_graph.sparse_initializer) != len(obj.graph.sparse_initializer):
1537 raise RuntimeError( # pragma: no cover
1538 "Mismatched number of initializers %d != %d." % (
1539 len(new_graph.sparse_initializer),
1540 len(obj.graph.sparse_initializer)))
1541 new_functions = []
1542 distri = Counter(
1543 (n.domain, n.op_type)
1544 for n in enumerate_onnx_nodes(new_graph))
1545 opsets = {op.domain: op.version for op in obj.opset_import}
1546 for f in obj.functions:
1547 key = f.domain, f.name
1548 if key not in protos:
1549 new_functions.append(f)
1550 elif key in distri:
1551 raise RuntimeError( # pragma: no cover
1552 "Function %r still appears in the graph, "
1553 "distibution=%s." % (key, pprint.pformat(distri)))
1554 if f.domain not in opsets:
1555 opsets[f.domain] = 1
1556 return (
1557 make_model(
1558 new_graph,
1559 functions=new_functions,
1560 opset_imports=[
1561 make_operatorsetid(k, v)
1562 for k, v in opsets.items()],
1563 producer_name=obj.producer_name,
1564 producer_version=obj.producer_version,
1565 ir_version=obj.ir_version,
1566 doc_string=obj.doc_string,
1567 domain=obj.domain,
1568 model_version=obj.model_version),
1569 m)
1571 # FunctionProto, GraphProto
1572 if existing_names is None:
1573 existing_names = set(enumerate_onnx_names(obj))
1575 if verbose > 0:
1576 fLOG("[onnx_inline_function] type=%r graph=%d begin" % (
1577 type(obj), id(obj)))
1578 distri = Counter((n.domain, n.op_type)
1579 for n in enumerate_onnx_nodes(obj))
1581 new_nodes = list(obj.node)
1582 modified_nodes = []
1583 n_iter = 0
1584 max_iter = onnx_subgraphs_level(obj) + 1
1585 modified = 1
1586 while modified > 0 and n_iter < max_iter:
1587 if verbose > 0:
1588 fLOG(f"[onnx_inline_function] start iteration {n_iter!r}")
1590 # local context
1591 mapping = _inline_mapping(verbose, fLOG, level=0)
1592 if isinstance(obj, GraphProto):
1593 mapping.update({i.name: i.name for i in obj.initializer})
1594 mapping.update({i.name: i.name for i in obj.sparse_initializer})
1595 for i in obj.input:
1596 if i.name not in mapping:
1597 mapping[i.name] = i.name
1598 elif isinstance(obj, FunctionProto):
1599 mapping.update({i: i for i in obj.input})
1600 else:
1601 raise TypeError( # pragma: no cover
1602 f"Unexpected type for obj: {type(obj)!r}.")
1604 # loop on nodes
1605 old_nodes = new_nodes
1606 modified = 0
1607 new_nodes = []
1608 for node in old_nodes:
1609 nnodes, m = _onnx_inline_function_node(
1610 node, protos, existing_names, verbose, fLOG, level=0)
1611 mapping.update({o: o for o in node.output})
1613 if len(m) > 0:
1614 if verbose > 0:
1615 fLOG("[onnx_inline_function] replaced node %r (%r) "
1616 "with %d nodes (id=%r) -- %r -> %r (iter=%r)" % (
1617 node.name, node.op_type, len(nnodes), id(node),
1618 node.input, node.output, n_iter))
1619 modified += len(m)
1620 new_nodes.extend(nnodes)
1621 modified_nodes.extend(m)
1622 else:
1623 has_graph = False
1624 new_attributes = []
1625 for att in node.attribute:
1626 if (att.type == AttributeProto.GRAPH and
1627 hasattr(att, 'g') and att.g is not None):
1628 g, m = _onnx_inline_function_graph(
1629 att.g, protos, verbose=verbose, fLOG=fLOG,
1630 existing_names=existing_names, mapping=mapping,
1631 rename=False, level=1)
1632 if len(m) > 0:
1633 modified_nodes.extend(m)
1634 modified_nodes.append(node)
1635 modified += 1 + len(m)
1636 has_graph = True
1637 att = make_attribute(att.name, g)
1638 new_attributes.append(att)
1639 if has_graph:
1640 new_node = make_node(
1641 node.op_type, node.input, node.output,
1642 domain=node.domain, name=node.name)
1643 new_node.attribute.extend(new_attributes)
1644 new_nodes.append(new_node)
1645 else:
1646 # we still need to check that this subgraph does
1647 # not include a function
1648 new_nodes.append(node)
1650 n_iter += 1
1651 if verbose > 0:
1652 total_node = len(list(enumerate_onnx_nodes(new_nodes)))
1653 fLOG("[onnx_inline_function] n_iter=%r/%r nodes=%r modified=%r "
1654 "n_nodes=%d total=%d" % (
1655 n_iter, max_iter, len(obj.node), modified,
1656 len(new_nodes), total_node))
1658 if verbose > 0:
1659 fLOG("[onnx_inline_function] type=%r graph=%d end with %d "
1660 "modified nodes" % (
1661 type(obj), id(obj), len(modified_nodes)))
1662 distri2 = Counter((n.domain, n.op_type)
1663 for n in enumerate_onnx_nodes(new_nodes))
1664 if distri != distri2:
1665 fLOG("[onnx_inline_function] BEFORE")
1666 for k, v in sorted(distri.items()):
1667 fLOG("[onnx_inline_function] %d -- %s" % (v, k))
1668 fLOG("[onnx_inline_function] AFTER")
1669 for k, v in sorted(distri2.items()):
1670 fLOG("[onnx_inline_function] %d -- %s" % (v, k))
1672 if isinstance(obj, FunctionProto):
1673 return (
1674 make_function(
1675 domain=obj.domain, fname=obj.name,
1676 inputs=obj.input, outputs=obj.output, nodes=new_nodes,
1677 opset_imports=[
1678 make_operatorsetid(op.domain, op.version)
1679 for op in obj.opset_import],
1680 doc_string=obj.doc_string,
1681 attributes=obj.attribute),
1682 modified_nodes)
1683 if isinstance(obj, GraphProto):
1684 return (
1685 make_graph(new_nodes, obj.name, list(obj.input), list(obj.output),
1686 list(obj.initializer), doc_string=obj.doc_string,
1687 sparse_initializer=list(obj.sparse_initializer)),
1688 modified_nodes)
1689 raise TypeError( # pragma: no cover
1690 f"Unexpected type for obj {type(obj)!r}.")