Coverage for mlprodict/onnxrt/ops_cpu/op_constant_of_shape.py: 100%
20 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
11class ConstantOfShape(OpRun):
13 atts = {'value': numpy.array([0], dtype=numpy.float32)}
15 def __init__(self, onnx_node, desc=None, **options):
16 OpRun.__init__(self, onnx_node, desc=desc,
17 expected_attributes=ConstantOfShape.atts,
18 **options)
19 self.cst = (self.value[0]
20 if isinstance(self.value, numpy.ndarray)
21 else self.value)
22 if isinstance(self.cst, int):
23 self.cst = numpy.int64(self.cst)
24 elif isinstance(self.cst, float):
25 self.cst = numpy.float64(self.cst)
26 if not isinstance(self.cst, (numpy.float32, numpy.float64,
27 numpy.int64, numpy.int32, numpy.bool_,
28 numpy.float16)):
29 raise TypeError( # pragma: no cover
30 f"cst must be a real not {type(self.cst)}")
32 def _run(self, data, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
33 try:
34 res = numpy.full(tuple(data), self.cst)
35 except TypeError as e: # pragma: no cover
36 raise RuntimeError(
37 "Unable to create a constant of shape %r with value %r "
38 "(raw value=%r)." % (data, self.cst, self.value)) from e
39 return (res, )
41 def to_python(self, inputs):
42 lines = ['cst = value[0] if isinstance(value, numpy.ndarray) else value',
43 f'return numpy.full(tuple({inputs[0]}), cst)']
44 return ("import numpy", "\n".join(lines))