Coverage for mlprodict/onnx_tools/optim/onnx_optimisation_redundant.py: 99%
110 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Optimisation of :epkg:`ONNX` graphs.
4"""
5import copy
6import hashlib
7import logging
8from onnx import FunctionProto
9from onnx.helper import make_graph, make_function
10from ._onnx_optimisation_common import ( # pylint: disable=E0611
11 _rename_node_input,
12 _rename_node_output,
13 _apply_optimisation_on_graph,
14 _apply_remove_node_fct_node)
17logger = logging.getLogger('onnx:optim')
20def _hash_obj_content(obj, max_size=1000):
21 """
22 Hash the content of an object.
23 """
24 m = hashlib.sha256()
25 if hasattr(obj, 'op_type'):
26 # An operator.
27 m.update(obj.op_type.encode('ascii'))
28 m.update(len(obj.output).to_bytes(8, byteorder='big'))
29 for i in obj.input:
30 m.update(i.encode('ascii'))
31 if hasattr(obj, 'attribute'):
32 for att in obj.attribute:
33 m.update(att.name.encode('ascii'))
34 m.update(_hash_obj_content(att))
35 else:
36 # An initializer.
37 obj = copy.deepcopy(obj)
38 obj.name = ""
39 obj.doc_string = ""
40 m.update(obj.SerializeToString())
42 content = m.digest()
43 if len(content) > max_size:
44 content = content[:max_size]
45 return content
48def onnx_remove_node_redundant(onnx_model, recursive=True, debug_info=None,
49 max_hash_size=1000, **options):
50 """
51 Removes redundant part of the graph. A redundant part is
52 a set of nodes which takes the same inputs and produces
53 the same outputs. It first starts by looking into duplicated
54 initializers, then looks into nodes taking the same inputs
55 and sharing the same type and parameters.
57 @param onnx_model onnx model
58 @param recursive looks into subgraphs
59 @param debug_info debug information (private)
60 @param max_hash_size limit the size of a hash used to detect
61 identical subgraphs
62 @param options additional options (unused)
63 @return new onnx _model
64 """
65 if debug_info is None:
66 debug_info = [str(type(onnx_model)).rsplit(
67 '.', maxsplit=1)[-1].strip("'>")]
68 else:
69 debug_info = (debug_info +
70 [str(type(onnx_model)).rsplit('.', maxsplit=1)[-1].strip("'>")])
72 if hasattr(onnx_model, 'graph'):
73 return _apply_optimisation_on_graph(
74 onnx_remove_node_redundant, onnx_model,
75 recursive=recursive, debug_info=debug_info,
76 max_hash_size=max_hash_size, **options)
78 def _enumerate_rename_list_nodes_inputs(nodes, rename):
79 for i, node in enumerate(nodes):
80 if node is None:
81 yield False, i, None
82 continue
83 if any(set(node.input) & set(rename)):
84 yield True, i, _rename_node_input(node, rename)
85 continue
86 yield False, i, node
88 graph = onnx_model
89 logger.debug("onnx_remove_node_redundant:begin with %d nodes.",
90 len(graph.node))
91 is_function = isinstance(graph, FunctionProto)
93 # Detects duplicated initializers.
94 hashes = {}
95 names = []
96 rename = {}
97 if is_function:
98 new_inits = []
99 else:
100 for init in graph.initializer:
101 hs = _hash_obj_content(init, max_size=max_hash_size)
102 if hs in hashes:
103 # Already seen.
104 rename[init.name] = hashes[hs] # pragma: no cover
105 else:
106 # New.
107 hashes[hs] = init.name
108 names.append(init.name)
109 new_inits = [init for init in graph.initializer
110 if init.name in set(names)]
112 # Renames node inputs.
113 new_nodes = []
114 new_nodes = list(graph.node)
115 new_nodes = list(
116 _[2] for _ in _enumerate_rename_list_nodes_inputs(new_nodes, rename))
118 # Detects duplicated operators.
119 if is_function:
120 graph_outputs = set(graph.output)
121 else:
122 graph_outputs = set(o.name for o in graph.output)
123 node_hashes = {}
124 changed = 1
125 replace = {}
126 while changed > 0:
127 changed = 0
128 nnodes = len(new_nodes)
129 for i in range(nnodes):
130 if i in replace:
131 # Already removed.
132 continue
133 node = new_nodes[i]
134 hash = _hash_obj_content(node, max_size=max_hash_size)
135 if hash in node_hashes:
136 ni = node_hashes[hash]
137 if ni == i:
138 continue
139 replace[i] = ni
140 changed += 1
142 # Specifies what to rename.
143 # One exception: the output is one of the graph output.
144 rep = new_nodes[ni]
145 for old, nn in zip(node.output, rep.output):
146 if old in graph_outputs:
147 rename[nn] = old
148 new_nodes[ni] = _rename_node_output(
149 new_nodes[ni], nn, old)
150 else:
151 rename[old] = nn
153 # Renames inputs.
154 new_new_nodes = []
155 renew_index = set()
156 for changed, ci, node in _enumerate_rename_list_nodes_inputs(new_nodes, rename):
157 if changed:
158 renew_index.add(ci)
159 new_new_nodes.append(node)
160 new_nodes = new_new_nodes
162 # Renews hashes.
163 renew_hash = set(
164 k for k, v in node_hashes.items() if v in renew_index)
165 for hs in renew_hash:
166 del node_hashes[hs]
167 new_nodes[i] = None
168 else:
169 node_hashes[hash] = i
171 if recursive:
172 # Handles subgraphs.
173 for i in range(len(new_nodes)): # pylint: disable=C0200
174 node = new_nodes[i]
175 if node is None or not (node.attribute): # pylint: disable=C0325
176 continue
177 new_nodes[i] = _apply_remove_node_fct_node(
178 onnx_remove_node_redundant,
179 node, recursive=True, debug_info=debug_info + [node.name])
181 # Finally create the new graph.
182 nodes = list(filter(lambda n: n is not None, new_nodes))
183 if is_function:
184 logger.debug("onnx_remove_node_redundant:end function with %d nodes.",
185 len(nodes))
186 return make_function(
187 onnx_model.domain, onnx_model.name,
188 onnx_model.input, onnx_model.output, nodes,
189 opset_imports=onnx_model.opset_import,
190 attributes=onnx_model.attribute,
191 doc_string=onnx_model.doc_string)
193 graph = make_graph(nodes, onnx_model.name,
194 onnx_model.input, onnx_model.output,
195 new_inits)
197 graph.value_info.extend(onnx_model.value_info) # pylint: disable=E1101
198 logger.debug("onnx_remove_node_redundant:end graph with %d nodes.",
199 len(nodes))
200 return graph