Coverage for mlprodict/onnx_tools/optim/_onnx_optimisation_common.py: 93%

122 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Common functions to reduce the number of 

4nodes of an :epkg:`ONNX` graphs. 

5""" 

6from onnx.helper import make_graph, make_model, make_attribute 

7from onnx import AttributeProto, NodeProto, ValueInfoProto 

8 

9 

10def _apply_optimisation_on_graph(fct, onnx_model, recursive=True, debug_info=None, 

11 **kwargs): 

12 """ 

13 Applies an optimisation function *fct* on a graph 

14 and not on the model. 

15 

16 @param fct function to optimize like 

17 @see fn onnx_remove_node_identity 

18 @param onnx_model onnx model 

19 @param recursive looks into subgraphs 

20 @param debug_info debug information (private) 

21 @param kwargs additional parameters 

22 @return new onnx _model 

23 """ 

24 if hasattr(onnx_model, 'graph'): 

25 if debug_info is None: 

26 debug_info = [] 

27 graph = fct( 

28 onnx_model.graph, debug_info=debug_info + ['GRAPH'], 

29 **kwargs) 

30 new_model = make_model(graph, functions=onnx_model.functions) 

31 new_model.ir_version = onnx_model.ir_version 

32 new_model.producer_name = onnx_model.producer_name 

33 new_model.producer_version = onnx_model.producer_version 

34 new_model.domain = onnx_model.domain 

35 new_model.model_version = onnx_model.model_version 

36 new_model.doc_string = onnx_model.doc_string 

37 if hasattr(onnx_model, 'value_info'): 

38 graph.value_info.extend(onnx_model.value_info) # pragma: no cover 

39 while len(new_model.opset_import) > 0: # pylint: disable=E1101 

40 new_model.opset_import.pop() # pylint: disable=E1101 

41 for oimp in onnx_model.opset_import: 

42 op_set = new_model.opset_import.add() # pylint: disable=E1101 

43 op_set.domain = oimp.domain 

44 op_set.version = oimp.version 

45 return new_model 

46 raise TypeError( # pragma: no cover 

47 f"This function only works on 'ModelProto' anod not not on {type(onnx_model)}.") 

48 

49 

50def _apply_remove_node_fct_node(fct, node, recursive, debug_info): 

51 """ 

52 Applies an optimizing function on a subgraphs. 

53 

54 @param node onnx node 

55 @param recursive does it in subgraphs as well 

56 @return new node 

57 """ 

58 if not hasattr(node, 'attribute'): 

59 return node # pragma: no cover 

60 modified = 0 

61 new_atts = [] 

62 for att in node.attribute: 

63 if att.name in ('body', 'then_branch', 'else_branch'): 

64 new_body = fct( 

65 att.g, recursive=recursive, 

66 debug_info=debug_info + [att.name]) 

67 new_atts.append(_make_att_graph(att.name, new_body)) 

68 modified += 1 

69 else: 

70 new_atts.append(att) 

71 if modified > 0: 

72 new_node = _make_node(node.op_type, node.input, 

73 node.output, name=node.name, 

74 attributes=new_atts) 

75 return new_node 

76 return node 

77 

78 

79def _make_node(op_type, inputs, outputs, name=None, doc_string=None, 

80 domain=None, attributes=None): 

81 """ 

82 Constructs a NodeProto. 

83 

84 :param op_type: (string): The name of the operator to construct 

85 :param inputs: list of input names 

86 :param outputs: list of output names 

87 :param name: optional unique identifier for NodeProto 

88 :param doc_string: optional documentation 

89 string for NodeProto 

90 :param domain: optional domain for NodeProto. 

91 If it's None, we will just use default domain (which is empty) 

92 :param attributes: the attributes of the node. The acceptable values 

93 are documented in :epkg:`make_attribute`. 

94 :return: node 

95 """ 

96 node = NodeProto() 

97 node.op_type = op_type 

98 node.input.extend(inputs) # pylint: disable=E1101 

99 node.output.extend(outputs) # pylint: disable=E1101 

100 if name: 

101 node.name = name 

102 if doc_string: 

103 node.doc_string = doc_string # pragma: no cover 

104 if domain is not None: 

105 node.domain = domain 

106 if isinstance(attributes, dict): 

107 if len(attributes) > 0: # pragma: no cover 

108 node.attribute.extend( # pylint: disable=E1101 

109 make_attribute(key, value) 

110 for key, value in sorted(attributes.items())) 

111 elif attributes: 

112 for att in attributes: 

113 node.attribute.extend([att]) # pylint: disable=E1101 

114 return node 

115 

116 

117def _replace(name, old_name, new_name): 

118 if isinstance(old_name, dict) and new_name is None: 

119 return old_name.get(name, name) 

120 if name == old_name: 

121 return new_name 

122 return name 

123 

124 

125def _rename_node_input(onnx_node, old_name, new_name=None): 

126 """ 

