Coverage for mlprodict/onnxrt/onnx_micro_runtime.py: 91%

120 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Micro runtime for ONNX. 

4 

5.. versionadded:: 0.6 

6""" 

7import numpy 

8from ..onnx_tools.onnx2py_helper import _var_as_dict 

9 

10 

11class OnnxMicroRuntime: 

12 """ 

13 Implements a micro runtime for ONNX graphs. 

14 It does not implements all the operator types. 

15 

16 :param model_onnx: ONNX model 

17 

18 .. runpython:: 

19 :showcode: 

20 

21 import pprint 

22 import numpy 

23 from mlprodict.onnxrt.onnx_micro_runtime import OnnxMicroRuntime 

24 from mlprodict.npy.xop import loadop 

25 

26 OnnxAdd = loadop('Add') 

27 

28 dtype = numpy.float32 

29 opset = 15 

30 x = numpy.array([1, 2, 4, 5, 5, 4]).astype( 

31 numpy.float32).reshape((3, 2)) 

32 cop = OnnxAdd('X', numpy.array([1], dtype=dtype), op_version=opset) 

33 cop4 = OnnxAdd(cop, numpy.array([2], dtype=dtype), op_version=opset, 

34 output_names=['Y']) 

35 model_def = cop4.to_onnx({'X': x}, target_opset=opset) 

36 rt = OnnxMicroRuntime(model_def) 

37 out = rt.run({'X': x}) 

38 pprint.pprint(out) 

39 """ 

40 

41 def __init__(self, model_onnx): 

42 if not hasattr(model_onnx, 'graph'): 

43 raise TypeError( 

44 f"model_onnx is not an ONNX graph but {type(model_onnx)!r}.") 

45 self.model_onnx = model_onnx 

46 

47 @property 

48 def input_names(self): 

49 "Returns input names." 

50 return [i.name for i in self.model_onnx.graph.input] 

51 

52 @property 

53 def output_names(self): 

54 "Returns output names." 

55 return [i.name for i in self.model_onnx.graph.output] 

56 

57 def run(self, inputs): 

58 """ 

59 Computes the outputs of the graph. 

60 

61 :param inputs: dictionary 

62 :return: all intermediates results and output as a dictionary 

