Coverage for mlprodict/onnx_tools/compress.py: 80%
122 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Functions to simplify, compress an ONNX graph.
5.. versionadded:: 0.9
6"""
7import logging
8from onnx import ModelProto, GraphProto, FunctionProto
9from onnx.helper import (
10 make_function, make_model, make_value_info, make_graph,
11 make_tensor_type_proto, make_node, make_operatorsetid)
14logger = logging.getLogger('onnx:compress')
17def _check_expression(expe):
18 att = expe.attribute[0].g
19 inputs = [i.name for i in att.input]
20 if list(expe.input) != inputs:
21 raise RuntimeError( # pragma: no cover
22 f'Name mismatch in node Expression {expe.input!r} != {inputs!r}.')
23 outputs = [o.name for o in att.output]
24 if list(expe.output) != outputs:
25 raise RuntimeError( # pragma: no cover
26 f'Name mismatch in node Expression {expe.input!r} != {inputs!r}.')
29def _fuse_node(o, node, node_next):
30 """
31 Merges two nodes having one input/output in common.
33 :param o: output name
34 :param node: first node (it outputs the results)
35 :param node_next: second node (it ingests the result)
36 :return: merged node
37 """
38 type_expression = ('mlprodict', 'Expression')
39 if list(node.output) != [o]:
40 raise RuntimeError( # pragma: no cover
41 f"The only output of the first node should be {[o]!r} not {node.output!r}.")
42 cannot_do = {('', 'If'), ('', 'Loop'), ('', 'Scan')}
43 key1 = node.domain, node.op_type
44 if key1 in cannot_do:
45 return None
46 key2 = node_next.domain, node_next.op_type
47 if key2 in cannot_do:
48 return None
50 if key1 == type_expression:
51 _check_expression(node)
52 if key2 == type_expression:
53 _check_expression(node_next)
55 graph = None
57 if node.domain == '' and node_next.domain == '':
58 # Simple case
59 inputs = [make_value_info(name, make_tensor_type_proto(0, []))
60 for name in node.input]
61 outputs = [make_value_info(name, make_tensor_type_proto(0, []))
62 for name in node_next.output]
63 graph = make_graph([node, node_next], "expression", inputs, outputs)
65 elif key1 == type_expression and node_next.domain == '':
66 att = node.attribute[0].g
67 inputs = att.input
68 outputs = [make_value_info(name, make_tensor_type_proto(0, []))
69 for name in node_next.output]
70 graph = make_graph(list(att.node) + [node_next],
71 "expression", inputs, outputs)
73 elif node.domain == '' and key2 == type_expression:
74 att = node_next.attribute[0].g
75 inputs = [make_value_info(name, make_tensor_type_proto(0, []))
76 for name in node.input]
77 outputs = att.output
78 graph = make_graph([node] + list(att.node),
79 "expression", inputs, outputs)
81 elif key1 == type_expression and key2 == type_expression:
82 att1 = node.attribute[0].g
83 att2 = node_next.attribute[0].g
84 inputs = att1.input
85 outputs = att2.output
86 graph = make_graph(list(att1.node) + list(att2.node),
87 "expression", inputs, outputs)
89 if graph is not None:
90 new_node = make_node(
91 'Expression', node.input, node_next.output, domain='mlprodict',
92 expression=graph)
93 return new_node
95 raise NotImplementedError( # pragma: no cover
96 "Unable to merge nodes '%s/%s' and '%s/%s'." % (
97 node.domain, node.op_type, node_next.domain, node_next.op_type))
100def _compress_nodes_once(nodes, verbose=0):
101 """
102 Compresses a sequence of node to make it more
103 readable. If possible, it creates a node `Expression`
104 with a graph as an attribute.
106 :param nodes: sequence of nodes to compress
107 :return: compressed sequence of nodes
108 """
109 # check that a result is used only once
110 order = {}
111 results = {}
112 for node in list(nodes):
113 order[id(node)] = (len(order), node)
114 for name in node.input:
115 if name in results:
116 results[name] += 1
117 else:
118 results[name] = 1
120 once = {k: v for k, v in results.items() if v == 1}
121 if len(once) == 0:
122 return nodes
124 once_nodes_o = {}
125 once_nodes_i = {}
126 for node in nodes:
127 if len(node.output) != 1:
128 continue
129 for o in node.output:
130 if o in once:
131 once_nodes_o[o] = node
132 for i in node.input:
133 if i in once:
134 once_nodes_i[i] = node
136 if len(once_nodes_o) == 0:
137 return nodes
139 if verbose > 0:
140 logger.debug(
141 "Results to compress: %r", list(sorted(once_nodes_o)))
143 while len(once_nodes_o) > 0:
144 o, node = once_nodes_o.popitem()
145 node_next = once_nodes_i[o]
146 new_node = _fuse_node(o, node, node_next)
147 if new_node is None:
148 # nothing can be done
149 continue
150 once_nodes_o.update({o: new_node for o in node_next.output
151 if o in once_nodes_o})
152 once_nodes_i.update({i: new_node for i in node.input
153 if i in once_nodes_i})
154 order[id(new_node)] = (order[id(node)][0], new_node)
155 del order[id(node)]
156 del order[id(node_next)]
158 ordered = list(sorted((v[0], k, v[1]) for k, v in order.items()))
159 return [v[-1] for v in ordered]
162def _compress_nodes(nodes, verbose=0):
163 """
164 Compresses a sequence of node to make it more
165 readable. If possible, it creates a node `Expression`
166 with a graph as an attribute.
168 :param nodes: sequence of nodes to compress
169 :return: compressed sequence of nodes
170 """
171 return _compress_nodes_once(nodes, verbose=verbose)
174def compress_proto(proto, verbose=0):
175 """
176 Compresses a :epkg:`ModelProto`, :epkg:`FunctionProto`,
177 :epkg:`GraphProto`. The function detects nodes outputting
178 results only used once. It then fuses it with the node
179 taking it as an input.
181 :param proto: :epkg:`ModelProto`, :epkg:`FunctionProto`,
182 :epkg:`GraphProto`
183 :param verbose: logging
184 :return: same type
186 .. versionadded:: 0.9
187 """
188 if isinstance(proto, FunctionProto):
189 nodes = _compress_nodes(proto.node, verbose=verbose)
190 if len(nodes) == len(proto.node):
191 # unchanged
192 return proto
193 if verbose:
194 logger.debug( # pragma: no cover
195 "Compressed function %r/%r from %d nodes to %d.",
196 proto.domain, proto.name, len(proto.node), len(nodes))
197 opsets = {op.domain: op.version for op in proto.opset_import}
198 opsets['mlprodict'] = 1
200 return make_function(
201 proto.domain, proto.name,
202 proto.input, proto.output, nodes,
203 opset_imports=[
204 make_operatorsetid(k, v) for k, v in opsets.items()],
205 attributes=proto.attribute,
206 doc_string=proto.doc_string)
208 if isinstance(proto, ModelProto):
209 modified = 0
210 new_graph = compress_proto(proto.graph, verbose=verbose)
211 if id(new_graph) != id(proto.graph):
212 modified += 1
213 fcts = []
214 for f in proto.functions:
215 new_f = compress_proto(f, verbose=verbose)
216 if id(new_f) != id(f):
217 modified += 1
218 fcts.append(new_f)
219 if modified == 0:
220 return proto
221 opsets = {op.domain: op.version for op in proto.opset_import}
222 opsets['mlprodict'] = 1
223 if verbose:
224 logger.debug( # pragma: no cover
225 "Compressed model %s modified=%d.", proto.name, modified)
226 return make_model(
227 new_graph, functions=fcts,
228 opset_imports=[
229 make_operatorsetid(k, v) for k, v in opsets.items()],
230 producer_name=proto.producer_name,
231 producer_version=proto.producer_version,
232 ir_version=proto.ir_version,
233 doc_string=proto.doc_string,
234 domain=proto.domain,
235 model_version=proto.model_version)
237 if isinstance(proto, GraphProto):
238 nodes = _compress_nodes(proto.node, verbose=verbose)
239 if len(nodes) == len(proto.node):
240 # unchanged
241 return proto
242 if verbose:
243 logger.debug( # pragma: no cover
244 "Compressed graph %s from %d nodes to %d.",
245 proto.name, len(proto.node), len(nodes))
246 return make_graph(
247 nodes, proto.name, proto.input, proto.output,
248 proto.initializer, sparse_initializer=proto.sparse_initializer)
250 raise TypeError( # pragma: no cover
251 "Unexpected type for proto %r, it should ModelProto, "
252 "GraphProto or FunctionProto." % type(proto))