Coverage for mlprodict/onnxrt/ops_cpu/op_expression.py: 84%

37 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7from ...onnx_tools.onnx2py_helper import guess_dtype 

8from ._op import OpRun 

9from ._new_ops import OperatorSchema 

10 

11 

12class Expression(OpRun): 

13 

14 atts = { 

15 'expression': None, 

16 } 

17 

18 def __init__(self, onnx_node, desc=None, **options): 

19 OpRun.__init__(self, onnx_node, desc=desc, 

20 expected_attributes=Expression.atts, 

21 **options) 

22 if not hasattr(self.expression, 'run'): 

23 raise RuntimeError( # pragma: no cover 

24 "Parameter 'expression' must have a method 'run', " 

25 "type {}.".format(type(self.then_branch))) 

26 

27 self._run_expression = (self.expression.run_in_scan 

28 if hasattr(self.expression, 'run_in_scan') 

29 else self.expression.run) 

30 self.additional_inputs = list(self.expression.static_inputs) 

31 self.input_names = [ 

32 i.name for i in self.onnx_node.attribute[0].g.input] 

33 

34 def _find_custom_operator_schema(self, op_name): 

35 if op_name == "Expression": 

36 return ExpressionSchema() 

37 raise RuntimeError( # pragma: no cover 

38 f"Unable to find a schema for operator '{op_name}'.") 

39 

40 def need_context(self): 

41 """ 

42 Tells the runtime if this node needs the context 

43 (all the results produced so far) as it may silently access 

44 one of them (operator Loop). 

45 The default answer is `False`. 

46 """ 

47 return True 

48 

49 def _run(self, *inputs, named_inputs=None, context=None, # pylint: disable=W0221 

50 attributes=None, verbose=0, fLOG=None): 

51 

52 if verbose > 0 and fLOG is not None: 

53 fLOG( # pragma: no cover 

54 f' -- expression> {list(context)!r}') 

55 if named_inputs is None: 

56 if len(inputs) != len(self.input_names): 

57 raise RuntimeError( # pragma: no cover 

58 "Unpexpected number of inputs (%d != %d): %r." % ( 

59 len(inputs), len(self.input_names), self.input_names)) 

60 named_inputs = {name: value for name, 

61 value in zip(self.input_names, inputs)} 

62 outputs = self._run_expression(named_inputs, context=context, 

63 attributes=attributes, 

64 verbose=verbose, fLOG=fLOG) 

65 if verbose > 0 and fLOG is not None: 

66 fLOG(' -- expression<') # pragma: no cover 

67 final = tuple([outputs[name] 

68 for name in self.expression.output_names]) 

69 return final 

70 

71 def _pick_type(self, res, name): 

72 if name in res: 

73 return res[name] 

74 out = {o.name: o for o in self.expression.obj.graph.output} 

75 if name not in out: 

76 raise ValueError( # pragma: no cover 

77 "Unable to find name=%r in %r or %r." % ( 

78 name, list(sorted(res)), list(sorted(out)))) 

79 dt = out[name].type.tensor_type.elem_type 

80 return guess_dtype(dt) 

81 

82 

83class ExpressionSchema(OperatorSchema): 

84 """ 

85 Defines a schema for operators added in this package 

86 such as @see cl ComplexAbs. 

87 """ 

88 

89 def __init__(self): 

90 OperatorSchema.__init__(self, 'Expression') 

91 self.attributes = Expression.atts