Coverage for mlprodict/onnxrt/ops_cpu/op_expression.py: 84%
37 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7from ...onnx_tools.onnx2py_helper import guess_dtype
8from ._op import OpRun
9from ._new_ops import OperatorSchema
12class Expression(OpRun):
14 atts = {
15 'expression': None,
16 }
18 def __init__(self, onnx_node, desc=None, **options):
19 OpRun.__init__(self, onnx_node, desc=desc,
20 expected_attributes=Expression.atts,
21 **options)
22 if not hasattr(self.expression, 'run'):
23 raise RuntimeError( # pragma: no cover
24 "Parameter 'expression' must have a method 'run', "
25 "type {}.".format(type(self.then_branch)))
27 self._run_expression = (self.expression.run_in_scan
28 if hasattr(self.expression, 'run_in_scan')
29 else self.expression.run)
30 self.additional_inputs = list(self.expression.static_inputs)
31 self.input_names = [
32 i.name for i in self.onnx_node.attribute[0].g.input]
34 def _find_custom_operator_schema(self, op_name):
35 if op_name == "Expression":
36 return ExpressionSchema()
37 raise RuntimeError( # pragma: no cover
38 f"Unable to find a schema for operator '{op_name}'.")
40 def need_context(self):
41 """
42 Tells the runtime if this node needs the context
43 (all the results produced so far) as it may silently access
44 one of them (operator Loop).
45 The default answer is `False`.
46 """
47 return True
49 def _run(self, *inputs, named_inputs=None, context=None, # pylint: disable=W0221
50 attributes=None, verbose=0, fLOG=None):
52 if verbose > 0 and fLOG is not None:
53 fLOG( # pragma: no cover
54 f' -- expression> {list(context)!r}')
55 if named_inputs is None:
56 if len(inputs) != len(self.input_names):
57 raise RuntimeError( # pragma: no cover
58 "Unpexpected number of inputs (%d != %d): %r." % (
59 len(inputs), len(self.input_names), self.input_names))
60 named_inputs = {name: value for name,
61 value in zip(self.input_names, inputs)}
62 outputs = self._run_expression(named_inputs, context=context,
63 attributes=attributes,
64 verbose=verbose, fLOG=fLOG)
65 if verbose > 0 and fLOG is not None:
66 fLOG(' -- expression<') # pragma: no cover
67 final = tuple([outputs[name]
68 for name in self.expression.output_names])
69 return final
71 def _pick_type(self, res, name):
72 if name in res:
73 return res[name]
74 out = {o.name: o for o in self.expression.obj.graph.output}
75 if name not in out:
76 raise ValueError( # pragma: no cover
77 "Unable to find name=%r in %r or %r." % (
78 name, list(sorted(res)), list(sorted(out))))
79 dt = out[name].type.tensor_type.elem_type
80 return guess_dtype(dt)
83class ExpressionSchema(OperatorSchema):
84 """
85 Defines a schema for operators added in this package
86 such as @see cl ComplexAbs.
87 """
89 def __init__(self):
90 OperatorSchema.__init__(self, 'Expression')
91 self.attributes = Expression.atts