Coverage for mlprodict/onnx_tools/compress.py: 80%

122 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Functions to simplify, compress an ONNX graph. 

4 

5.. versionadded:: 0.9 

6""" 

7import logging 

8from onnx import ModelProto, GraphProto, FunctionProto 

9from onnx.helper import ( 

10 make_function, make_model, make_value_info, make_graph, 

11 make_tensor_type_proto, make_node, make_operatorsetid) 

12 

13 

14logger = logging.getLogger('onnx:compress') 

15 

16 

17def _check_expression(expe): 

18 att = expe.attribute[0].g 

19 inputs = [i.name for i in att.input] 

20 if list(expe.input) != inputs: 

21 raise RuntimeError( # pragma: no cover 

22 f'Name mismatch in node Expression {expe.input!r} != {inputs!r}.') 

23 outputs = [o.name for o in att.output] 

24 if list(expe.output) != outputs: 

25 raise RuntimeError( # pragma: no cover 

26 f'Name mismatch in node Expression {expe.input!r} != {inputs!r}.') 

27 

28 

29def _fuse_node(o, node, node_next): 

30 """ 

31 Merges two nodes having one input/output in common. 

32 

33 :param o: output name 

34 :param node: first node (it outputs the results) 

35 :param node_next: second node (it ingests the result) 

36 :return: merged node 

37 """ 

38 type_expression = ('mlprodict', 'Expression') 

39 if list(node.output) != [o]: 

40 raise RuntimeError( # pragma: no cover 

41 f"The only output of the first node should be {[o]!r} not {node.output!r}.") 

42 cannot_do = {('', 'If'), ('', 'Loop'), ('', 'Scan')} 

43 key1 = node.domain, node.op_type 

44 if key1 in cannot_do: 

45 return None 

46 key2 = node_next.domain, node_next.op_type 

47 if key2 in cannot_do: 

48 return None 

49 

50 if key1 == type_expression: 

51 _check_expression(node) 

52 if key2 == type_expression: 

53 _check_expression(node_next) 

54 

55 graph = None 

56 

57 if node.domain == '' and node_next.domain == '': 

58 # Simple case 

59 inputs = [make_value_info(name, make_tensor_type_proto(0, [])) 

60 for name in node.input] 

61 outputs = [make_value_info(name, make_tensor_type_proto(0, [])) 

62 for name in node_next.output] 

63 graph = make_graph([node, node_next], "expression", inputs, outputs) 

64 

65 elif key1 == type_expression and node_next.domain == '': 

66 att = node.attribute[0].g 

67 inputs = att.input 

68 outputs = [make_value_info(name, make_tensor_type_proto(0, [])) 

69 for name in node_next.output] 

70 graph = make_graph(list(att.node) + [node_next], 

71 "expression", inputs, outputs) 

72 

73 elif node.domain == '' and key2 == type_expression: 

74 att = node_next.attribute[0].g 

75 inputs = [make_value_info(name, make_tensor_type_proto(0, [])) 

76 for name in node.input] 

77 outputs = att.output 

78 graph = make_graph([node] + list(att.node), 

79 "expression", inputs, outputs) 

80 

81 elif key1 == type_expression and key2 == type_expression: 

82 att1 = node.attribute[0].g 

83 att2 = node_next.attribute[0].g 

84 inputs = att1.input 

85 outputs = att2.output 

86 graph = make_graph(list(att1.node) + list(att2.node), 

87 "expression", inputs, outputs) 

88 

89 if graph is not None: 

90 new_node = make_node( 

91 'Expression', node.input, node_next.output, domain='mlprodict', 

92 expression=graph) 

93 return new_node 

94 

95 raise NotImplementedError( # pragma: no cover 

96 "Unable to merge nodes '%s/%s' and '%s/%s'." % ( 

97 node.domain, node.op_type, node_next.domain, node_next.op_type)) 

98 

99 

100def _compress_nodes_once(nodes, verbose=0): 

101 """ 

102 Compresses a sequence of node to make it more 

103 readable. If possible, it creates a node `Expression` 

104 with a graph as an attribute. 

105 

106 :param nodes: sequence of nodes to compress 

107 :return: compressed sequence of nodes 

