Coverage for mlprodict/onnxrt/ops_cpu/op_clip.py: 100%

48 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7from collections import OrderedDict 

8import numpy 

9from onnx.defs import onnx_opset_version 

10from ._op import OpRunUnaryNum 

11 

12 

13class Clip_6(OpRunUnaryNum): 

14 

15 atts = {'min': -3.4028234663852886e+38, 

16 'max': 3.4028234663852886e+38} 

17 

18 def __init__(self, onnx_node, desc=None, **options): 

19 OpRunUnaryNum.__init__(self, onnx_node, desc=desc, 

20 expected_attributes=Clip_6.atts, 

21 **options) 

22 

23 def _run(self, data, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

24 if self.inplaces.get(0, False) and data.flags['WRITEABLE']: 

25 return self._run_inplace(data) 

26 res = numpy.clip(data, self.min, self.max) 

27 return (res, ) if res.dtype == data.dtype else (res.astype(data.dtype), ) 

28 

29 def _run_inplace(self, data): 

30 return (numpy.clip(data, self.min, self.max, out=data), ) 

31 

32 def to_python(self, inputs): 

33 return ("import numpy", 

34 f"return numpy.clip({inputs[0]}, min_, max_)") 

35 

36 

37class Clip_11(OpRunUnaryNum): 

38 

39 version_higher_than = 11 

40 mandatory_inputs = ['X'] 

41 optional_inputs = OrderedDict([ 

42 ('min', -3.4028234663852886e+38), 

43 ('max', 3.4028234663852886e+38) 

44 ]) 

45 

46 def __init__(self, onnx_node, desc=None, **options): 

47 OpRunUnaryNum.__init__(self, onnx_node, desc=desc, 

48 **options) 

49 

50 def run(self, x, *minmax, attributes=None, verbose=0, fLOG=None): # pylint: disable=E0202,W0221 

51 """ 

52 Calls method ``_run``. 

53 """ 

54 try: 

55 res = self._run(x, *minmax, attributes=attributes, 

56 verbose=verbose, fLOG=fLOG) 

57 except TypeError as e: # pragma: no cover 

58 raise TypeError("Issues with types {} (binary operator {}).".format( 

59 ", ".join(str(type(_)) for _ in [x]), 

60 self.__class__.__name__)) from e 

61 return res 

62 

63 def _run(self, data, *minmax, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

64 if self.inplaces.get(0, False) and data.flags['WRITEABLE']: 

65 return self._run_inplace(data, *minmax) 

66 le = len(minmax) 

67 amin = minmax[0] if le > 0 else None # -3.4028234663852886e+38 

68 amax = minmax[1] if le > 1 else None # 3.4028234663852886e+38 

69 if amin is None and amax is None: 

70 amin = -numpy.inf 

71 res = numpy.clip(data, amin, amax) 

72 return (res, ) if res.dtype == data.dtype else (res.astype(data.dtype), ) 

73 

74 def _run_inplace(self, data, *minmax): # pylint: disable=W0221 

75 le = len(minmax) 

76 amin = minmax[0] if le > 0 else None # -3.4028234663852886e+38 

77 amax = minmax[1] if le > 1 else None # 3.4028234663852886e+38 

78 res = numpy.clip(data, amin, amax, out=data) 

79 return (res, ) 

80 

81 def to_python(self, inputs): 

82 return ("import numpy", 

83 f"return numpy.clip({inputs[0]}, min_, max_)") 

84 

85 

86if onnx_opset_version() >= 11: 

87 Clip = Clip_11 

88else: 

89 Clip = Clip_6 # pragma: no cover