Coverage for mlprodict/onnx_tools/optim/onnx_optimisation_unused.py: 97%

75 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Optimisation of :epkg:`ONNX` graphs. 

4""" 

5import logging 

6from onnx import FunctionProto, GraphProto 

7from onnx.helper import make_graph, make_function 

8from ._onnx_optimisation_common import ( # pylint: disable=E0611 

9 _apply_optimisation_on_graph, _apply_remove_node_fct_node) 

10 

11 

12logger = logging.getLogger('onnx:optim') 

13 

14 

15def _process_node(node, data, edges, paths, prefix="", sep="::", path=None): 

16 node_name = prefix + node.name 

17 data[node_name, 1] = node 

18 path = [] if path is None else path.copy() 

19 paths[node_name, 1] = path 

20 path = path.copy() 

21 path.append(node_name) 

22 for inp in node.input: 

23 data[inp, 0] = node 

24 edges[(inp, 0), (node_name, 1)] = node 

25 paths[inp, 0] = path 

26 if '::' in node_name: 

27 # We need to link an input to the parent node 

28 # if the node is part of subgraph. 

29 # path_r = paths[inp, 0] 

30 if len(path) <= 1: 

31 raise RuntimeError( # pragma: no cover 

32 f"Unexpected path {path!r}.") 

33 edges[(inp, 0), (path[-2], 1)] = node 

34 

35 for out in node.output: 

36 data[out, 0] = node 

37 paths[out, 0] = node_name 

38 edges[(node_name, 1), (out, 0)] = node 

39 if len(node.attribute) > 0: 

40 for att in node.attribute: 

41 if not hasattr(att, 'g'): 

42 continue 

43 if not isinstance(att.g, GraphProto): 

44 continue 

45 for no in att.g.node: 

46 _process_node(no, data, edges, paths, 

47 prefix=node_name + sep, path=path) 

48 

49 

50def onnx_remove_node_unused(onnx_model, recursive=True, debug_info=None, **options): 

51 """ 

52 Removes unused nodes of the graph. An unused node 

53 is not involved in the output computation. 

54 

55 :param onnx_model: onnx model 

56 :param recursive: looks into subgraphs 

57 :param debug_info: debug information (private) 

58 :param options: unused 

59 :return: new onnx _model 

60 """ 

61 if debug_info is None: 

62 debug_info = [str(type(onnx_model)).rsplit( 

63 '.', maxsplit=1)[-1].strip("'>")] 

64 else: 

65 debug_info = (debug_info + 

66 [str(type(onnx_model)).rsplit('.', maxsplit=1)[-1].strip("'>")]) 

67 

68 if hasattr(onnx_model, 'graph'): 

69 return _apply_optimisation_on_graph( 

70 onnx_remove_node_unused, onnx_model, 

71 recursive=recursive, debug_info=debug_info, 

72 **options) 

73 

74 graph = onnx_model 

75 logger.debug("onnx_remove_node_unused:begin with %d nodes.", 

76 len(graph.node)) 

77 is_function = isinstance(graph, FunctionProto) 

78 data = {} 

79 valid = {} 

80 edges = {} 

81 paths = {} 

82 

83 if not is_function: 

84 for init in graph.initializer: 

85 data[init.name, 0] = init 

86 

87 for node in graph.node: 

88 _process_node(node, data, edges, paths) 

89 

90 for out in graph.output: 

91 valid[out if is_function else out.name, 0] = True 

92 

93 modif = 1 

94 while modif > 0: 

95 modif = 0 

96 for e1, e2 in edges: # pylint: disable=E1141 

97 if valid.get(e2, False) and not valid.get(e1, False): 

98 valid[e1] = True 

99 modif += 1 

100 

101 new_nodes = [n for n in graph.node if (n.name, 1) in valid] 

102 if not is_function: 

103 new_inits = [n for n in graph.initializer if (n.name, 0) in valid] 

104 

105 if recursive: 

106 # Handles subgraphs. 

107 for i in range(len(new_nodes)): # pylint: disable=C0200 

108 node = new_nodes[i] 

109 if node is None or not (node.attribute): # pylint: disable=C0325 

110 continue 

111 new_nodes[i] = _apply_remove_node_fct_node( 

112 onnx_remove_node_unused, 

113 node, recursive=True, debug_info=debug_info + [node.name]) 

114 

115 # Finally create the new graph. 

116 nodes = list(filter(lambda n: n is not None, new_nodes)) 

117 if is_function: 

118 logger.debug("onnx_remove_node_unused:end function with %d nodes.", 

119 len(nodes)) 

120 return make_function( 

121 onnx_model.domain, onnx_model.name, 

122 onnx_model.input, onnx_model.output, nodes, 

123 opset_imports=onnx_model.opset_import, 

124 attributes=onnx_model.attribute, 

125 doc_string=onnx_model.doc_string) 

126 graph = make_graph(nodes, onnx_model.name, 

127 onnx_model.input, onnx_model.output, 

128 new_inits) 

129 graph.value_info.extend(onnx_model.value_info) # pylint: disable=E1101 

130 logger.debug("onnx_remove_node_unused:end graph with %d nodes.", 

131 len(nodes)) 

132 return graph