63 """ 

64 if not isinstance(inputs, dict): 

65 raise TypeError( 

66 f"inputs must be a dictionary not {type(inputs)!r}.") 

67 results = inputs.copy() 

68 

69 for init in self.model_onnx.graph.initializer: 

70 name = init.name 

71 mat = _var_as_dict(init)['value'] 

72 results[name] = mat 

73 

74 for node in self.model_onnx.graph.node: 

75 op_type = node.op_type 

76 inp = [results[n] for n in node.input] 

77 meth_name = f"_op_{op_type.lower()}" 

78 if not hasattr(self, meth_name): 

79 raise NotImplementedError( 

80 f"OnnxMicroRuntime does not implement operator {op_type!r}.") 

81 kwargs = {} 

82 for at in node.attribute: 

83 var = _var_as_dict(at) 

84 kwargs[at.name] = var['value'] 

85 out = getattr(self, meth_name)(*inp, **kwargs) 

86 for n, o in zip(node.output, out): 

87 results[n] = o 

88 

89 return results 

90 

91 ######################## 

92 # Runtime for operators 

93 ######################## 

94 

95 def _op_abs(self, x): 

96 "Runtime for operator :epkg:`Op:Abs`." 

97 return (numpy.abs(x), ) 

98 

99 def _op_add(self, x, y): 

100 "Runtime for operator :epkg:`Op:Add`." 

101 return (x + y, ) 

102 

103 def _op_concat(self, *args, axis=None): 

104 "Runtime for operator :epkg:`Op:Concat`." 

105 def _preprocess(a, axis): 

106 if axis >= len(a.shape): 

107 new_shape = a.shape + (1, ) * (axis + 1 - len(a.shape)) 

108 return a.reshape(new_shape) 

109 return a 

110 

111 targs = tuple(_preprocess(a, axis) for a in args) 

112 return (numpy.concatenate(targs, axis), ) 

113 

114 def _op_gemm(self, a, b, c=None, alpha=None, beta=None, 

115 transA=False, transB=False): 

116 "Runtime for operator :epkg:`Op:Gemm`." 

117 

118 def _gemm00(a, b, c, alpha, beta): 

119 o = numpy.dot(a, b) * alpha 

120 if beta != 0: 

121 o += c * beta 

122 return o 

123 

124 def _gemm01(a, b, c, alpha, beta): 

125 o = numpy.dot(a, b.T) * alpha 

126 if beta != 0: 

127 o += c * beta 

128 return o 

129 

130 def _gemm10(a, b, c, alpha, beta): 

131 o = numpy.dot(a.T, b) * alpha 

132 if beta != 0: 

133 o += c * beta 

134 return o 

135 

136 def _gemm11(a, b, c, alpha, beta): 

137 o = numpy.dot(a.T, b.T) * alpha 

138 if beta != 0: 

139 o += c * beta 

140 return o 

141 

142 if not isinstance(transA, (int, bool, numpy.int64)): 

143 raise TypeError( # pragma: no cover 

144 f"Unexpected type for transA: {type(transA)!r}.") 

145 if not isinstance(transB, (int, bool, numpy.int64)): 

146 raise TypeError( # pragma: no cover 

147 f"Unexpected type for transA: {type(transB)!r}.") 

148 if transA: 

149 fct = _gemm11 if transB else _gemm10 

150 else: 

151 fct = _gemm01 if transB else _gemm00 

152 return (fct(a, b, c, alpha=alpha, beta=beta), ) 

153 

154 def _op_gather(self, x, indices, axis=None): 

155 "Runtime for operator :epkg:`Op:Gather`." 

156 if not x.flags['C_CONTIGUOUS']: 

157 x = numpy.ascontiguousarray(x) 

158 if not indices.flags['C_CONTIGUOUS']: 

159 indices = indices.ascontiguousarray() 

160 return (numpy.take(x, indices, axis=axis), ) 

161 

162 def _op_identity(self, x): 

163 "Runtime for operator :epkg:`Op:Identity`." 

164 return (x, ) 

165 

166 def _op_matmul(self, x, y): 

167 "Runtime for operator :epkg:`Op:MatMul`." 

168 return (numpy.matmul(x, y), ) 

169 

170 def _op_max(self, *inps): 

171 "Runtime for operator :epkg:`Op:Max`." 

172 return (numpy.maximum(*inps), ) 

173 

174 def _op_mul(self, x, y): 

175 "Runtime for operator :epkg:`Op:Mul`." 

176 return (x * y, ) 

177 

178 def _op_reduceprod(self, data, axes=None, keepdims=None): 

179 "Runtime for operator :epkg:`Op:ReduceProd`." 

180 if axes is not None and not isinstance(axes, int): 

181 if isinstance(axes, numpy.ndarray) and len(axes.shape) == 0: 

182 axes = int(axes) 

183 else: 

184 axes = tuple(axes) if len(axes) > 0 else None 

185 return (numpy.prod(data, axis=axes, 

186 keepdims=keepdims, 

187 dtype=data.dtype), ) 

188 

189 def _op_reducesum(self, data, axes, keepdims=None, 

190 noop_with_empty_axes=None): 

191 "Runtime for operator :epkg:`Op:ReduceSum`." 

192 if axes is None and noop_with_empty_axes: 

193 return (data, ) 

194 if axes is not None and not isinstance(axes, int): 

195 if isinstance(axes, numpy.ndarray) and len(axes.shape) == 0: 

196 axes = int(axes) 

197 else: 

198 axes = tuple(axes) if len(axes) > 0 else None 

199 return (numpy.sum(data, axis=axes, 

200 keepdims=keepdims, 

201 dtype=data.dtype), ) 

202 

203 def _op_reshape(self, x, shape): 

204 "Runtime for operator :epkg:`Op:Reshape`." 

205 return (x.reshape(shape), ) 

206 

207 def _op_shape(self, x): 

208 "Runtime for operator :epkg:`Op:Shape`." 

209 return (numpy.array(list(x.shape), dtype=numpy.int64), ) 

210 

211 def _op_squeeze(self, x, axes=None): 

212 "Runtime for operator :epkg:`Op:Squeeze`." 

213 if axes is None: 

214 return (x, ) 

215 if hasattr(axes, '__iter__'): 

216 return (numpy.squeeze(x, axis=tuple(axes)), ) 

217 return (numpy.squeeze(x, axis=axes), ) 

218 

219 def _op_transpose(self, x, perm=None): 

220 "Runtime for operator :epkg:`Op:Transpose`." 

221 return (numpy.transpose(x, perm), ) 

222 

223 def _op_unsqueeze(self, x, axes=None): 

224 "Runtime for operator :epkg:`Op:Unsqueeze`." 

225 if axes is None: 

226 return (x, ) 

227 if hasattr(axes, '__iter__'): 

228 return (numpy.expand_dims(x, axis=tuple(axes)), ) 

229 return (numpy.expand_dims(x, axis=axes), )