Coverage for mlprodict/onnx_tools/onnx_grammar/onnx_translation.py: 100%

75 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief One class which visits a syntax tree. 

4""" 

5import inspect 

6import ast 

7from textwrap import dedent 

8import numpy 

9from scipy.spatial.distance import squareform, pdist 

10from .node_visitor_translator import CodeNodeVisitor 

11 

12 

13def py_make_float_array(cst, op_version=None): 

14 """ 

15 Creates an array with a single element 

16 from a constant. 

17 

18 @param cst constant 

19 @param op_version unused 

20 @return array 

21 

22 .. runpython:: 

23 :showcode: 

24 :warningout: DeprecationWarning 

25 

26 from mlprodict.onnx_tools.onnx_grammar.onnx_translation import py_make_float_array 

27 print(py_make_float_array(5.5)) 

28 """ 

29 return numpy.array([cst], dtype=numpy.float32) 

30 

31 

32def py_pow(x, p, op_version=None): 

33 """ 

34 Function for python operator ``**``. 

35 

36 @param x float 

37 @param p power 

38 @param op_version unused 

39 @return :math:`x^p` 

40 """ 

41 return x ** p 

42 

43 

44def py_mul(*x, op_version=None): 

45 """ 

46 Function for python operator ``*``. 

47 

48 @param x floats 

49 @param op_version unused 

50 @return `x*y` 

51 """ 

52 if len(x) == 2: 

53 return x[0] * x[1] 

54 p = x[0] 

55 for y in x[1:]: 

56 p *= y 

57 return p 

58 

59 

60def py_opp(x, op_version=None): 

61 """ 

62 Function for python unary operator ``-``. 

63 

64 @param x floats 

65 @param op_version unused 

66 @return `-x` 

67 """ 

68 return -x 

69 

70 

71def squareform_pdist(X, metric='sqeuclidean', op_version=None): 

72 """ 

73 Replacements for `squareform 

74 <http://scipy.github.io/devdocs/generated/scipy.spatial.distance.squareform.html>`_ 

75 and `pdist 

76 <http://scipy.github.io/devdocs/generated/scipy.spatial.distance.pdist.html>`_. 

77 """ 

78 return squareform(pdist(X, metric=metric)) 

79 

80 

81def get_default_context(): 

82 """ 

83 Returns a default context useful for most of the conversion 

84 from a function using :epkg:`numpy` into :epkg:`ONNX`. 

85 """ 

86 context = {'py_pow': py_pow, 'py_make_float_array': py_make_float_array, 

87 'py_mul': py_mul, 'py_opp': py_opp, 

88 'cdist': 'cdist', 'squareform_pdist': 'squareform_pdist'} 

89 allow = set(('abs add ceil arccos arccosh arcsin arcsinh arctan arctanh ceil cos cosh divide' 

90 'equal exp floor greater invert less log matmul maximum minimum mod' 

91 'multiply power sign sin sinh sqrt square subtract tan tanh transpose').split()) 

92 for k, v in numpy.__dict__.items(): 

93 if k not in allow: 

94 continue 

95 context[f'numpy.{k}'] = v 

96 context[f'np.{k}'] = v 

97 return context 

98 

99 

100def get_default_context_cpl(): 

101 """ 

102 Returns a default useful context to compile the converter 

103 returned by @see fn translate_fct2onnx. 

104 """ 

105 ctx = {'py_make_float_array': py_make_float_array, 

106 'py_pow': py_pow, 'py_mul': py_mul, 'py_opp': py_opp, 

107 'numpy': numpy} 

108 try: 

109 from skl2onnx.algebra.complex_functions import onnx_squareform_pdist # delayed 

110 from skl2onnx.algebra.complex_functions import onnx_cdist # delayed 

111 ctx['onnx_squareform_pdist'] = onnx_squareform_pdist 

112 ctx['onnx_cdist'] = onnx_cdist 

113 except ImportError: # pragma: no cover 

114 # Too old version for skl2onnx. 

115 pass 

116 

117 from skl2onnx.algebra import onnx_ops # delayed 

118 from skl2onnx.algebra.onnx_operator import OnnxOperator # delayed 

119 d = onnx_ops.__dict__ 

120 for k, v in d.items(): 

121 try: 

122 if k.startswith("Onnx") and issubclass(v, OnnxOperator): 

123 ctx[k] = v 

124 except TypeError as e: 

125 if inspect.isfunction(v): 

126 continue 

127 raise RuntimeError( # pragma: no cover 

128 f"Issue with {k}={v} (type={type(v)})") from e 

129 return ctx 

130 

131 

132def translate_fct2onnx(fct, context=None, cpl=False, 

133 context_cpl=None, output_names=None, 

134 dtype=numpy.float32, 

135 verbose=0, fLOG=None): 

136 """ 

137 Translates a function into :epkg:`ONNX`. The code it produces 

138 is using classes *OnnxAbs*, *OnnxAdd*, ... 

139 

140 @param fct function to convert 

141 @param context context of the function to convert 

142 something like ``{'numpy.transpose': numpy.transpose}``, 

143 if *context* is None, it receives a default value 

144 returnd by @see fn get_default_context 

145 @param cpl compile the function after it was 

146 created 

147 @param context_cpl context used at compiling time 

148 if *context_cpl* is None, it receives a default value 

149 returnd by @see fn get_default_context_cpl 

150 @param output_names names of the output in the :epkg:`ONNX` graph 

151 @param dtype :epkg:`numpy` float type used to produce the model 

152 @param verbose integer, display more information 

153 @param fLOG logging function 

154 @return code or compiled code 

155 

156 .. exref:: 

157 :title: Convert a function into ONNX code 

