Coverage for mlprodict/onnxrt/ops_cpu/op_cast.py: 100%
32 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from onnx.onnx_pb import TensorProto # pylint: disable=E0611
9from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
10from ._op import OpRun
13class Cast(OpRun):
15 atts = {'to': None}
17 def __init__(self, onnx_node, desc=None, **options):
18 OpRun.__init__(self, onnx_node, desc=desc,
19 expected_attributes=Cast.atts,
20 **options)
21 if self.to == TensorProto.STRING: # pylint: disable=E1101
22 self._dtype = numpy.str_
23 else:
24 self._dtype = TENSOR_TYPE_TO_NP_TYPE[self.to]
25 self._cast = lambda x: x.astype(self._dtype)
27 def _run(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
28 if self.inplaces.get(0, False) and x.flags['WRITEABLE']:
29 return self._run_inplace(x)
30 return (self._cast(x), )
32 def _run_inplace(self, x):
33 if x.dtype == self._dtype:
34 return (x, )
35 return (self._cast(x), )
38class CastLike(OpRun):
40 def __init__(self, onnx_node, desc=None, **options):
41 OpRun.__init__(self, onnx_node, desc=desc, **options)
43 def _run(self, x, y, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
44 if self.inplaces.get(0, False) and x.flags['WRITEABLE']:
45 return self._run_inplace(x, y)
46 return (x.astype(y.dtype), )
48 def _run_inplace(self, x, y):
49 if x.dtype == y.dtype:
50 return (x, )
51 return (x.astype(y.dtype), )