Coverage for mlprodict/onnx_tools/onnx_tools.py: 92%
138 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Functions to manipulate ONNX file.
4"""
5from onnx import helper, AttributeProto
8def find_node_name(model, name):
9 """
10 Finds a node by its name.
11 :param model: onnx graph
12 :param name: node name
13 :return: node pointer
14 """
15 if not hasattr(model, "graph"):
16 raise TypeError( # pragma: no cover
17 f"Parameter model is not an ONNX model but {type(model)}")
18 for node in model.graph.node:
19 if node.name == name:
20 return node
21 return None # pragma: no cover
24def find_node_input_name(node, name):
25 """
26 Finds a node input by its name.
27 :param node: onnx node
28 :param name: node name
29 :return: input index
30 """
31 for i, inode in enumerate(node.input.node):
32 if inode.name == name:
33 return i
34 return -1
37def insert_node(model, op_type, node, input_index=0, new_name=None, **attrs):
38 """
39 Inserts a node before one node input.
40 :param model: onnx graph
41 :param op_type:
42 :param node: node or node name
43 :param input_index: input index or input name
44 :param attrs: node attributes
45 :return: updated graph
46 """
47 if isinstance(node, str):
48 inode = find_node_name(model, node)
49 else:
50 inode = node
51 if isinstance(input_index, str):
52 input_index_ = find_node_input_name(node, input_index)
53 if input_index_ == -1:
54 raise RuntimeError( # pragma: no cover
55 "Unable to find input_index %r in node %r." % (
56 input_index, node.name)) # pylint: disable=E1120
57 input_index = input_index_
59 # guess a new name
60 names = []
61 for n in model.graph.node:
62 names.extend(n.input)
63 names.extend(n.output)
64 names = set(names)
65 if new_name is None:
66 new_name = op_type.lower()
67 root_name = new_name
68 i = 0
69 while new_name in names:
70 new_name = "%s_%d" % (root_name, i)
71 i += 1
73 new_node = helper.make_node(
74 op_type, [inode.input[input_index]], [new_name], **attrs)
75 inode.input[input_index] = new_name
76 keep_nodes = list(model.graph.node)
77 keep_nodes.append(new_node)
78 keep_nodes = ensure_topological_order(
79 model.graph.input, model.graph.initializer, keep_nodes)
81 graph = helper.make_graph(
82 keep_nodes, model.graph.name, model.graph.input,
83 model.graph.output, model.graph.initializer)
84 onnx_model = helper.make_model(graph, functions=model.functions)
85 onnx_model.ir_version = model.ir_version
86 onnx_model.producer_name = model.producer_name
87 onnx_model.producer_version = model.producer_version
88 onnx_model.domain = model.domain
89 onnx_model.model_version = model.model_version
90 onnx_model.doc_string = model.doc_string
91 if len(model.metadata_props) > 0:
92 values = {p.key: p.value for p in model.metadata_props}
93 helper.set_model_props(onnx_model, values)
95 del onnx_model.opset_import[:] # pylint: disable=E1101
96 for oimp in model.opset_import:
97 op_set = onnx_model.opset_import.add() # pylint: disable=E1101
98 op_set.domain = oimp.domain
99 op_set.version = oimp.version
101 if len(onnx_model.graph.input) != len(model.graph.input): # pylint: disable=E1101
102 raise RuntimeError( # pragma: no cover
103 "Input mismatch {} != {}".format(
104 len(onnx_model.input), len(model.input))) # pylint: disable=E1101
105 return onnx_model
108def ensure_topological_order(inputs, initializers, nodes):
109 """
110 Ensures and modifies the order of nodes to have
111 a topological order (every node in the list
112 can only be an input for a node later in this list).
113 The function raises an exception if a cycle is detected.
115 :param inputs: graph inputs:
116 :param initializers: graph initializers
117 :param nodes: graph nodes
118 :return: list ordered nodes
119 """
120 order = {}
121 for inp in inputs:
122 name = inp.name
123 order[name] = 0
124 for inp in initializers:
125 name = inp.name
126 order[name] = 0
127 n_iter = 0
128 while n_iter < len(nodes) * 2:
129 n_iter += 1
130 missing_names = set()
131 missing_ops = []
132 for node in nodes:
133 maxi = 0
134 for name in node.input:
135 if name in order:
136 maxi = max(maxi, order[name])
137 else:
138 maxi = None
139 missing_names.add(name)
140 break
141 if maxi is None:
142 missing_ops.append(node)
143 continue
144 key = id(node)
145 if key in order:
146 continue
147 maxi += 1
148 order[key] = maxi
149 maxi += 1
150 for name in node.output:
151 if name in order:
152 raise RuntimeError( # pragma: no cover
153 "Unable to sort a node (cycle). An output was "
154 "already ordered %r (iteration=%r)." % (
155 name, n_iter))
156 order[name] = maxi
157 if len(missing_names) == 0:
158 continue
160 if len(missing_ops) > 0: # pragma: no cover
161 def nstr(name):
162 if name in order:
163 return "%s#%d" % (name, order[name])
164 return name
165 rows = ["%s(%s) -> [%s]" % (
166 n.name or n.op_type,
167 ', '.join(map(nstr, n.input)),
168 ', '.join(n.output))
169 for n in missing_ops]
170 rows.insert(0, "")
171 rows.append("--")
172 rows.append("--all-nodes--")
173 rows.append("--")
174 rows.extend("%s(%s) -> [%s]" % (
175 n.name or n.op_type,
176 ', '.join(map(nstr, n.input)),
177 ', '.join(n.output))
178 for n in nodes)
179 raise RuntimeError(
180 "After %d iterations for %d nodes, still unable "
181 "to sort names %r. The graph may be disconnected. "
182 "List of operators: %s" % (
183 n_iter, len(nodes), missing_names,
184 "\n".join(rows)))
186 # Update order
187 topo = [(order[id(node)], str(id(node))) for node in nodes]
188 topo.sort()
189 map_nodes = {str(id(node)): node for node in nodes}
190 return [map_nodes[_[1]] for _ in topo]
193def enumerate_onnx_names(onx):
194 """
195 Enumerates all existing names in one ONNX graph
196 (:epkg:`ModelProto`, :epkg:`FunctionProto`, :epkg:`GraphProto`).
197 The function is recursive.
199 :param onx: one onnx object
200 :return: iterator on names
201 """
202 if hasattr(onx, 'graph'):
203 for i in onx.graph.initializer:
204 yield i.name
205 for i in onx.graph.input:
206 yield i.name
207 for i in onx.graph.output:
208 yield i.name
209 nodes = onx.graph.node
210 elif hasattr(onx, 'initializer'):
211 for i in onx.initializer:
212 yield i.name
213 for i in onx.input:
214 yield i.name
215 for i in onx.output:
216 yield i.name
217 nodes = onx.node
218 else:
219 if hasattr(onx, 'input'):
220 for i in onx.input:
221 yield i
222 if hasattr(onx, 'output'):
223 for i in onx.output:
224 yield i
225 nodes = onx.node
226 for node in nodes:
227 for i in node.input:
228 yield i
229 for o in node.output:
230 yield o
231 for att in node.attribute:
232 if (att.type == AttributeProto.GRAPH and # pylint: disable=E0611,E1101
233 hasattr(att, 'g') and att.g is not None):
234 for n in enumerate_onnx_names(att.g):
235 yield n
238def enumerate_onnx_nodes(onx):
239 """
240 Enumerates all nodes in one ONNX graph
241 (:epkg:`ModelProto`, :epkg:`FunctionProto`, :epkg:`GraphProto`).
242 The function is recursive.
244 :param onx: one onnx object
245 :return: iterator on names
246 """
247 if isinstance(onx, list):
248 nodes = onx
249 elif hasattr(onx, 'graph'):
250 nodes = onx.graph.node
251 else:
252 nodes = onx.node
253 for node in nodes:
254 yield node
255 for att in node.attribute:
256 if (att.type == AttributeProto.GRAPH and # pylint: disable=E0611,E1101
257 hasattr(att, 'g') and att.g is not None):
258 for n in enumerate_onnx_nodes(att.g):
259 yield n