Coverage for mlprodict/onnxrt/ops_cpu/op_cast.py: 100%

32 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from onnx.onnx_pb import TensorProto # pylint: disable=E0611 

9from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE 

10from ._op import OpRun 

11 

12 

13class Cast(OpRun): 

14 

15 atts = {'to': None} 

16 

17 def __init__(self, onnx_node, desc=None, **options): 

18 OpRun.__init__(self, onnx_node, desc=desc, 

19 expected_attributes=Cast.atts, 

20 **options) 

21 if self.to == TensorProto.STRING: # pylint: disable=E1101 

22 self._dtype = numpy.str_ 

23 else: 

24 self._dtype = TENSOR_TYPE_TO_NP_TYPE[self.to] 

25 self._cast = lambda x: x.astype(self._dtype) 

26 

27 def _run(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

28 if self.inplaces.get(0, False) and x.flags['WRITEABLE']: 

29 return self._run_inplace(x) 

30 return (self._cast(x), ) 

31 

32 def _run_inplace(self, x): 

33 if x.dtype == self._dtype: 

34 return (x, ) 

35 return (self._cast(x), ) 

36 

37 

38class CastLike(OpRun): 

39 

40 def __init__(self, onnx_node, desc=None, **options): 

41 OpRun.__init__(self, onnx_node, desc=desc, **options) 

42 

43 def _run(self, x, y, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

44 if self.inplaces.get(0, False) and x.flags['WRITEABLE']: 

45 return self._run_inplace(x, y) 

46 return (x.astype(y.dtype), ) 

47 

48 def _run_inplace(self, x, y): 

49 if x.dtype == y.dtype: 

50 return (x, ) 

51 return (x.astype(y.dtype), )