Coverage for mlprodict/onnx_tools/optim/onnx_optimisation_identity.py: 100%
112 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Optimisation of :epkg:`ONNX` graphs.
4"""
5import logging
6from onnx import FunctionProto, AttributeProto
7from onnx.helper import make_graph, make_function
8from ._onnx_optimisation_common import ( # pylint: disable=E0611
9 _rename_node_input,
10 _rename_node_output,
11 _apply_optimisation_on_graph,
12 _apply_remove_node_fct_node)
15logger = logging.getLogger('onnx:optim')
18def onnx_remove_node_identity(onnx_model, recursive=True, debug_info=None, **options):
19 """
20 Removes as many *Identity* nodes as possible.
21 The function looks into every node and subgraphs if
22 *recursive* is True for identity node. Unless such a
23 node directy connects one input to one output, it will
24 be removed and every other node gets its inputs or
25 outputs accordingly renamed.
27 :param onnx_model: onnx model
28 :param recursive: looks into subgraphs
29 :param debug_info: debug information (private)
30 :param options: additional options (unused)
31 :return: new onnx _model
32 """
33 if debug_info is None:
34 debug_info = [str(type(onnx_model)).rsplit(
35 '.', maxsplit=1)[-1].strip("'>")]
36 else:
37 debug_info = (debug_info +
38 [str(type(onnx_model)).rsplit('.', maxsplit=1)[-1].strip("'>")])
40 if hasattr(onnx_model, 'graph'):
41 return _apply_optimisation_on_graph(
42 onnx_remove_node_identity, onnx_model,
43 recursive=recursive, debug_info=debug_info, **options)
45 graph = onnx_model
46 logger.debug("onnx_remove_node_identity:begin with %d nodes.",
47 len(graph.node))
48 is_function = isinstance(graph, FunctionProto)
50 if is_function:
51 inputs = set(graph.input)
52 outputs = set(graph.output)
53 else:
54 inputs = set(i.name for i in graph.input)
55 inits = set(i.name for i in graph.initializer)
56 inputs_inits = inputs.union(inits)
57 outputs = set(o.name for o in graph.output)
59 def retrieve_idnodes(graph, existing_nodes):
60 idnodes = []
61 for i, exnode in enumerate(existing_nodes):
62 if exnode is None:
63 continue
64 if exnode.op_type == 'Identity':
65 input = exnode.input[0]
66 output = exnode.output[0]
67 idnodes.append((i, exnode, input, output))
68 return idnodes
70 # add to output the list of local variables in subgraphs
71 def append_local_variable(graph, known=None, subgraph=True):
72 if known is None:
73 known = set()
74 else:
75 known = known.copy()
76 local_var = set()
77 if isinstance(graph, FunctionProto):
78 known = set(graph.input)
79 else:
80 known = set(i.name for i in graph.input)
81 known |= set(i.name for i in graph.initializer)
82 for node in graph.node:
83 for i in node.input:
84 if i not in known and subgraph:
85 local_var.add(i)
86 for o in node.output:
87 known.add(o)
88 for att in node.attribute:
89 if (att.type == AttributeProto.GRAPH and # pylint: disable=E1101
90 hasattr(att, 'g') and att.g is not None):
91 lv = append_local_variable(att.g, known)
92 local_var |= lv
93 return local_var
95 local_vars = append_local_variable(graph, subgraph=False)
96 logger.debug('onnx_remove_node_identity:local_vars:%r', local_vars)
97 ext_outputs = outputs | local_vars
99 nodes = list(graph.node)
100 rem = 1
101 while rem > 0:
102 rem = 0
103 idnodes = retrieve_idnodes(graph, nodes)
104 restart = False
105 for i, _, inp, out in idnodes:
106 if restart:
107 break # pragma: no cover
108 if nodes[i] is None:
109 # Already removed.
110 continue # pragma: no cover
111 if inp in inputs_inits and out in ext_outputs:
112 # Cannot be removed.
113 continue
114 if not restart and out not in ext_outputs:
115 # We cannot change an output name.
116 for j in range(len(nodes)): # pylint: disable=C0200
117 if nodes[j] is None:
118 continue
119 if out in nodes[j].input:
120 logger.debug('onnx_remove_node_identity:'
121 '_rename_node_input:%s:%r->%r:'
122 'out=%r:inp=%r',
123 nodes[j].op_type, nodes[j].input,
124 nodes[j].output, out, inp)
125 nodes[j] = _rename_node_input(nodes[j], out, inp)
126 rem += 1
127 if nodes[j].op_type == 'Identity':
128 restart = True # pragma: no cover
129 logger.debug('onnx_remove_node_identity:1:remove:%s:%r->%r:',
130 nodes[i].op_type, nodes[i].input, nodes[i].output)
131 nodes[i] = None
132 rem += 1
133 continue
134 if not restart and inp not in inputs_inits and inp not in ext_outputs:
135 # We cannot change an input name or an output name.
136 for j in range(len(nodes)): # pylint: disable=C0200
137 if nodes[j] is None:
138 continue
139 if inp in nodes[j].output:
140 logger.debug('onnx_remove_node_identity:'
141 '_rename_node_output:%s:%r->%r:'
142 'inp=%r:out=%r',
143 nodes[j].op_type, nodes[j].input,
144 nodes[j].output, inp, out)
145 nodes[j] = _rename_node_output(nodes[j], inp, out)
146 rem += 1
147 if nodes[j].op_type == 'Identity':
148 restart = True # pragma: no cover
149 if inp in nodes[j].input:
150 logger.debug('onnx_remove_node_identity:'
151 '_rename_node_input:%s:%r->%r:'
152 'inp=%r:out=%r',
153 nodes[j].op_type, nodes[j].input,
154 nodes[j].output, inp, out)
155 nodes[j] = _rename_node_input(nodes[j], inp, out)
156 rem += 1
157 if nodes[j].op_type == 'Identity':
158 restart = True
159 logger.debug('onnx_remove_node_identity:2:remove:%s:%r->%r:',
160 nodes[i].op_type, nodes[i].input, nodes[i].output)
161 nodes[i] = None
162 rem += 1
164 if recursive:
165 # Handles subgraphs.
166 for i in range(len(nodes)): # pylint: disable=C0200
167 node = nodes[i]
168 if node is None or not (node.attribute): # pylint: disable=C0325
169 continue
170 nodes[i] = _apply_remove_node_fct_node(
171 onnx_remove_node_identity,
172 node, recursive=True, debug_info=debug_info + [node.name])
174 # Finally create the new graph.
175 nodes = list(filter(lambda n: n is not None, nodes))
176 if len(nodes) == 0:
177 # something went wrong
178 nodes = list(graph.node)
179 if is_function:
180 logger.debug("onnx_remove_node_identity:end function with %d nodes.",
181 len(nodes))
182 return make_function(
183 onnx_model.domain, onnx_model.name,
184 onnx_model.input, onnx_model.output, nodes,
185 opset_imports=onnx_model.opset_import,
186 attributes=onnx_model.attribute,
187 doc_string=onnx_model.doc_string)
189 graph = make_graph(nodes, onnx_model.name,
190 onnx_model.input, onnx_model.output,
191 onnx_model.initializer)
193 graph.value_info.extend(onnx_model.value_info) # pylint: disable=E1101
194 logger.debug("onnx_remove_node_identity: end graph with %d nodes.",
195 len(nodes))
196 return graph