Coverage for mlprodict/onnxrt/ops_cpu/op_constant_of_shape.py: 100%

20 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9 

10 

11class ConstantOfShape(OpRun): 

12 

13 atts = {'value': numpy.array([0], dtype=numpy.float32)} 

14 

15 def __init__(self, onnx_node, desc=None, **options): 

16 OpRun.__init__(self, onnx_node, desc=desc, 

17 expected_attributes=ConstantOfShape.atts, 

18 **options) 

19 self.cst = (self.value[0] 

20 if isinstance(self.value, numpy.ndarray) 

21 else self.value) 

22 if isinstance(self.cst, int): 

23 self.cst = numpy.int64(self.cst) 

24 elif isinstance(self.cst, float): 

25 self.cst = numpy.float64(self.cst) 

26 if not isinstance(self.cst, (numpy.float32, numpy.float64, 

27 numpy.int64, numpy.int32, numpy.bool_, 

28 numpy.float16)): 

29 raise TypeError( # pragma: no cover 

30 f"cst must be a real not {type(self.cst)}") 

31 

32 def _run(self, data, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

33 try: 

34 res = numpy.full(tuple(data), self.cst) 

35 except TypeError as e: # pragma: no cover 

36 raise RuntimeError( 

37 "Unable to create a constant of shape %r with value %r " 

38 "(raw value=%r)." % (data, self.cst, self.value)) from e 

39 return (res, ) 

40 

41 def to_python(self, inputs): 

42 lines = ['cst = value[0] if isinstance(value, numpy.ndarray) else value', 

43 f'return numpy.full(tuple({inputs[0]}), cst)'] 

44 return ("import numpy", "\n".join(lines))