Coverage for mlprodict/onnx_tools/optim/_onnx_optimisation_common.py: 93%
122 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Common functions to reduce the number of
4nodes of an :epkg:`ONNX` graphs.
5"""
6from onnx.helper import make_graph, make_model, make_attribute
7from onnx import AttributeProto, NodeProto, ValueInfoProto
10def _apply_optimisation_on_graph(fct, onnx_model, recursive=True, debug_info=None,
11 **kwargs):
12 """
13 Applies an optimisation function *fct* on a graph
14 and not on the model.
16 @param fct function to optimize like
17 @see fn onnx_remove_node_identity
18 @param onnx_model onnx model
19 @param recursive looks into subgraphs
20 @param debug_info debug information (private)
21 @param kwargs additional parameters
22 @return new onnx _model
23 """
24 if hasattr(onnx_model, 'graph'):
25 if debug_info is None:
26 debug_info = []
27 graph = fct(
28 onnx_model.graph, debug_info=debug_info + ['GRAPH'],
29 **kwargs)
30 new_model = make_model(graph, functions=onnx_model.functions)
31 new_model.ir_version = onnx_model.ir_version
32 new_model.producer_name = onnx_model.producer_name
33 new_model.producer_version = onnx_model.producer_version
34 new_model.domain = onnx_model.domain
35 new_model.model_version = onnx_model.model_version
36 new_model.doc_string = onnx_model.doc_string
37 if hasattr(onnx_model, 'value_info'):
38 graph.value_info.extend(onnx_model.value_info) # pragma: no cover
39 while len(new_model.opset_import) > 0: # pylint: disable=E1101
40 new_model.opset_import.pop() # pylint: disable=E1101
41 for oimp in onnx_model.opset_import:
42 op_set = new_model.opset_import.add() # pylint: disable=E1101
43 op_set.domain = oimp.domain
44 op_set.version = oimp.version
45 return new_model
46 raise TypeError( # pragma: no cover
47 f"This function only works on 'ModelProto' anod not not on {type(onnx_model)}.")
50def _apply_remove_node_fct_node(fct, node, recursive, debug_info):
51 """
52 Applies an optimizing function on a subgraphs.
54 @param node onnx node
55 @param recursive does it in subgraphs as well
56 @return new node
57 """
58 if not hasattr(node, 'attribute'):
59 return node # pragma: no cover
60 modified = 0
61 new_atts = []
62 for att in node.attribute:
63 if att.name in ('body', 'then_branch', 'else_branch'):
64 new_body = fct(
65 att.g, recursive=recursive,
66 debug_info=debug_info + [att.name])
67 new_atts.append(_make_att_graph(att.name, new_body))
68 modified += 1
69 else:
70 new_atts.append(att)
71 if modified > 0:
72 new_node = _make_node(node.op_type, node.input,
73 node.output, name=node.name,
74 attributes=new_atts)
75 return new_node
76 return node
79def _make_node(op_type, inputs, outputs, name=None, doc_string=None,
80 domain=None, attributes=None):
81 """
82 Constructs a NodeProto.
84 :param op_type: (string): The name of the operator to construct
85 :param inputs: list of input names
86 :param outputs: list of output names
87 :param name: optional unique identifier for NodeProto
88 :param doc_string: optional documentation
89 string for NodeProto
90 :param domain: optional domain for NodeProto.
91 If it's None, we will just use default domain (which is empty)
92 :param attributes: the attributes of the node. The acceptable values
93 are documented in :epkg:`make_attribute`.
94 :return: node
95 """
96 node = NodeProto()
97 node.op_type = op_type
98 node.input.extend(inputs) # pylint: disable=E1101
99 node.output.extend(outputs) # pylint: disable=E1101
100 if name:
101 node.name = name
102 if doc_string:
103 node.doc_string = doc_string # pragma: no cover
104 if domain is not None:
105 node.domain = domain
106 if isinstance(attributes, dict):
107 if len(attributes) > 0: # pragma: no cover
108 node.attribute.extend( # pylint: disable=E1101
109 make_attribute(key, value)
110 for key, value in sorted(attributes.items()))
111 elif attributes:
112 for att in attributes:
113 node.attribute.extend([att]) # pylint: disable=E1101
114 return node
117def _replace(name, old_name, new_name):
118 if isinstance(old_name, dict) and new_name is None:
119 return old_name.get(name, name)
120 if name == old_name:
121 return new_name
122 return name
125def _rename_node_input(onnx_node, old_name, new_name=None):
126 """
127 Renames an input from a node.
129 @param onnx_node onnx_node
130 @param old_name old name
131 @param new_name new name or None if *old_name* is a dictionary
132 @return new node
133 """
134 inputs = [_replace(name, old_name, new_name) for name in onnx_node.input]
135 outputs = list(onnx_node.output)
136 if hasattr(onnx_node, 'attribute'):
137 new_atts = []
138 for att in onnx_node.attribute:
139 if (att.type == AttributeProto.GRAPH and # pylint: disable=E1101
140 hasattr(att, 'g') and att.g is not None):
141 new_body = _rename_graph_input(att.g, old_name, new_name)
142 attr = AttributeProto()
143 attr.name = att.name
144 attr.g.CopyFrom(new_body) # pylint: disable=E1101
145 attr.type = AttributeProto.GRAPH # pylint: disable=E1101
146 new_atts.append(attr)
147 else:
148 new_atts.append(att)
149 atts = new_atts
150 else:
151 atts = None # pragma: no cover
152 node = _make_node(
153 onnx_node.op_type, inputs, outputs, name=onnx_node.name,
154 domain=onnx_node.domain, attributes=atts)
155 return node
158def _copy_value_info_proto(new_name, obj):
159 value_info = ValueInfoProto()
160 value_info.name = new_name
161 value_info.type.CopyFrom(obj.type) # pylint: disable=E1101
162 if obj.type.doc_string:
163 value_info.doc_string = obj.type.doc_string
164 return value_info
167def _rename_graph_output(graph, old_name, new_name):
168 """
169 Renames an output and adds an *Identity* node
170 to connect the dots.
172 @param graph ONNX graph
173 @return modified graph
174 """
175 outputs = []
176 for o in graph.output:
177 if old_name != o.name:
178 outputs.append(o)
179 else:
180 outputs.append(_copy_value_info_proto(new_name, o))
181 nodes = list(graph.node)
182 nodes.append(_make_node('Identity', [old_name], [new_name]))
183 new_graph = make_graph(nodes, graph.name, graph.input, outputs,
184 graph.initializer)
185 new_graph.value_info.extend(graph.value_info) # pylint: disable=E1101
186 return new_graph
189def _rename_graph_input(graph, old_name, new_name):
190 """
191 Renames an input and adds an *Identity* node
192 to connect the dots.
194 @param graph ONNX graph
195 @return modified graph
196 """
197 inputs = []
198 for i in graph.input:
199 if old_name != i.name:
200 inputs.append(i)
201 else:
202 inputs.append(_copy_value_info_proto(new_name, i))
203 nodes = list(graph.node)
204 nodes.append(_make_node('Identity', [new_name], [old_name]))
205 new_graph = make_graph(nodes, graph.name, inputs, graph.output,
206 graph.initializer)
207 new_graph.value_info.extend(graph.value_info) # pylint: disable=E1101
208 return new_graph
211def _make_att_graph(name, new_body):
212 attr = AttributeProto()
213 attr.name = name
214 attr.g.CopyFrom(new_body) # pylint: disable=E1101
215 attr.type = AttributeProto.GRAPH # pylint: disable=E1101
216 return attr
219def _rename_node_output(onnx_node, old_name, new_name):
220 """
221 Renames an output from a node.
223 @param onnx_node onnx_node
224 @param old_name old name
225 @param new_name new name
226 @return new node
227 """
228 inputs = list(onnx_node.input)
229 outputs = [_replace(name, old_name, new_name) for name in onnx_node.output]
230 if hasattr(onnx_node, 'attribute'):
231 new_atts = []
232 for att in onnx_node.attribute:
233 if (att.type == AttributeProto.GRAPH and # pylint: disable=E1101
234 hasattr(att, 'g') and att.g is not None):
235 new_body = _rename_graph_output(att.g, old_name, new_name)
236 new_atts.append(_make_att_graph(att.name, new_body))
237 else:
238 new_atts.append(att)
239 atts = new_atts
240 else:
241 atts = None # pragma: no cover
242 node = _make_node(
243 onnx_node.op_type, inputs, outputs, name=onnx_node.name,
244 domain=onnx_node.domain, attributes=atts)
245 return node