Coverage for mlprodict/npy/onnx_numpy_compiler.py: 97%
204 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Implements :epkg:`numpy` functions with onnx and a runtime.
5.. versionadded:: 0.6
6"""
7import inspect
8import logging
9from typing import Any
10import numpy
11from ..onnx_tools.optim._main_onnx_optim import onnx_optimisations
12from .onnx_version import FctVersion
13from .onnx_numpy_annotation import get_args_kwargs
14from .xop_variable import Variable
15from .xop import OnnxOperator, OnnxOperatorTuple
18logger = logging.getLogger('xop')
21class OnnxNumpyFunction:
22 """
23 Class wrapping a function build with
24 @see cl OnnxNumpyCompiler.
26 .. versionadded:: 0.6
27 """
29 def __init__(self, compiler, rt, inputs, outputs,
30 n_optional, n_variables):
31 if any(map(lambda n: not isinstance(n, Variable), inputs)):
32 raise TypeError( # pragma: no cover
33 f"All inputs must be of type Variable: {inputs!r}.")
34 if any(map(lambda n: not isinstance(n, Variable), outputs)):
35 raise TypeError( # pragma: no cover
36 f"All outputs must be of type Variable: {outputs!r}.")
37 self.compiler = compiler
38 self.inputs = inputs
39 self.outputs = outputs
40 self.rt = rt
41 self.n_optional = n_optional
42 self.n_variables = n_variables
43 if n_optional < 0:
44 raise RuntimeError( # pragma: no cover
45 f"Wrong configuration, n_optional {n_optional!r} must be >= 0.")
46 if n_optional >= len(inputs):
47 raise RuntimeError( # pragma: no cover
48 "Wrong configuration, n_optional %r must be >= %r "
49 "the number of inputs." % (n_optional, len(inputs)))
51 def _check_(self, *args, **kwargs):
52 if self.n_variables > 0:
53 return
54 if (len(args) < len(self.inputs) - self.n_optional or
55 len(args) > len(self.inputs)):
56 raise RuntimeError( # pragma: no cover
57 "Unexpected number of inputs %d. It should be in "
58 "[%r, %r] len(args)=%d n_optional=%d n_variables=%d"
59 "\nargs=%s\nkwargs=%s\ninputs=%s" % (
60 len(args), len(self.inputs) - self.n_optional,
61 len(args), self.n_optional, self.n_variables,
62 len(self.inputs), args, kwargs, self.inputs))
65class OnnxNumpyFunctionOnnxInference(OnnxNumpyFunction):
66 """
67 Overwrites @see cl OnnxNumpyFunction to run an instance of
68 @see cl OnnxInference.
70 .. versionadded:: 0.6
71 """
73 def __call__(self, *args, **kwargs):
74 self._check_(*args, **kwargs)
75 inp = {k.name: a for k, a in zip(self.inputs, args)}
76 out = self.rt.run(inp, **kwargs)
77 if len(out) != len(self.outputs):
78 raise RuntimeError( # pragma: no cover
79 "Unexpected number of outputs %d instead of %d." % (
80 len(out), len(self.outputs)))
81 return tuple([out[o.name] for o in self.outputs])
84class OnnxNumpyFunctionInferenceSession(OnnxNumpyFunction):
85 """
86 Overwrites @see cl OnnxNumpyFunction to run an instance of
87 `InferenceSession` from :epkg:`onnxruntime`.
89 .. versionadded:: 0.6
90 """
92 def __call__(self, *args, **kwargs):
93 self._check_(*args, **kwargs)
94 if len(kwargs) > 0:
95 raise RuntimeError( # pragma: no cover
96 f"kwargs is not used but it is not empty: {kwargs!r}.")
97 inp = {k.name: a for k, a in zip(self.inputs, args)}
98 out = self.rt.run(None, inp)
100 if len(out) != len(self.outputs):
101 raise RuntimeError( # pragma: no cover
102 "Unexpected number of outputs %d instead of %d." % (
103 len(out), len(self.outputs)))
104 return tuple(out)
107class OnnxNumpyCompiler:
108 """
109 Implements a class which runs onnx graph.
111 :param fct: a function with annotations which returns an ONNX graph,
112 it can also be an ONNX graph.
113 :param op_version: :epkg:`ONNX` opset to use, None
114 for the latest one
115 :param runtime: runtime to choose to execute the onnx graph,
116 `python`, `onnxruntime`, `onnxruntime1`
117 :param signature: used when the function is not annotated
118 :param version: the same function can be instantiated with
119 different type, this parameter is None or a numpy type
120 if the signature allows multiple types, it must an instance
121 of type @see cl FctVersion
122 :param fctsig: function used to overwrite the fct signature
123 in case this one is using `*args, **kwargs`
125 .. versionadded:: 0.6
126 """
128 def __init__(self, fct, op_version=None, runtime=None, signature=None,
129 version=None, fctsig=None):
130 if version is not None and not isinstance(version, FctVersion):
131 raise TypeError( # pragma: no cover
132 "version must be of Type 'FctVersion' not %s - %s"
133 "." % (type(version), version))
134 self.fctsig = fctsig
135 if op_version is None:
136 from .. import __max_supported_opset__
137 op_version = __max_supported_opset__
138 if hasattr(fct, 'SerializeToString'):
139 self.fct_ = None
140 self.onnx_ = fct
141 else:
142 self.fct_ = fct
143 if not inspect.isfunction(fct):
144 raise TypeError( # pragma: no cover
145 f"Unexpected type for fct={type(fct)!r}, it must be a function.")
146 self.onnx_ = None
147 self.onnx_ = self._to_onnx(
148 op_version=op_version, signature=signature,
149 version=version)
150 self.runtime_ = self._build_runtime(
151 op_version=op_version, runtime=runtime,
152 signature=signature, version=version)
153 ann = self._parse_annotation(signature=signature, version=version)
154 inputs, outputs, kwargs, n_optional, n_variables = ann
155 n_opt = 0 if signature is None else signature.n_optional
156 args, kwargs2 = get_args_kwargs(self.fctsig or self.fct_, n_opt)
157 self.meta_ = dict(op_version=op_version, runtime=runtime,
158 signature=signature, version=version,
159 inputs=inputs, outputs=outputs,
160 kwargs=kwargs, n_optional=n_optional,
161 n_variables=n_variables,
162 args=args, kwargs2=kwargs2,
163 annotations=self.fct_.__annotations__)
165 def __getstate__(self):
166 """
167 Serializes everything but function `fct_`.
168 Function `fct_` is used to build the onnx graph
169 and is not needed anymore.
170 """
171 return dict(onnx_=self.onnx_, meta_=self.meta_)
173 def __setstate__(self, state):
174 """
175 Restores serialized data.
176 """
177 for k, v in state.items():
178 setattr(self, k, v)
179 self.runtime_ = self._build_runtime(
180 op_version=self.meta_['op_version'],
181 runtime=self.meta_['runtime'],
182 signature=self.meta_['signature'],
183 version=self.meta_['version'])
185 def __repr__(self):
186 "usual"
187 if self.fct_ is not None:
188 return f"{self.__class__.__name__}({repr(self.fct_)})"
189 if self.onnx_ is not None:
190 return f"{self.__class__.__name__}({'... ONNX ... '})"
191 raise NotImplementedError( # pragma: no cover
192 "fct_ and onnx_ are empty.")
194 def _to_onnx_shape(self, shape):
195 if shape is Any or shape is Ellipsis:
196 shape = None
197 elif isinstance(shape, tuple):
198 shape = [None if s is Any or s is Ellipsis else s
199 for s in shape]
200 else:
201 raise RuntimeError( # pragma: no cover
202 f"Unexpected annotated shape {shape!r}.")
203 return shape
205 def _parse_annotation(self, signature, version):
206 """
207 Returns the annotations for function `fct_`.
209 :param signature: needed if the annotation is missing,
210 then version might be needed to specify which type
211 to use if the signature allows many
212 :param version: version inside the many signatures possible
213 :return: *tuple(inputs, outputs, kwargs)*, each of them
214 is a list of tuple with the name and the dtype,
215 *kwargs* is the list of additional parameters
216 """
217 n_opt = 0 if signature is None else signature.n_optional
218 if hasattr(self, 'meta_'):
219 args, kwargs = self.meta_['args'], self.meta_['kwargs2']
220 else:
221 args, kwargs = get_args_kwargs(self.fctsig or self.fct_, n_opt)
222 if version is not None:
223 nv = len(version) - len(args) - n_opt
224 if (signature is not None and not
225 signature.n_variables and nv > len(kwargs)):
226 raise RuntimeError( # pragma: no cover
227 "Mismatch (%d - %d - %d ? %d) between version=%r and kwargs=%r for "
228 "function %r, optional argument is %d, "
229 "signature=%r." % (
230 len(version), len(args), n_opt, len(kwargs),
231 version, kwargs, self.fct_,
232 signature.n_variables, signature))
233 vvers = {} if version.kwargs is None else version.kwargs
234 up = {}
235 for k, v in zip(kwargs, vvers):
236 up[k] = v
237 kwargs = kwargs.copy()
238 kwargs.update(up)
240 for k, v in kwargs.items():
241 if isinstance(v, (type, numpy.dtype)):
242 raise RuntimeError( # pragma: no cover
243 f"Unexpected value for argument {k!r}: {v!r} from {kwargs!r}.")
245 if signature is not None:
246 inputs, kwargs, outputs, n_optional, n_variables = (
247 signature.get_inputs_outputs(args, kwargs, version))
248 inputs = [Variable(i[0], i[1]) for i in inputs]
249 outputs = [Variable(i[0], i[1]) for i in outputs]
250 return inputs, outputs, kwargs, n_optional, n_variables
252 def _possible_names():
253 yield 'y'
254 yield 'z' # pragma: no cover
255 yield 'o' # pragma: no cover
256 for i in range(0, 10000): # pragma: no cover
257 yield 'o%d' % i
259 if hasattr(self, 'meta_'):
260 annotations = self.meta_['annotations']
261 else:
262 annotations = self.fct_.__annotations__
263 inputs = []
264 outputs = []
265 for a in args:
266 if a == "op_version":
267 continue
268 if a not in annotations:
269 raise RuntimeError( # pragma: no cover
270 "Unable to find annotation for argument %r. "
271 "You should annotate the arguments and the results "
272 "or specify a signature." % a)
273 ann = annotations[a]
274 shape, dtype = ann.__args__
275 shape = self._to_onnx_shape(shape)
276 inputs.append(Variable(a, dtype, shape=shape))
278 ret = annotations['return']
279 names_in = set(inp.name for inp in inputs)
281 if isinstance(ret, tuple):
282 # multiple outputs
283 names_none = set()
284 for shape_dtype in ret:
285 shape, dtype = shape_dtype.__args__
286 shape = self._to_onnx_shape(shape)
287 name_out = None
288 for name in _possible_names():
289 if name not in names_in and name not in names_none:
290 name_out = name
291 break
292 outputs.append(Variable(name_out, dtype, shape=shape))
293 names_none.add(name_out)
294 return (inputs, outputs, kwargs, 0,
295 signature.n_variables if signature is not None else False)
297 # single outputs
298 shape, dtype = ret.__args__
299 shape = self._to_onnx_shape(shape)
300 name_out = None
301 for name in _possible_names():
302 if name not in names_in:
303 name_out = name
304 break
305 outputs.append(Variable(name_out, dtype, shape=shape))
306 return (inputs, outputs, kwargs, 0,
307 signature.n_variables if signature is not None else False)
309 def _find_hidden_algebras(self, onx_var, onx_algebra):
310 """
311 Subgraph are using inputs not linked to the others nodes.
312 This function retrieves them as they are stored in
313 attributes `alg_hidden_var_`. The function looks into every
314 node linked to the inputs and their predecessors.
316 :param onx_var: @see cl OnnxVar
317 :param onx_algebra: OnnxOperator
318 :return: tuple(dictionary `{id(obj): (var, obj)}`,
319 all instance of @see cl OnnxVarGraph)
320 """
321 keep_hidden = {}
322 var_graphs = []
323 stack = [onx_var]
324 while len(stack) > 0:
325 var = stack.pop()
326 hidden = getattr(var, 'alg_hidden_var_', None)
327 if hidden is not None:
328 if any(map(lambda x: len(x) > 0,
329 var.alg_hidden_var_inputs.values())):
330 keep_hidden.update(hidden)
331 var_graphs.append(var)
332 if hasattr(var, 'inputs'):
333 for inp in var.inputs:
334 stack.append(inp)
335 return keep_hidden, var_graphs
337 def _to_onnx(self, op_version=None, signature=None, version=None):
338 """
339 Returns the onnx graph produced by function `fct_`.
340 """
341 if self.onnx_ is None and self.fct_ is not None:
342 from .onnx_variable import OnnxVar
343 logger.debug('OnnxNumpyCompiler._to_onnx(op_version=%r, '
344 'signature=%r, version=%r)',
345 op_version, signature, version)
346 inputs, outputs, kwargs, n_optional, n_variables = ( # pylint: disable=W0612
347 self._parse_annotation(
348 signature=signature, version=version))
349 if ((signature is None or not signature.n_variables) and
350 isinstance(version, tuple) and
351 len(inputs) > len(version)):
352 raise NotImplementedError( # pragma: no cover
353 "Mismatch between additional parameters %r "
354 "(n_optional=%r) and version %r for function %r from %r."
355 "" % (kwargs, n_optional, version, self.fct_,
356 getattr(self.fct_, '__module__', None)))
357 names_in = [oi.name for oi in inputs]
358 names_out = [oi.name for oi in outputs]
359 names_var = [OnnxVar(n, dtype=dt.dtype)
360 for n, dt in zip(names_in, inputs)]
362 logger.debug('OnnxNumpyCompiler._to_onnx:names_in=%r', names_in)
363 logger.debug('OnnxNumpyCompiler._to_onnx:names_out=%r', names_out)
365 if 'op_version' in self.fct_.__code__.co_varnames:
366 onx_var = None
367 onx_algebra = self.fct_(
368 *names_in, op_version=op_version, **kwargs)
369 else:
370 onx_var = self.fct_(*names_var, **kwargs)
371 if not hasattr(onx_var, 'to_algebra'):
372 raise TypeError( # pragma: no cover
373 "The function %r to convert must return an instance of "
374 "OnnxVar but returns type %r." % (self.fct_, type(onx_var)))
375 onx_algebra = onx_var.to_algebra(op_version=op_version)
377 logger.debug('OnnxNumpyCompiler._to_onnx:onx_var=%r',
378 type(onx_var))
379 logger.debug('OnnxNumpyCompiler._to_onnx:onx_algebra=%r',
380 type(onx_algebra))
382 if not isinstance(onx_algebra, (OnnxOperator, OnnxOperatorTuple)):
383 raise TypeError( # pragma: no cover
384 "Unexpected type for onx_algebra %r "
385 "(It should be OnnxOperator or OnnxOperatorItem), "
386 "function is %r." % (type(onx_algebra), self.fct_))
387 hidden_algebras, var_graphs = self._find_hidden_algebras(
388 onx_var, onx_algebra)
389 if len(hidden_algebras) > 0:
390 logger.debug( # pragma: no cover
391 'OnnxNumpyCompiler._to_onnx:len(hidden_algebras)=%r',
392 len(hidden_algebras))
393 # print('----1', len(var_graphs))
394 # for gr in var_graphs:
395 # print(type(gr), dir(gr))
396 # print('----2', len(hidden_algebras))
397 # for k, v in hidden_algebras.items():
398 # print("*", type(v.alg_), dir(v.alg_))
399 # #import pprint
400 # #pprint.pprint(dir(v.alg_))
401 raise NotImplementedError( # pragma: no cover
402 "Subgraphs only support constants (operator If, Loop, "
403 "Scan). hidden_algebras=%r var_graphs=%r" % (
404 hidden_algebras, var_graphs))
406 if isinstance(onx_algebra, str):
407 raise RuntimeError( # pragma: no cover
408 f"Unexpected str type {onx_algebra!r}.")
409 if isinstance(onx_algebra, tuple):
410 raise NotImplementedError( # pragma: no cover
411 "Not implemented when the function returns multiple results.")
412 if hasattr(onx_algebra, 'to_onnx'):
413 onx_algebra.output_names = [Variable(n) for n in names_out]
414 onx = onx_algebra.to_onnx(
415 inputs=inputs, target_opset=op_version, outputs=outputs)
416 # optimisation
417 onx_optimized = onnx_optimisations(onx)
418 self.onnx_ = onx_optimized
420 if self.onnx_ is None:
421 raise RuntimeError( # pragma: no cover
422 "Unable to get the ONNX graph (class %r, fct_=%r)" % (
423 type(self), self.fct_))
424 return self.onnx_
426 def to_onnx(self, **kwargs):
427 """
428 Returns the ONNX graph for the wrapped function.
429 It takes additional arguments to distinguish between multiple graphs.
430 This happens when a function needs to support multiple type.
432 :return: ONNX graph
433 """
434 if len(kwargs) > 0:
435 raise NotImplementedError( # pragma: no cover
436 "kwargs is not empty, this case is not implemented. "
437 "kwargs=%r." % kwargs)
438 if hasattr(self, 'onnx_'):
439 return self.onnx_
440 raise NotImplementedError( # pragma: no cover
441 "Attribute 'onnx_' is missing.")
443 def _build_runtime(self, op_version=None, runtime=None,
444 signature=None, version=None):
445 """
446 Creates the runtime for the :epkg:`ONNX` graph.
448 :param op_version: :epkg:`ONNX` opset to use, None
449 for the latest one
450 :param runtime: runtime to choose to execute the onnx graph,
451 `python`, `onnxruntime`, `onnxruntime1`
452 :param signature: used when the function is not annotated
453 """
454 onx = self._to_onnx(op_version=op_version, signature=signature,
455 version=version)
456 inputs, outputs, _, n_optional, n_variables = self._parse_annotation(
457 signature=signature, version=version)
458 if runtime not in ('onnxruntime', 'onnxruntime-cuda'):
459 from ..onnxrt import OnnxInference
460 rt = OnnxInference(onx, runtime=runtime)
461 self.rt_fct_ = OnnxNumpyFunctionOnnxInference(
462 self, rt, inputs=inputs, outputs=outputs,
463 n_optional=n_optional, n_variables=n_variables)
464 else:
465 from ..tools.ort_wrapper import InferenceSession
466 rt = InferenceSession(onx.SerializeToString(), runtime=runtime)
467 self.rt_fct_ = OnnxNumpyFunctionInferenceSession(
468 self, rt, inputs=inputs, outputs=outputs,
469 n_optional=n_optional, n_variables=n_variables)
470 return self.rt_fct_
472 def __call__(self, *args, **kwargs):
473 """
474 Executes the function and returns the results.
476 :param args: arguments
477 :return: results
478 """
479 res = self.rt_fct_(*args, **kwargs)
480 if len(res) == 1:
481 return res[0]
482 return res