Coverage for deeponnxcustom/onnxtorch/tchrun.py: 100%

127 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-06-06 02:28 +0200

1""" 

2@file 

3@brief Executes ONNX graph with pytorch. 

4""" 

5from onnx.numpy_helper import to_array 

6import torch 

7from ..tools.math_helper import decompose_permutation 

8 

9 

10class _function_OnnxTorchRuntime: 

11 

12 @staticmethod 

13 def _concat(*tensors, axis=0): 

14 nonnull = [t for t in tensors if len(t.shape) > 0] 

15 if len(nonnull) == 0: 

16 raise NotImplementedError( 

17 "Cannot concatenate empty tensors.") 

18 if len(nonnull) == 1: 

19 return nonnull[0] 

20 try: 

21 return torch.cat(nonnull, dim=axis) # pylint: disable=E1101 

22 except RuntimeError as e: # pragma: no cover 

23 raise RuntimeError( 

24 "Unable to run 'cat' with shape=%r and axis=%r." % ( 

25 ", ".join(str(t.shape) for t in tensors), 

26 axis)) from e 

27 

28 @staticmethod 

29 def _gather(t, indices, axis=0): 

30 return torch.gather(t, axis, indices) # pylint: disable=E1101 

31 

32 @staticmethod 

33 def _gemm(a, b, c=None, alpha=1, beta=0, transA=False, transB=False): 

34 if transA: 

35 a = a.T 

36 if transB: 

37 b = b.T 

38 res = torch.matmul(a, b) * alpha # pylint: disable=E1101 

39 if c is not None: 

40 res += c * beta 

41 return res 

42 

43 @staticmethod 

44 def _reduceprod(data, axes=None, keepdims=1): 

45 if axes is None: 

46 if len(data.shape) == 1: 

47 return torch.prod( # pylint: disable=E1101 

48 data, 0, keepdims == 1) 

49 raise NotImplementedError( 

50 "Unable to prod(...) with shape=%r axes=%r keepdims=%r." % ( 

51 tuple(data.shape), axes, keepdims)) 

52 if len(axes) != 1: 

53 for a in reversed(axes): 

54 data = torch.prod( # pylint: disable=E1101 

55 data, dim=a, keepdim=keepdims == 1) 

56 return data 

57 return torch.prod( # pylint: disable=E1101 

58 data, dim=axes[0], keepdim=keepdims == 1) 

59 

60 @staticmethod 

61 def _reducesum(data, axes=None, keepdims=1): 

62 if axes is None: 

63 if len(data.shape) == 1: 

64 return torch.sum( # pylint: disable=E1101 

65 data, 0, keepdims == 1) 

66 raise NotImplementedError( 

67 "Unable to prod(...) with shape=%r axes=%r keepdims=%r." % ( 

68 tuple(data.shape), axes, keepdims)) 

69 return torch.sum( # pylint: disable=E1101 

70 data, dim=axes, keepdim=keepdims == 1) 

71 

72 @staticmethod 

73 def _reshape(t, shape): 

74 return torch.reshape(t, tuple(shape)) # pylint: disable=E1101 

75 

76 @staticmethod 

77 def _shape(t): 

78 return torch.tensor(t.shape) # pylint: disable=E1101 

79 

80 @staticmethod 

81 def _squeeze(data, axes=None): 

82 if axes is None: 

83 return torch.squeeze(data) # pylint: disable=E1101 

84 if len(axes) == 1: 

85 return torch.squeeze(data, axes[0]) # pylint: disable=E1101 

86 for a in reversed(axes): 

87 data = torch.squeeze(data, a) # pylint: disable=E1101 

88 return data 

89 

90 @staticmethod 

91 def _transpose(t, perm): 

92 transitions = decompose_permutation(perm) 

93 for a, b in transitions: 

94 t = torch.transpose(t, a, b) # pylint: disable=E1101 

95 return t 

96 

97 @staticmethod 

98 def _unqueeze(t, dim): 

99 if tuple(dim.shape) == (0, ): 

100 return t 

101 if len(dim) == 1: 

102 return torch.unsqueeze(t, dim[0]) # pylint: disable=E1101 

103 v = t 

104 for d in dim: 

105 v = torch.unsqueeze(v, d) # pylint: disable=E1101 

106 return v 

107 

108 

109class OnnxTorchRuntime: 

110 """ 

111 Executes ONNX graph using :epkg:`torch` function. 

112 This is a very simple runtime. It goes through every 

113 node in the ONNX graph and execute with the corresponding 

114 torch functions. 

115 

116 :param onnx_model: ONNX model 

117 

118 The class is very basic. It does not handle subgraphs and 

119 supports a limited number of operators. 

120 

121 .. runpython:: 

122 :showcode: 

123 

124 import pprint 

125 from deeponnxcustom.onnxtorch.tchrun import OnnxTorchRuntime 

126 

127 pprint.pprint(list(sorted(OnnxTorchRuntime._mapping))) 

128 """ 

129 

