Coverage for mlprodict/onnx_tools/optim/onnx_optimisation_unused.py: 97%
75 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Optimisation of :epkg:`ONNX` graphs.
4"""
5import logging
6from onnx import FunctionProto, GraphProto
7from onnx.helper import make_graph, make_function
8from ._onnx_optimisation_common import ( # pylint: disable=E0611
9 _apply_optimisation_on_graph, _apply_remove_node_fct_node)
12logger = logging.getLogger('onnx:optim')
15def _process_node(node, data, edges, paths, prefix="", sep="::", path=None):
16 node_name = prefix + node.name
17 data[node_name, 1] = node
18 path = [] if path is None else path.copy()
19 paths[node_name, 1] = path
20 path = path.copy()
21 path.append(node_name)
22 for inp in node.input:
23 data[inp, 0] = node
24 edges[(inp, 0), (node_name, 1)] = node
25 paths[inp, 0] = path
26 if '::' in node_name:
27 # We need to link an input to the parent node
28 # if the node is part of subgraph.
29 # path_r = paths[inp, 0]
30 if len(path) <= 1:
31 raise RuntimeError( # pragma: no cover
32 f"Unexpected path {path!r}.")
33 edges[(inp, 0), (path[-2], 1)] = node
35 for out in node.output:
36 data[out, 0] = node
37 paths[out, 0] = node_name
38 edges[(node_name, 1), (out, 0)] = node
39 if len(node.attribute) > 0:
40 for att in node.attribute:
41 if not hasattr(att, 'g'):
42 continue
43 if not isinstance(att.g, GraphProto):
44 continue
45 for no in att.g.node:
46 _process_node(no, data, edges, paths,
47 prefix=node_name + sep, path=path)
50def onnx_remove_node_unused(onnx_model, recursive=True, debug_info=None, **options):
51 """
52 Removes unused nodes of the graph. An unused node
53 is not involved in the output computation.
55 :param onnx_model: onnx model
56 :param recursive: looks into subgraphs
57 :param debug_info: debug information (private)
58 :param options: unused
59 :return: new onnx _model
60 """
61 if debug_info is None:
62 debug_info = [str(type(onnx_model)).rsplit(
63 '.', maxsplit=1)[-1].strip("'>")]
64 else:
65 debug_info = (debug_info +
66 [str(type(onnx_model)).rsplit('.', maxsplit=1)[-1].strip("'>")])
68 if hasattr(onnx_model, 'graph'):
69 return _apply_optimisation_on_graph(
70 onnx_remove_node_unused, onnx_model,
71 recursive=recursive, debug_info=debug_info,
72 **options)
74 graph = onnx_model
75 logger.debug("onnx_remove_node_unused:begin with %d nodes.",
76 len(graph.node))
77 is_function = isinstance(graph, FunctionProto)
78 data = {}
79 valid = {}
80 edges = {}
81 paths = {}
83 if not is_function:
84 for init in graph.initializer:
85 data[init.name, 0] = init
87 for node in graph.node:
88 _process_node(node, data, edges, paths)
90 for out in graph.output:
91 valid[out if is_function else out.name, 0] = True
93 modif = 1
94 while modif > 0:
95 modif = 0
96 for e1, e2 in edges: # pylint: disable=E1141
97 if valid.get(e2, False) and not valid.get(e1, False):
98 valid[e1] = True
99 modif += 1
101 new_nodes = [n for n in graph.node if (n.name, 1) in valid]
102 if not is_function:
103 new_inits = [n for n in graph.initializer if (n.name, 0) in valid]
105 if recursive:
106 # Handles subgraphs.
107 for i in range(len(new_nodes)): # pylint: disable=C0200
108 node = new_nodes[i]
109 if node is None or not (node.attribute): # pylint: disable=C0325
110 continue
111 new_nodes[i] = _apply_remove_node_fct_node(
112 onnx_remove_node_unused,
113 node, recursive=True, debug_info=debug_info + [node.name])
115 # Finally create the new graph.
116 nodes = list(filter(lambda n: n is not None, new_nodes))
117 if is_function:
118 logger.debug("onnx_remove_node_unused:end function with %d nodes.",
119 len(nodes))
120 return make_function(
121 onnx_model.domain, onnx_model.name,
122 onnx_model.input, onnx_model.output, nodes,
123 opset_imports=onnx_model.opset_import,
124 attributes=onnx_model.attribute,
125 doc_string=onnx_model.doc_string)
126 graph = make_graph(nodes, onnx_model.name,
127 onnx_model.input, onnx_model.output,
128 new_inits)
129 graph.value_info.extend(onnx_model.value_info) # pylint: disable=E1101
130 logger.debug("onnx_remove_node_unused:end graph with %d nodes.",
131 len(nodes))
132 return graph