Coverage for mlprodict/onnx_tools/onnx_tools.py: 92%

138 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Functions to manipulate ONNX file. 

4""" 

5from onnx import helper, AttributeProto 

6 

7 

8def find_node_name(model, name): 

9 """ 

10 Finds a node by its name. 

11 :param model: onnx graph 

12 :param name: node name 

13 :return: node pointer 

14 """ 

15 if not hasattr(model, "graph"): 

16 raise TypeError( # pragma: no cover 

17 f"Parameter model is not an ONNX model but {type(model)}") 

18 for node in model.graph.node: 

19 if node.name == name: 

20 return node 

21 return None # pragma: no cover 

22 

23 

24def find_node_input_name(node, name): 

25 """ 

26 Finds a node input by its name. 

27 :param node: onnx node 

28 :param name: node name 

29 :return: input index 

30 """ 

31 for i, inode in enumerate(node.input.node): 

32 if inode.name == name: 

33 return i 

34 return -1 

35 

36 

37def insert_node(model, op_type, node, input_index=0, new_name=None, **attrs): 

38 """ 

39 Inserts a node before one node input. 

40 :param model: onnx graph 

41 :param op_type: 

42 :param node: node or node name 

43 :param input_index: input index or input name 

44 :param attrs: node attributes 

45 :return: updated graph 

46 """ 

47 if isinstance(node, str): 

48 inode = find_node_name(model, node) 

49 else: 

50 inode = node 

51 if isinstance(input_index, str): 

52 input_index_ = find_node_input_name(node, input_index) 

53 if input_index_ == -1: 

54 raise RuntimeError( # pragma: no cover 

55 "Unable to find input_index %r in node %r." % ( 

56 input_index, node.name)) # pylint: disable=E1120 

57 input_index = input_index_ 

58 

59 # guess a new name 

60 names = [] 

61 for n in model.graph.node: 

62 names.extend(n.input) 

63 names.extend(n.output) 

64 names = set(names) 

65 if new_name is None: 

66 new_name = op_type.lower() 

67 root_name = new_name 

68 i = 0 

69 while new_name in names: 

70 new_name = "%s_%d" % (root_name, i) 

71 i += 1 

72 

73 new_node = helper.make_node( 

74 op_type, [inode.input[input_index]], [new_name], **attrs) 

75 inode.input[input_index] = new_name 

76 keep_nodes = list(model.graph.node) 

77 keep_nodes.append(new_node) 

78 keep_nodes = ensure_topological_order( 

79 model.graph.input, model.graph.initializer, keep_nodes) 

80 

81 graph = helper.make_graph( 

82 keep_nodes, model.graph.name, model.graph.input, 

83 model.graph.output, model.graph.initializer) 

84 onnx_model = helper.make_model(graph, functions=model.functions) 

85 onnx_model.ir_version = model.ir_version 

86 onnx_model.producer_name = model.producer_name 

87 onnx_model.producer_version = model.producer_version 

88 onnx_model.domain = model.domain 

89 onnx_model.model_version = model.model_version 

90 onnx_model.doc_string = model.doc_string 

91 if len(model.metadata_props) > 0: 

92 values = {p.key: p.value for p in model.metadata_props} 

93 helper.set_model_props(onnx_model, values) 

94 

95 del onnx_model.opset_import[:] # pylint: disable=E1101 

96 for oimp in model.opset_import: 

97 op_set = onnx_model.opset_import.add() # pylint: disable=E1101 

98 op_set.domain = oimp.domain 

99 op_set.version = oimp.version 

100 

101 if len(onnx_model.graph.input) != len(model.graph.input): # pylint: disable=E1101 

102 raise RuntimeError( # pragma: no cover 

103 "Input mismatch {} != {}".format( 

104 len(onnx_model.input), len(model.input))) # pylint: disable=E1101 

105 return onnx_model 

106 

107 

108def ensure_topological_order(inputs, initializers, nodes): 

109 """ 

110 Ensures and modifies the order of nodes to have 

111 a topological order (every node in the list 

112 can only be an input for a node later in this list). 

113 The function raises an exception if a cycle is detected. 

114 

115 :param inputs: graph inputs: 

116 :param initializers: graph initializers 

117 :param nodes: graph nodes 

118 :return: list ordered nodes 