130 _mapping = { 

131 'Concat': _function_OnnxTorchRuntime._concat, 

132 'Gather': _function_OnnxTorchRuntime._gather, 

133 'Gemm': _function_OnnxTorchRuntime._gemm, 

134 'Identity': lambda x: x, 

135 'MatMul': torch.matmul, # pylint: disable=E1101 

136 'Max': torch.max, # pylint: disable=E1101 

137 'ReduceProd': 

138 _function_OnnxTorchRuntime._reduceprod, # pylint: disable=E1101 

139 'ReduceSum': 

140 _function_OnnxTorchRuntime._reducesum, # pylint: disable=E1101 

141 'Reshape': _function_OnnxTorchRuntime._reshape, 

142 'Shape': _function_OnnxTorchRuntime._shape, 

143 'Squeeze': _function_OnnxTorchRuntime._squeeze, 

144 'Transpose': _function_OnnxTorchRuntime._transpose, 

145 'Unsqueeze': _function_OnnxTorchRuntime._unqueeze, 

146 } 

147 

148 def __init__(self, onnx_model): 

149 self._onnx_model = onnx_model 

150 self._inits = OnnxTorchRuntime._extract_init(onnx_model) 

151 self._atts = OnnxTorchRuntime._extract_atts(onnx_model) 

152 

153 @staticmethod 

154 def _extract_init(onnx_model): 

155 """ 

156 Builds a dictionary with all initializers 

157 converted into torch arrays. 

158 """ 

159 res = {} 

160 for init in onnx_model.graph.initializer: 

161 if init.name in res: 

162 raise RuntimeError( # pragma: no cover 

163 "Duplicated initializer name %r for type %r." % ( 

164 init.name, init.op_type)) 

165 res[init.name] = torch.from_numpy( # pylint: disable=E1101 

166 to_array(init)) 

167 return res 

168 

169 @staticmethod 

170 def _extract_atts(onnx_model): 

171 """ 

172 Builds a dictionary with all attributes 

173 """ 

174 res = {} 

175 for i, node in enumerate(onnx_model.graph.node): 

176 node_name = "N%d_%s" % (i, node.name) 

177 res[node_name] = {} 

178 for at in node.attribute: 

179 if node.op_type in ('ReduceSum', 'ReduceProd'): 

180 if at.name == 'axes': 

181 res[node_name][at.name] = tuple(at.ints) 

182 else: 

183 res[node_name][at.name] = at.i 

184 if node.op_type == 'Transpose': 

185 res[node_name][at.name] = tuple(at.ints) 

186 elif node.op_type in ('Gather', 'Concat'): 

187 res[node_name][at.name] = at.i 

188 elif node.op_type == 'Gemm': 

189 if at.name in ('alpha', 'beta'): 

190 res[node_name][at.name] = at.f 

191 else: 

192 res[node_name][at.name] = at.i 

193 return res 

194 

195 def _run_op(self, node_name, node, *inputs): 

196 """ 

197 Executes a node with :epkg:`pytorch`. 

198 Returns a dictionary. 

199 """ 

200 if len(node.output) != 1: 

201 raise NotImplementedError( 

202 "Unable to execute a node with more than one " 

203 "input (type=%r)." % node.op_type) 

204 tf = OnnxTorchRuntime._mapping[node.op_type] 

205 try: 

206 res = tf(*inputs, **self._atts[node_name]) 

207 except (TypeError, IndexError, RuntimeError) as e: # pragma: no cover 

208 raise RuntimeError( 

209 "Unable to run operator %r with len(inputs)=%d, atts=%r.\n%r" 

210 "" % (node.op_type, len(inputs), 

211 self._atts[node_name], inputs)) from e 

212 if isinstance(res, tuple): 

213 return res # pragma: no cover 

214 return (res, ) 

215 

216 def run(self, *inputs, verbose=False): 

217 """ 

218 Executes the ONNX graph. 

219 

220 :param inputs: inputs of the function 

221 :param verbose: displays more information while running the graph 

222 :return: a result or a tuple of results 

223 """ 

224 keep = self._inits.copy() 

225 for i, v in zip(self._onnx_model.graph.input, inputs): 

226 keep[i.name] = v 

227 

228 for i, node in enumerate(self._onnx_model.graph.node): 

229 node_name = "N%d_%s" % (i, node.name) 

230 node_inputs = [keep[name] for name in node.input] 

231 res = self._run_op(node_name, node, *node_inputs) 

232 if verbose: 

233 print( # pragma: no cover 

234 "[OnnxTorchRuntime.run] op=%r, shapes=[%s] " 

235 "-> %s, name=%r in [%r, %r], atts=%r" % ( 

236 node.op_type, 

237 ", ".join(map( 

238 lambda x: str(tuple(getattr(x, 'shape', '?'))), 

239 node_inputs)), 

240 ", ".join(map( 

241 lambda x: str(tuple(getattr(x, 'shape', '?'))), 

242 res)), 

243 node.name, 

244 float(min(t.min() for t in res)), 

245 float(max(t.max() for t in res)), 

246 self._atts[node_name])) 

247 for name, value in zip(node.output, res): 

248 if not isinstance(value, torch.Tensor): 

249 raise TypeError( # pragma: no cover 

250 "Unexpected value for name=%r, type=%r." % ( 

251 name, type(value))) 

252 keep[name] = value 

253 

254 res = tuple(keep[o.name] for o in self._onnx_model.graph.output) 

255 if len(res) == 1: 

256 return res[0] 

257 return res # pragma: no cover