Coverage for mlprodict/onnxrt/ops_cpu/op_reduce_sum_square.py: 92%
25 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from onnx.defs import onnx_opset_version
9from ._op import OpRunReduceNumpy, OpRun
12class ReduceSumSquare_1(OpRunReduceNumpy):
14 atts = {'axes': [], 'keepdims': 1}
16 def __init__(self, onnx_node, desc=None, **options):
17 OpRunReduceNumpy.__init__(self, onnx_node, desc=desc,
18 expected_attributes=ReduceSumSquare_1.atts,
19 **options)
21 def _run(self, data, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
22 return (numpy.sum(numpy.square(data), axis=self.axes,
23 keepdims=self.keepdims), )
26class ReduceSumSquare_18(OpRun):
28 atts = {'keepdims': 1, 'noop_with_empty_axes': 0}
30 def __init__(self, onnx_node, desc=None, **options):
31 OpRun.__init__(self, onnx_node, desc=desc,
32 expected_attributes=ReduceSumSquare_18.atts,
33 **options)
35 def _run(self, data, axes=None, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
36 if ((axes is None or len(axes.shape) == 0 or axes.shape[0] == 0) and
37 self.noop_with_empty_axes):
38 return (data, )
39 if ((axes is not None and len(axes.shape) > 0 and axes.shape[0] > 0) and
40 not isinstance(axes, int)):
41 if isinstance(axes, numpy.ndarray) and len(axes.shape) == 0:
42 axes = int(axes)
43 else:
44 axes = tuple(axes.ravel().tolist()) if len(axes) > 0 else None
45 try:
46 return (numpy.sum(numpy.square(data), axis=axes if axes else None,
47 keepdims=self.keepdims,
48 dtype=data.dtype), )
49 except TypeError as e: # pragma: no cover
50 raise TypeError(
51 f"Unable to reduce shape {data.shape!r} with axes={axes!r}.") from e
54if onnx_opset_version() >= 18:
55 ReduceSumSquare = ReduceSumSquare_18
56else: # pragma: no cover
57 ReduceSumSquare = ReduceSumSquare_1