127 Renames an input from a node. 

128 

129 @param onnx_node onnx_node 

130 @param old_name old name 

131 @param new_name new name or None if *old_name* is a dictionary 

132 @return new node 

133 """ 

134 inputs = [_replace(name, old_name, new_name) for name in onnx_node.input] 

135 outputs = list(onnx_node.output) 

136 if hasattr(onnx_node, 'attribute'): 

137 new_atts = [] 

138 for att in onnx_node.attribute: 

139 if (att.type == AttributeProto.GRAPH and # pylint: disable=E1101 

140 hasattr(att, 'g') and att.g is not None): 

141 new_body = _rename_graph_input(att.g, old_name, new_name) 

142 attr = AttributeProto() 

143 attr.name = att.name 

144 attr.g.CopyFrom(new_body) # pylint: disable=E1101 

145 attr.type = AttributeProto.GRAPH # pylint: disable=E1101 

146 new_atts.append(attr) 

147 else: 

148 new_atts.append(att) 

149 atts = new_atts 

150 else: 

151 atts = None # pragma: no cover 

152 node = _make_node( 

153 onnx_node.op_type, inputs, outputs, name=onnx_node.name, 

154 domain=onnx_node.domain, attributes=atts) 

155 return node 

156 

157 

158def _copy_value_info_proto(new_name, obj): 

159 value_info = ValueInfoProto() 

160 value_info.name = new_name 

161 value_info.type.CopyFrom(obj.type) # pylint: disable=E1101 

162 if obj.type.doc_string: 

163 value_info.doc_string = obj.type.doc_string 

164 return value_info 

165 

166 

167def _rename_graph_output(graph, old_name, new_name): 

168 """ 

169 Renames an output and adds an *Identity* node 

170 to connect the dots. 

171 

172 @param graph ONNX graph 

173 @return modified graph 

174 """ 

175 outputs = [] 

176 for o in graph.output: 

177 if old_name != o.name: 

178 outputs.append(o) 

179 else: 

180 outputs.append(_copy_value_info_proto(new_name, o)) 

181 nodes = list(graph.node) 

182 nodes.append(_make_node('Identity', [old_name], [new_name])) 

183 new_graph = make_graph(nodes, graph.name, graph.input, outputs, 

184 graph.initializer) 

185 new_graph.value_info.extend(graph.value_info) # pylint: disable=E1101 

186 return new_graph 

187 

188 

189def _rename_graph_input(graph, old_name, new_name): 

190 """ 

191 Renames an input and adds an *Identity* node 

192 to connect the dots. 

193 

194 @param graph ONNX graph 

195 @return modified graph 

196 """ 

197 inputs = [] 

198 for i in graph.input: 

199 if old_name != i.name: 

200 inputs.append(i) 

201 else: 

202 inputs.append(_copy_value_info_proto(new_name, i)) 

203 nodes = list(graph.node) 

204 nodes.append(_make_node('Identity', [new_name], [old_name])) 

205 new_graph = make_graph(nodes, graph.name, inputs, graph.output, 

206 graph.initializer) 

207 new_graph.value_info.extend(graph.value_info) # pylint: disable=E1101 

208 return new_graph 

209 

210 

211def _make_att_graph(name, new_body): 

212 attr = AttributeProto() 

213 attr.name = name 

214 attr.g.CopyFrom(new_body) # pylint: disable=E1101 

215 attr.type = AttributeProto.GRAPH # pylint: disable=E1101 

216 return attr 

217 

218 

219def _rename_node_output(onnx_node, old_name, new_name): 

220 """ 

221 Renames an output from a node. 

222 

223 @param onnx_node onnx_node 

224 @param old_name old name 

225 @param new_name new name 

226 @return new node 

227 """ 

228 inputs = list(onnx_node.input) 

229 outputs = [_replace(name, old_name, new_name) for name in onnx_node.output] 

230 if hasattr(onnx_node, 'attribute'): 

231 new_atts = [] 

232 for att in onnx_node.attribute: 

233 if (att.type == AttributeProto.GRAPH and # pylint: disable=E1101 

234 hasattr(att, 'g') and att.g is not None): 

235 new_body = _rename_graph_output(att.g, old_name, new_name) 

236 new_atts.append(_make_att_graph(att.name, new_body)) 

237 else: 

238 new_atts.append(att) 

239 atts = new_atts 

240 else: 

241 atts = None # pragma: no cover 

242 node = _make_node( 

243 onnx_node.op_type, inputs, outputs, name=onnx_node.name, 

244 domain=onnx_node.domain, attributes=atts) 

245 return node