108 """ 

109 # check that a result is used only once 

110 order = {} 

111 results = {} 

112 for node in list(nodes): 

113 order[id(node)] = (len(order), node) 

114 for name in node.input: 

115 if name in results: 

116 results[name] += 1 

117 else: 

118 results[name] = 1 

119 

120 once = {k: v for k, v in results.items() if v == 1} 

121 if len(once) == 0: 

122 return nodes 

123 

124 once_nodes_o = {} 

125 once_nodes_i = {} 

126 for node in nodes: 

127 if len(node.output) != 1: 

128 continue 

129 for o in node.output: 

130 if o in once: 

131 once_nodes_o[o] = node 

132 for i in node.input: 

133 if i in once: 

134 once_nodes_i[i] = node 

135 

136 if len(once_nodes_o) == 0: 

137 return nodes 

138 

139 if verbose > 0: 

140 logger.debug( 

141 "Results to compress: %r", list(sorted(once_nodes_o))) 

142 

143 while len(once_nodes_o) > 0: 

144 o, node = once_nodes_o.popitem() 

145 node_next = once_nodes_i[o] 

146 new_node = _fuse_node(o, node, node_next) 

147 if new_node is None: 

148 # nothing can be done 

149 continue 

150 once_nodes_o.update({o: new_node for o in node_next.output 

151 if o in once_nodes_o}) 

152 once_nodes_i.update({i: new_node for i in node.input 

153 if i in once_nodes_i}) 

154 order[id(new_node)] = (order[id(node)][0], new_node) 

155 del order[id(node)] 

156 del order[id(node_next)] 

157 

158 ordered = list(sorted((v[0], k, v[1]) for k, v in order.items())) 

159 return [v[-1] for v in ordered] 

160 

161 

162def _compress_nodes(nodes, verbose=0): 

163 """ 

164 Compresses a sequence of node to make it more 

165 readable. If possible, it creates a node `Expression` 

166 with a graph as an attribute. 

167 

168 :param nodes: sequence of nodes to compress 

169 :return: compressed sequence of nodes 

170 """ 

171 return _compress_nodes_once(nodes, verbose=verbose) 

172 

173 

174def compress_proto(proto, verbose=0): 

175 """ 

176 Compresses a :epkg:`ModelProto`, :epkg:`FunctionProto`, 

177 :epkg:`GraphProto`. The function detects nodes outputting 

178 results only used once. It then fuses it with the node 

179 taking it as an input. 

180 

181 :param proto: :epkg:`ModelProto`, :epkg:`FunctionProto`, 

182 :epkg:`GraphProto` 

183 :param verbose: logging 

184 :return: same type 

185 

186 .. versionadded:: 0.9 

187 """ 

188 if isinstance(proto, FunctionProto): 

189 nodes = _compress_nodes(proto.node, verbose=verbose) 

190 if len(nodes) == len(proto.node): 

191 # unchanged 

192 return proto 

193 if verbose: 

194 logger.debug( # pragma: no cover 

195 "Compressed function %r/%r from %d nodes to %d.", 

196 proto.domain, proto.name, len(proto.node), len(nodes)) 

197 opsets = {op.domain: op.version for op in proto.opset_import} 

198 opsets['mlprodict'] = 1 

199 

200 return make_function( 

201 proto.domain, proto.name, 

202 proto.input, proto.output, nodes, 

203 opset_imports=[ 

204 make_operatorsetid(k, v) for k, v in opsets.items()], 

205 attributes=proto.attribute, 

206 doc_string=proto.doc_string) 

207 

208 if isinstance(proto, ModelProto): 

209 modified = 0 

210 new_graph = compress_proto(proto.graph, verbose=verbose) 

211 if id(new_graph) != id(proto.graph): 

212 modified += 1 

213 fcts = [] 

214 for f in proto.functions: 

215 new_f = compress_proto(f, verbose=verbose) 

216 if id(new_f) != id(f): 

217 modified += 1 

218 fcts.append(new_f) 

219 if modified == 0: 

220 return proto 

221 opsets = {op.domain: op.version for op in proto.opset_import} 

222 opsets['mlprodict'] = 1 

223 if verbose: 

224 logger.debug( # pragma: no cover 

225 "Compressed model %s modified=%d.", proto.name, modified) 

226 return make_model( 

227 new_graph, functions=fcts, 

228 opset_imports=[ 

229 make_operatorsetid(k, v) for k, v in opsets.items()], 

230 producer_name=proto.producer_name, 

231 producer_version=proto.producer_version, 

232 ir_version=proto.ir_version, 

233 doc_string=proto.doc_string, 

234 domain=proto.domain, 

235 model_version=proto.model_version) 

236 

237 if isinstance(proto, GraphProto): 

238 nodes = _compress_nodes(proto.node, verbose=verbose) 

239 if len(nodes) == len(proto.node): 

240 # unchanged 

241 return proto 

242 if verbose: 

243 logger.debug( # pragma: no cover 

244 "Compressed graph %s from %d nodes to %d.", 

245 proto.name, len(proto.node), len(nodes)) 

246 return make_graph( 

247 nodes, proto.name, proto.input, proto.output, 

248 proto.initializer, sparse_initializer=proto.sparse_initializer) 

249 

250 raise TypeError( # pragma: no cover 

251 "Unexpected type for proto %r, it should ModelProto, " 

252 "GraphProto or FunctionProto." % type(proto))