119 """ 

120 order = {} 

121 for inp in inputs: 

122 name = inp.name 

123 order[name] = 0 

124 for inp in initializers: 

125 name = inp.name 

126 order[name] = 0 

127 n_iter = 0 

128 while n_iter < len(nodes) * 2: 

129 n_iter += 1 

130 missing_names = set() 

131 missing_ops = [] 

132 for node in nodes: 

133 maxi = 0 

134 for name in node.input: 

135 if name in order: 

136 maxi = max(maxi, order[name]) 

137 else: 

138 maxi = None 

139 missing_names.add(name) 

140 break 

141 if maxi is None: 

142 missing_ops.append(node) 

143 continue 

144 key = id(node) 

145 if key in order: 

146 continue 

147 maxi += 1 

148 order[key] = maxi 

149 maxi += 1 

150 for name in node.output: 

151 if name in order: 

152 raise RuntimeError( # pragma: no cover 

153 "Unable to sort a node (cycle). An output was " 

154 "already ordered %r (iteration=%r)." % ( 

155 name, n_iter)) 

156 order[name] = maxi 

157 if len(missing_names) == 0: 

158 continue 

159 

160 if len(missing_ops) > 0: # pragma: no cover 

161 def nstr(name): 

162 if name in order: 

163 return "%s#%d" % (name, order[name]) 

164 return name 

165 rows = ["%s(%s) -> [%s]" % ( 

166 n.name or n.op_type, 

167 ', '.join(map(nstr, n.input)), 

168 ', '.join(n.output)) 

169 for n in missing_ops] 

170 rows.insert(0, "") 

171 rows.append("--") 

172 rows.append("--all-nodes--") 

173 rows.append("--") 

174 rows.extend("%s(%s) -> [%s]" % ( 

175 n.name or n.op_type, 

176 ', '.join(map(nstr, n.input)), 

177 ', '.join(n.output)) 

178 for n in nodes) 

179 raise RuntimeError( 

180 "After %d iterations for %d nodes, still unable " 

181 "to sort names %r. The graph may be disconnected. " 

182 "List of operators: %s" % ( 

183 n_iter, len(nodes), missing_names, 

184 "\n".join(rows))) 

185 

186 # Update order 

187 topo = [(order[id(node)], str(id(node))) for node in nodes] 

188 topo.sort() 

189 map_nodes = {str(id(node)): node for node in nodes} 

190 return [map_nodes[_[1]] for _ in topo] 

191 

192 

193def enumerate_onnx_names(onx): 

194 """ 

195 Enumerates all existing names in one ONNX graph 

196 (:epkg:`ModelProto`, :epkg:`FunctionProto`, :epkg:`GraphProto`). 

197 The function is recursive. 

198 

199 :param onx: one onnx object 

200 :return: iterator on names 

201 """ 

202 if hasattr(onx, 'graph'): 

203 for i in onx.graph.initializer: 

204 yield i.name 

205 for i in onx.graph.input: 

206 yield i.name 

207 for i in onx.graph.output: 

208 yield i.name 

209 nodes = onx.graph.node 

210 elif hasattr(onx, 'initializer'): 

211 for i in onx.initializer: 

212 yield i.name 

213 for i in onx.input: 

214 yield i.name 

215 for i in onx.output: 

216 yield i.name 

217 nodes = onx.node 

218 else: 

219 if hasattr(onx, 'input'): 

220 for i in onx.input: 

221 yield i 

222 if hasattr(onx, 'output'): 

223 for i in onx.output: 

224 yield i 

225 nodes = onx.node 

226 for node in nodes: 

227 for i in node.input: 

228 yield i 

229 for o in node.output: 

230 yield o 

231 for att in node.attribute: 

232 if (att.type == AttributeProto.GRAPH and # pylint: disable=E0611,E1101 

233 hasattr(att, 'g') and att.g is not None): 

234 for n in enumerate_onnx_names(att.g): 

235 yield n 

236 

237 

238def enumerate_onnx_nodes(onx): 

239 """ 

240 Enumerates all nodes in one ONNX graph 

241 (:epkg:`ModelProto`, :epkg:`FunctionProto`, :epkg:`GraphProto`). 

242 The function is recursive. 

243 

244 :param onx: one onnx object 

245 :return: iterator on names 

246 """ 

247 if isinstance(onx, list): 

248 nodes = onx 

249 elif hasattr(onx, 'graph'): 

250 nodes = onx.graph.node 

251 else: 

252 nodes = onx.node 

253 for node in nodes: 

254 yield node 

255 for att in node.attribute: 

256 if (att.type == AttributeProto.GRAPH and # pylint: disable=E0611,E1101 

257 hasattr(att, 'g') and att.g is not None): 

258 for n in enumerate_onnx_nodes(att.g): 

259 yield n