Coverage for mlprodict/onnxrt/ops_cpu/op_clip.py: 100%
48 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7from collections import OrderedDict
8import numpy
9from onnx.defs import onnx_opset_version
10from ._op import OpRunUnaryNum
13class Clip_6(OpRunUnaryNum):
15 atts = {'min': -3.4028234663852886e+38,
16 'max': 3.4028234663852886e+38}
18 def __init__(self, onnx_node, desc=None, **options):
19 OpRunUnaryNum.__init__(self, onnx_node, desc=desc,
20 expected_attributes=Clip_6.atts,
21 **options)
23 def _run(self, data, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
24 if self.inplaces.get(0, False) and data.flags['WRITEABLE']:
25 return self._run_inplace(data)
26 res = numpy.clip(data, self.min, self.max)
27 return (res, ) if res.dtype == data.dtype else (res.astype(data.dtype), )
29 def _run_inplace(self, data):
30 return (numpy.clip(data, self.min, self.max, out=data), )
32 def to_python(self, inputs):
33 return ("import numpy",
34 f"return numpy.clip({inputs[0]}, min_, max_)")
37class Clip_11(OpRunUnaryNum):
39 version_higher_than = 11
40 mandatory_inputs = ['X']
41 optional_inputs = OrderedDict([
42 ('min', -3.4028234663852886e+38),
43 ('max', 3.4028234663852886e+38)
44 ])
46 def __init__(self, onnx_node, desc=None, **options):
47 OpRunUnaryNum.__init__(self, onnx_node, desc=desc,
48 **options)
50 def run(self, x, *minmax, attributes=None, verbose=0, fLOG=None): # pylint: disable=E0202,W0221
51 """
52 Calls method ``_run``.
53 """
54 try:
55 res = self._run(x, *minmax, attributes=attributes,
56 verbose=verbose, fLOG=fLOG)
57 except TypeError as e: # pragma: no cover
58 raise TypeError("Issues with types {} (binary operator {}).".format(
59 ", ".join(str(type(_)) for _ in [x]),
60 self.__class__.__name__)) from e
61 return res
63 def _run(self, data, *minmax, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
64 if self.inplaces.get(0, False) and data.flags['WRITEABLE']:
65 return self._run_inplace(data, *minmax)
66 le = len(minmax)
67 amin = minmax[0] if le > 0 else None # -3.4028234663852886e+38
68 amax = minmax[1] if le > 1 else None # 3.4028234663852886e+38
69 if amin is None and amax is None:
70 amin = -numpy.inf
71 res = numpy.clip(data, amin, amax)
72 return (res, ) if res.dtype == data.dtype else (res.astype(data.dtype), )
74 def _run_inplace(self, data, *minmax): # pylint: disable=W0221
75 le = len(minmax)
76 amin = minmax[0] if le > 0 else None # -3.4028234663852886e+38
77 amax = minmax[1] if le > 1 else None # 3.4028234663852886e+38
78 res = numpy.clip(data, amin, amax, out=data)
79 return (res, )
81 def to_python(self, inputs):
82 return ("import numpy",
83 f"return numpy.clip({inputs[0]}, min_, max_)")
86if onnx_opset_version() >= 11:
87 Clip = Clip_11
88else:
89 Clip = Clip_6 # pragma: no cover