Coverage for mlprodict/onnxrt/ops_cpu/op_constant.py: 96%

57 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from onnx.defs import onnx_opset_version 

9from ._op import OpRun, RefAttrName 

10 

11 

12def _check_dtype(val): 

13 a = val.dtype 

14 if not isinstance(a, numpy.dtype) and a not in { 

15 numpy.int8, numpy.uint8, numpy.float16, numpy.float32, 

16 numpy.float64, numpy.int32, numpy.int64, numpy.int16, 

17 numpy.uint16, numpy.uint32, numpy.bool_, numpy.str_, 

18 numpy.uint64, bool, str, }: 

19 raise TypeError( # pragma: no cover 

20 f"Type ({a}, {type(a)}) is not a numpy type (operator 'Constant')") 

21 

22 

23class Constant_9(OpRun): 

24 

25 atts = {'value': numpy.array([0], dtype=numpy.float32)} 

26 

27 def __init__(self, onnx_node, desc=None, **options): 

28 OpRun.__init__(self, onnx_node, desc=desc, 

29 expected_attributes=Constant_9.atts, 

30 **options) 

31 self.cst = self.value 

32 _check_dtype(self.cst) 

33 

34 def _run(self, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

35 return (self.cst, ) 

36 

37 

38class Constant_11(OpRun): 

39 

40 atts = {'value': numpy.array([0], dtype=numpy.float32), 

41 'sparse_value': None, } 

42 

43 def __init__(self, onnx_node, desc=None, **options): 

44 OpRun.__init__(self, onnx_node, desc=desc, 

45 expected_attributes=Constant_11.atts, 

46 **options) 

47 if getattr(self, 'sparse_value', None) is not None: 

48 self.cst = self.sparse_value 

49 else: 

50 self.cst = self.value 

51 _check_dtype(self.cst) 

52 

53 def _run(self, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

54 return (self.cst, ) 

55 

56 

57class Constant_12(OpRun): 

58 

59 atts = {'value': numpy.array([0], dtype=numpy.float32), 

60 'sparse_value': None, 

61 'value_float': None, 

62 'value_floats': None, 

63 'value_int': None, 

64 'value_ints': None, 

65 'value_string': None, 

66 'value_strings': None, 

67 } 

68 

69 def __init__(self, onnx_node, desc=None, **options): 

70 OpRun.__init__(self, onnx_node, desc=desc, 

71 expected_attributes=Constant_12.atts, 

72 **options) 

73 if hasattr(self, 'sparse_value') and self.sparse_value is not None: 

74 self.cst = self.sparse_value 

75 elif hasattr(self, 'value_float') and self.value_float is not None: 

76 self.cst = numpy.array([self.value_float], dtype=numpy.float32) 

77 elif hasattr(self, 'value_floats') and self.value_floats is not None: 

78 self.cst = self.value_floats.astype(numpy.float32) 

79 elif hasattr(self, 'value_int') and self.value_int is not None: 

80 self.cst = numpy.array(self.value_int, dtype=numpy.int64) 

81 elif hasattr(self, 'value_ints') and self.value_ints is not None: 

82 self.cst = self.value_ints.astype(numpy.int64) 

83 elif hasattr(self, 'value_string') and self.value_string is not None: 

84 self.cst = self.value_string 

85 elif hasattr(self, 'value_strings') and self.value_strings is not None: 

86 self.cst = self.value_strings 

87 elif hasattr(self, 'value') and self.value is not None: 

88 self.cst = self.value 

89 else: 

90 raise AttributeError( # pragma: no cover 

91 "No constant is defined for operator 'Constant'.") 

92 if isinstance(self.cst, RefAttrName): 

93 self.is_linked_attribute = True 

94 else: 

95 self.is_linked_attribute = False 

96 _check_dtype(self.cst) 

97 

98 def _run(self, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

99 if self.is_linked_attribute: 

100 if attributes is None: 

101 raise RuntimeError( # pragma: no cover 

102 f"Attributes are empty, cannot retrieve value for {self.cst!r}.") 

103 if self.cst.name not in attributes: 

104 raise RuntimeError( # pragma: no cover 

105 f"Cannot find attribute {self.cst!r} in {list(attributes)!r}.") 

106 return (attributes[self.cst.name]['value'], ) 

107 return (self.cst, ) 

108 

109 

110if onnx_opset_version() >= 12: 

111 Constant = Constant_12 

112elif onnx_opset_version() >= 11: # pragma: no cover 

113 Constant = Constant_11 

114else: # pragma: no cover 

115 Constant = Constant_9