Coverage for mlprodict/onnx_tools/optim/onnx_optimisation_identity.py: 100%

112 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Optimisation of :epkg:`ONNX` graphs. 

4""" 

5import logging 

6from onnx import FunctionProto, AttributeProto 

7from onnx.helper import make_graph, make_function 

8from ._onnx_optimisation_common import ( # pylint: disable=E0611 

9 _rename_node_input, 

10 _rename_node_output, 

11 _apply_optimisation_on_graph, 

12 _apply_remove_node_fct_node) 

13 

14 

15logger = logging.getLogger('onnx:optim') 

16 

17 

18def onnx_remove_node_identity(onnx_model, recursive=True, debug_info=None, **options): 

19 """ 

20 Removes as many *Identity* nodes as possible. 

21 The function looks into every node and subgraphs if 

22 *recursive* is True for identity node. Unless such a 

23 node directy connects one input to one output, it will 

24 be removed and every other node gets its inputs or 

25 outputs accordingly renamed. 

26 

27 :param onnx_model: onnx model 

28 :param recursive: looks into subgraphs 

29 :param debug_info: debug information (private) 

30 :param options: additional options (unused) 

31 :return: new onnx _model 

32 """ 

33 if debug_info is None: 

34 debug_info = [str(type(onnx_model)).rsplit( 

35 '.', maxsplit=1)[-1].strip("'>")] 

36 else: 

37 debug_info = (debug_info + 

38 [str(type(onnx_model)).rsplit('.', maxsplit=1)[-1].strip("'>")]) 

39 

40 if hasattr(onnx_model, 'graph'): 

41 return _apply_optimisation_on_graph( 

42 onnx_remove_node_identity, onnx_model, 

43 recursive=recursive, debug_info=debug_info, **options) 

44 

45 graph = onnx_model 

46 logger.debug("onnx_remove_node_identity:begin with %d nodes.", 

47 len(graph.node)) 

48 is_function = isinstance(graph, FunctionProto) 

49 

50 if is_function: 

51 inputs = set(graph.input) 

52 outputs = set(graph.output) 

53 else: 

54 inputs = set(i.name for i in graph.input) 

55 inits = set(i.name for i in graph.initializer) 

56 inputs_inits = inputs.union(inits) 

57 outputs = set(o.name for o in graph.output) 

58 

59 def retrieve_idnodes(graph, existing_nodes): 

60 idnodes = [] 

61 for i, exnode in enumerate(existing_nodes): 

62 if exnode is None: 

63 continue 

64 if exnode.op_type == 'Identity': 

65 input = exnode.input[0] 

66 output = exnode.output[0] 

67 idnodes.append((i, exnode, input, output)) 

68 return idnodes 

69 

70 # add to output the list of local variables in subgraphs 

71 def append_local_variable(graph, known=None, subgraph=True): 

72 if known is None: 

73 known = set() 

74 else: 

75 known = known.copy() 

76 local_var = set() 

77 if isinstance(graph, FunctionProto): 

78 known = set(graph.input) 

79 else: 

80 known = set(i.name for i in graph.input) 

81 known |= set(i.name for i in graph.initializer) 

82 for node in graph.node: 

83 for i in node.input: 

84 if i not in known and subgraph: 

85 local_var.add(i) 

86 for o in node.output: 

87 known.add(o) 

88 for att in node.attribute: 

89 if (att.type == AttributeProto.GRAPH and # pylint: disable=E1101 

90 hasattr(att, 'g') and att.g is not None): 

91 lv = append_local_variable(att.g, known) 

92 local_var |= lv 

93 return local_var 

94 

95 local_vars = append_local_variable(graph, subgraph=False) 

96 logger.debug('onnx_remove_node_identity:local_vars:%r', local_vars) 

97 ext_outputs = outputs | local_vars 

98 

99 nodes = list(graph.node) 

100 rem = 1 

101 while rem > 0: 

102 rem = 0 

103 idnodes = retrieve_idnodes(graph, nodes) 

104 restart = False 

105 for i, _, inp, out in idnodes: 

106 if restart: 

107 break # pragma: no cover 

108 if nodes[i] is None: 

109 # Already removed. 

110 continue # pragma: no cover 

111 if inp in inputs_inits and out in ext_outputs: 

112 # Cannot be removed. 

113 continue 

114 if not restart and out not in ext_outputs: 

115 # We cannot change an output name. 

116 for j in range(len(nodes)): # pylint: disable=C0200 

117 if nodes[j] is None: 

118 continue 

119 if out in nodes[j].input: 

120 logger.debug('onnx_remove_node_identity:' 

121 '_rename_node_input:%s:%r->%r:' 

122 'out=%r:inp=%r', 

123 nodes[j].op_type, nodes[j].input, 

124 nodes[j].output, out, inp) 

125 nodes[j] = _rename_node_input(nodes[j], out, inp) 

126 rem += 1 

127 if nodes[j].op_type == 'Identity': 

128 restart = True # pragma: no cover 

129 logger.debug('onnx_remove_node_identity:1:remove:%s:%r->%r:', 

130 nodes[i].op_type, nodes[i].input, nodes[i].output) 

131 nodes[i] = None 

132 rem += 1 

133 continue 

134 if not restart and inp not in inputs_inits and inp not in ext_outputs: 

135 # We cannot change an input name or an output name. 

136 for j in range(len(nodes)): # pylint: disable=C0200 

137 if nodes[j] is None: 

138 continue 

139 if inp in nodes[j].output: 

140 logger.debug('onnx_remove_node_identity:' 

141 '_rename_node_output:%s:%r->%r:' 

142 'inp=%r:out=%r', 

143 nodes[j].op_type, nodes[j].input, 

144 nodes[j].output, inp, out) 

145 nodes[j] = _rename_node_output(nodes[j], inp, out) 

146 rem += 1 

147 if nodes[j].op_type == 'Identity': 

148 restart = True # pragma: no cover 

149 if inp in nodes[j].input: 

150 logger.debug('onnx_remove_node_identity:' 

151 '_rename_node_input:%s:%r->%r:' 

152 'inp=%r:out=%r', 

153 nodes[j].op_type, nodes[j].input, 

154 nodes[j].output, inp, out) 

155 nodes[j] = _rename_node_input(nodes[j], inp, out) 

156 rem += 1 

157 if nodes[j].op_type == 'Identity': 

158 restart = True 

159 logger.debug('onnx_remove_node_identity:2:remove:%s:%r->%r:', 

160 nodes[i].op_type, nodes[i].input, nodes[i].output) 

161 nodes[i] = None 

162 rem += 1 

163 

164 if recursive: 

165 # Handles subgraphs. 

166 for i in range(len(nodes)): # pylint: disable=C0200 

167 node = nodes[i] 

168 if node is None or not (node.attribute): # pylint: disable=C0325 

169 continue 

170 nodes[i] = _apply_remove_node_fct_node( 

171 onnx_remove_node_identity, 

172 node, recursive=True, debug_info=debug_info + [node.name]) 

173 

174 # Finally create the new graph. 

175 nodes = list(filter(lambda n: n is not None, nodes)) 

176 if len(nodes) == 0: 

177 # something went wrong 

178 nodes = list(graph.node) 

179 if is_function: 

180 logger.debug("onnx_remove_node_identity:end function with %d nodes.", 

181 len(nodes)) 

182 return make_function( 

183 onnx_model.domain, onnx_model.name, 

184 onnx_model.input, onnx_model.output, nodes, 

185 opset_imports=onnx_model.opset_import, 

186 attributes=onnx_model.attribute, 

187 doc_string=onnx_model.doc_string) 

188 

189 graph = make_graph(nodes, onnx_model.name, 

190 onnx_model.input, onnx_model.output, 

191 onnx_model.initializer) 

192 

193 graph.value_info.extend(onnx_model.value_info) # pylint: disable=E1101 

194 logger.debug("onnx_remove_node_identity: end graph with %d nodes.", 

195 len(nodes)) 

196 return graph