158 

159 The following code parses a python function and returns 

160 another python function which produces an :epkg:`ONNX` 

161 graph if executed. 

162 

163 .. runpython:: 

164 :showcode: 

165 :warningout: DeprecationWarning 

166 :process: 

167 :store_in_file: fct2onnx2.py 

168 

169 import numpy 

170 from mlprodict.onnx_tools.onnx_grammar import translate_fct2onnx 

171 

172 def trs(x, y): 

173 z = x + numpy.transpose(y, axes=[1, 0]) 

174 return x * z 

175 

176 onnx_code = translate_fct2onnx( 

177 trs, context={'numpy.transpose': numpy.transpose}) 

178 print(onnx_code) 

179 

180 Next example goes further and compile the outcome. 

181 

182 .. exref:: 

183 :title: Convert a function into ONNX code and run 

184 

185 The following code parses a python function and returns 

186 another python function which produces an :epkg:`ONNX` 

187 graph if executed. The example executes the function, 

188 creates an :epkg:`ONNX` then uses @see cl OnnxInference 

189 to compute *predictions*. Finally it compares 

190 them to the original. 

191 

192 .. runpython:: 

193 :showcode: 

194 :warningout: DeprecationWarning 

195 :process: 

196 :store_in_file: fct2onnx3.py 

197 

198 import numpy 

199 from mlprodict.onnx_tools.onnx_grammar import translate_fct2onnx 

200 from mlprodict.plotting.text_plot import onnx_simple_text_plot 

201 from mlprodict.onnxrt import OnnxInference 

202 from mlprodict.npy.xop import loadop 

203 

204 

205 OnnxAdd, OnnxTranspose, OnnxMul, OnnxIdentity = loadop( 

206 'Add', 'Transpose', 'Mul', 'Identity') 

207 

208 

209 ctx = {'OnnxAdd': OnnxAdd, 

210 'OnnxTranspose': OnnxTranspose, 

211 'OnnxMul': OnnxMul, 

212 'OnnxIdentity': OnnxIdentity} 

213 

214 def trs(x, y): 

215 z = x + numpy.transpose(y, axes=[1, 0]) 

216 return x * z 

217 

218 inputs = {'x': numpy.array([[1, 2]], dtype=numpy.float32), 

219 'y': numpy.array([[-0.3, 0.4]], dtype=numpy.float32).T} 

220 

221 original = trs(inputs['x'], inputs['y']) 

222 

223 print('original output:', original) 

224 

225 onnx_fct = translate_fct2onnx( 

226 trs, context={'numpy.transpose': numpy.transpose}, 

227 cpl=True, context_cpl=ctx, output_names=['Z']) 

228 

229 onnx_code = onnx_fct('x', 'y', op_version=12) 

230 

231 onnx_g = onnx_code.to_onnx(inputs, target_opset=12) 

232 print("ONNX model") 

233 print(onnx_simple_text_plot(onnx_g)) 

234 

235 oinf = OnnxInference(onnx_g) 

236 res = oinf.run(inputs) 

237 

238 print('-----------') 

239 print("ONNX inference:", res['Z']) 

240 

241 The function to be converted may include python functions 

242 which must not be converted. In that case, their name 

243 must be prefixed by ``py_``. The execution of the function 

244 this one builds produces the following error:: 

245 

246 TypeError: Parameter to MergeFrom() must be instance of same class: 

247 expected onnx.TensorProto got onnx.AttributeProto. 

248 

249 It indicates that constants in the code marges multiple types, 

250 usually floats and tensor of floats. Floats should be converted 

251 using the following function:: 

252 

253 def py_make_float_array(cst): 

254 return numpy.array([cst], dtype=numpy.float32) 

255 

256 The function replaces empty contexts by default values which 

257 covers many :epkg:`numpy` functions. The tutorial 

258 :ref:`l-onnx-tutorial` gives an example of how it can be used 

259 on a more complex function. 

260 """ 

261 def compile_code(name, code, context=None): 

262 """ 

263 Compiles a python function with the given 

264 context. 

265 

266 @param name function name 

267 @param code python code 

268 @param context context used at compilation 

269 @return compiled function 

270 """ 

271 if context is None: 

272 context = {} # pragma: no cover 

273 try: 

274 obj = compile(code, "", "exec") 

275 except SyntaxError as e: # pragma: no cover 

276 raise SyntaxError(f"Unable to compile\n{code}") from e 

277 context_g = context.copy() 

278 context_l = context.copy() 

279 exec(obj, context_g, context_l) # pylint: disable=W0122 

280 return context_l[name] 

281 

282 if isinstance(fct, str): 

283 code = fct 

284 elif callable(fct): 

285 code = inspect.getsource(fct) 

286 else: 

287 raise TypeError( # pragma: no cover 

288 f"Unable to guess code from type {type(fct)}.") 

289 node = ast.parse(dedent(code)) 

290 v = CodeNodeVisitor() 

291 v.visit(node) 

292 if context is None: 

293 context = get_default_context() 

294 onnx_code = v.export(context=context, 

295 output_names=output_names) 

296 if not cpl: 

297 return onnx_code 

298 if verbose > 0 and fLOG is not None: # pragma: no cover 

299 fLOG('[translate_fct2onnx] python code') 

300 fLOG(code) 

301 fLOG('[translate_fct2onnx] ONNX code') 

302 fLOG(onnx_code) 

303 if context_cpl is None: 

304 context_cpl = get_default_context_cpl() 

305 if 'numpy' not in context_cpl: 

306 context_cpl = context_cpl.copy() 

307 context_cpl['numpy'] = numpy 

308 return compile_code(fct.__name__, onnx_code, context_cpl)