Coverage for mlprodict/onnxrt/ops_cpu/op_constant.py: 96%
57 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from onnx.defs import onnx_opset_version
9from ._op import OpRun, RefAttrName
12def _check_dtype(val):
13 a = val.dtype
14 if not isinstance(a, numpy.dtype) and a not in {
15 numpy.int8, numpy.uint8, numpy.float16, numpy.float32,
16 numpy.float64, numpy.int32, numpy.int64, numpy.int16,
17 numpy.uint16, numpy.uint32, numpy.bool_, numpy.str_,
18 numpy.uint64, bool, str, }:
19 raise TypeError( # pragma: no cover
20 f"Type ({a}, {type(a)}) is not a numpy type (operator 'Constant')")
23class Constant_9(OpRun):
25 atts = {'value': numpy.array([0], dtype=numpy.float32)}
27 def __init__(self, onnx_node, desc=None, **options):
28 OpRun.__init__(self, onnx_node, desc=desc,
29 expected_attributes=Constant_9.atts,
30 **options)
31 self.cst = self.value
32 _check_dtype(self.cst)
34 def _run(self, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
35 return (self.cst, )
38class Constant_11(OpRun):
40 atts = {'value': numpy.array([0], dtype=numpy.float32),
41 'sparse_value': None, }
43 def __init__(self, onnx_node, desc=None, **options):
44 OpRun.__init__(self, onnx_node, desc=desc,
45 expected_attributes=Constant_11.atts,
46 **options)
47 if getattr(self, 'sparse_value', None) is not None:
48 self.cst = self.sparse_value
49 else:
50 self.cst = self.value
51 _check_dtype(self.cst)
53 def _run(self, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
54 return (self.cst, )
57class Constant_12(OpRun):
59 atts = {'value': numpy.array([0], dtype=numpy.float32),
60 'sparse_value': None,
61 'value_float': None,
62 'value_floats': None,
63 'value_int': None,
64 'value_ints': None,
65 'value_string': None,
66 'value_strings': None,
67 }
69 def __init__(self, onnx_node, desc=None, **options):
70 OpRun.__init__(self, onnx_node, desc=desc,
71 expected_attributes=Constant_12.atts,
72 **options)
73 if hasattr(self, 'sparse_value') and self.sparse_value is not None:
74 self.cst = self.sparse_value
75 elif hasattr(self, 'value_float') and self.value_float is not None:
76 self.cst = numpy.array([self.value_float], dtype=numpy.float32)
77 elif hasattr(self, 'value_floats') and self.value_floats is not None:
78 self.cst = self.value_floats.astype(numpy.float32)
79 elif hasattr(self, 'value_int') and self.value_int is not None:
80 self.cst = numpy.array(self.value_int, dtype=numpy.int64)
81 elif hasattr(self, 'value_ints') and self.value_ints is not None:
82 self.cst = self.value_ints.astype(numpy.int64)
83 elif hasattr(self, 'value_string') and self.value_string is not None:
84 self.cst = self.value_string
85 elif hasattr(self, 'value_strings') and self.value_strings is not None:
86 self.cst = self.value_strings
87 elif hasattr(self, 'value') and self.value is not None:
88 self.cst = self.value
89 else:
90 raise AttributeError( # pragma: no cover
91 "No constant is defined for operator 'Constant'.")
92 if isinstance(self.cst, RefAttrName):
93 self.is_linked_attribute = True
94 else:
95 self.is_linked_attribute = False
96 _check_dtype(self.cst)
98 def _run(self, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
99 if self.is_linked_attribute:
100 if attributes is None:
101 raise RuntimeError( # pragma: no cover
102 f"Attributes are empty, cannot retrieve value for {self.cst!r}.")
103 if self.cst.name not in attributes:
104 raise RuntimeError( # pragma: no cover
105 f"Cannot find attribute {self.cst!r} in {list(attributes)!r}.")
106 return (attributes[self.cst.name]['value'], )
107 return (self.cst, )
110if onnx_opset_version() >= 12:
111 Constant = Constant_12
112elif onnx_opset_version() >= 11: # pragma: no cover
113 Constant = Constant_11
114else: # pragma: no cover
115 Constant = Constant_9