Coverage for deeponnxcustom/onnxtorch/tchrun.py: 100%
127 statements
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-06 02:28 +0200
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-06 02:28 +0200
1"""
2@file
3@brief Executes ONNX graph with pytorch.
4"""
5from onnx.numpy_helper import to_array
6import torch
7from ..tools.math_helper import decompose_permutation
10class _function_OnnxTorchRuntime:
12 @staticmethod
13 def _concat(*tensors, axis=0):
14 nonnull = [t for t in tensors if len(t.shape) > 0]
15 if len(nonnull) == 0:
16 raise NotImplementedError(
17 "Cannot concatenate empty tensors.")
18 if len(nonnull) == 1:
19 return nonnull[0]
20 try:
21 return torch.cat(nonnull, dim=axis) # pylint: disable=E1101
22 except RuntimeError as e: # pragma: no cover
23 raise RuntimeError(
24 "Unable to run 'cat' with shape=%r and axis=%r." % (
25 ", ".join(str(t.shape) for t in tensors),
26 axis)) from e
28 @staticmethod
29 def _gather(t, indices, axis=0):
30 return torch.gather(t, axis, indices) # pylint: disable=E1101
32 @staticmethod
33 def _gemm(a, b, c=None, alpha=1, beta=0, transA=False, transB=False):
34 if transA:
35 a = a.T
36 if transB:
37 b = b.T
38 res = torch.matmul(a, b) * alpha # pylint: disable=E1101
39 if c is not None:
40 res += c * beta
41 return res
43 @staticmethod
44 def _reduceprod(data, axes=None, keepdims=1):
45 if axes is None:
46 if len(data.shape) == 1:
47 return torch.prod( # pylint: disable=E1101
48 data, 0, keepdims == 1)
49 raise NotImplementedError(
50 "Unable to prod(...) with shape=%r axes=%r keepdims=%r." % (
51 tuple(data.shape), axes, keepdims))
52 if len(axes) != 1:
53 for a in reversed(axes):
54 data = torch.prod( # pylint: disable=E1101
55 data, dim=a, keepdim=keepdims == 1)
56 return data
57 return torch.prod( # pylint: disable=E1101
58 data, dim=axes[0], keepdim=keepdims == 1)
60 @staticmethod
61 def _reducesum(data, axes=None, keepdims=1):
62 if axes is None:
63 if len(data.shape) == 1:
64 return torch.sum( # pylint: disable=E1101
65 data, 0, keepdims == 1)
66 raise NotImplementedError(
67 "Unable to prod(...) with shape=%r axes=%r keepdims=%r." % (
68 tuple(data.shape), axes, keepdims))
69 return torch.sum( # pylint: disable=E1101
70 data, dim=axes, keepdim=keepdims == 1)
72 @staticmethod
73 def _reshape(t, shape):
74 return torch.reshape(t, tuple(shape)) # pylint: disable=E1101
76 @staticmethod
77 def _shape(t):
78 return torch.tensor(t.shape) # pylint: disable=E1101
80 @staticmethod
81 def _squeeze(data, axes=None):
82 if axes is None:
83 return torch.squeeze(data) # pylint: disable=E1101
84 if len(axes) == 1:
85 return torch.squeeze(data, axes[0]) # pylint: disable=E1101
86 for a in reversed(axes):
87 data = torch.squeeze(data, a) # pylint: disable=E1101
88 return data
90 @staticmethod
91 def _transpose(t, perm):
92 transitions = decompose_permutation(perm)
93 for a, b in transitions:
94 t = torch.transpose(t, a, b) # pylint: disable=E1101
95 return t
97 @staticmethod
98 def _unqueeze(t, dim):
99 if tuple(dim.shape) == (0, ):
100 return t
101 if len(dim) == 1:
102 return torch.unsqueeze(t, dim[0]) # pylint: disable=E1101
103 v = t
104 for d in dim:
105 v = torch.unsqueeze(v, d) # pylint: disable=E1101
106 return v
109class OnnxTorchRuntime:
110 """
111 Executes ONNX graph using :epkg:`torch` function.
112 This is a very simple runtime. It goes through every
113 node in the ONNX graph and execute with the corresponding
114 torch functions.
116 :param onnx_model: ONNX model
118 The class is very basic. It does not handle subgraphs and
119 supports a limited number of operators.
121 .. runpython::
122 :showcode:
124 import pprint
125 from deeponnxcustom.onnxtorch.tchrun import OnnxTorchRuntime
127 pprint.pprint(list(sorted(OnnxTorchRuntime._mapping)))
128 """
130 _mapping = {
131 'Concat': _function_OnnxTorchRuntime._concat,
132 'Gather': _function_OnnxTorchRuntime._gather,
133 'Gemm': _function_OnnxTorchRuntime._gemm,
134 'Identity': lambda x: x,
135 'MatMul': torch.matmul, # pylint: disable=E1101
136 'Max': torch.max, # pylint: disable=E1101
137 'ReduceProd':
138 _function_OnnxTorchRuntime._reduceprod, # pylint: disable=E1101
139 'ReduceSum':
140 _function_OnnxTorchRuntime._reducesum, # pylint: disable=E1101
141 'Reshape': _function_OnnxTorchRuntime._reshape,
142 'Shape': _function_OnnxTorchRuntime._shape,
143 'Squeeze': _function_OnnxTorchRuntime._squeeze,
144 'Transpose': _function_OnnxTorchRuntime._transpose,
145 'Unsqueeze': _function_OnnxTorchRuntime._unqueeze,
146 }
148 def __init__(self, onnx_model):
149 self._onnx_model = onnx_model
150 self._inits = OnnxTorchRuntime._extract_init(onnx_model)
151 self._atts = OnnxTorchRuntime._extract_atts(onnx_model)
153 @staticmethod
154 def _extract_init(onnx_model):
155 """
156 Builds a dictionary with all initializers
157 converted into torch arrays.
158 """
159 res = {}
160 for init in onnx_model.graph.initializer:
161 if init.name in res:
162 raise RuntimeError( # pragma: no cover
163 "Duplicated initializer name %r for type %r." % (
164 init.name, init.op_type))
165 res[init.name] = torch.from_numpy( # pylint: disable=E1101
166 to_array(init))
167 return res
169 @staticmethod
170 def _extract_atts(onnx_model):
171 """
172 Builds a dictionary with all attributes
173 """
174 res = {}
175 for i, node in enumerate(onnx_model.graph.node):
176 node_name = "N%d_%s" % (i, node.name)
177 res[node_name] = {}
178 for at in node.attribute:
179 if node.op_type in ('ReduceSum', 'ReduceProd'):
180 if at.name == 'axes':
181 res[node_name][at.name] = tuple(at.ints)
182 else:
183 res[node_name][at.name] = at.i
184 if node.op_type == 'Transpose':
185 res[node_name][at.name] = tuple(at.ints)
186 elif node.op_type in ('Gather', 'Concat'):
187 res[node_name][at.name] = at.i
188 elif node.op_type == 'Gemm':
189 if at.name in ('alpha', 'beta'):
190 res[node_name][at.name] = at.f
191 else:
192 res[node_name][at.name] = at.i
193 return res
195 def _run_op(self, node_name, node, *inputs):
196 """
197 Executes a node with :epkg:`pytorch`.
198 Returns a dictionary.
199 """
200 if len(node.output) != 1:
201 raise NotImplementedError(
202 "Unable to execute a node with more than one "
203 "input (type=%r)." % node.op_type)
204 tf = OnnxTorchRuntime._mapping[node.op_type]
205 try:
206 res = tf(*inputs, **self._atts[node_name])
207 except (TypeError, IndexError, RuntimeError) as e: # pragma: no cover
208 raise RuntimeError(
209 "Unable to run operator %r with len(inputs)=%d, atts=%r.\n%r"
210 "" % (node.op_type, len(inputs),
211 self._atts[node_name], inputs)) from e
212 if isinstance(res, tuple):
213 return res # pragma: no cover
214 return (res, )
216 def run(self, *inputs, verbose=False):
217 """
218 Executes the ONNX graph.
220 :param inputs: inputs of the function
221 :param verbose: displays more information while running the graph
222 :return: a result or a tuple of results
223 """
224 keep = self._inits.copy()
225 for i, v in zip(self._onnx_model.graph.input, inputs):
226 keep[i.name] = v
228 for i, node in enumerate(self._onnx_model.graph.node):
229 node_name = "N%d_%s" % (i, node.name)
230 node_inputs = [keep[name] for name in node.input]
231 res = self._run_op(node_name, node, *node_inputs)
232 if verbose:
233 print( # pragma: no cover
234 "[OnnxTorchRuntime.run] op=%r, shapes=[%s] "
235 "-> %s, name=%r in [%r, %r], atts=%r" % (
236 node.op_type,
237 ", ".join(map(
238 lambda x: str(tuple(getattr(x, 'shape', '?'))),
239 node_inputs)),
240 ", ".join(map(
241 lambda x: str(tuple(getattr(x, 'shape', '?'))),
242 res)),
243 node.name,
244 float(min(t.min() for t in res)),
245 float(max(t.max() for t in res)),
246 self._atts[node_name]))
247 for name, value in zip(node.output, res):
248 if not isinstance(value, torch.Tensor):
249 raise TypeError( # pragma: no cover
250 "Unexpected value for name=%r, type=%r." % (
251 name, type(value)))
252 keep[name] = value
254 res = tuple(keep[o.name] for o in self._onnx_model.graph.output)
255 if len(res) == 1:
256 return res[0]
257 return res # pragma: no cover