Coverage for mlprodict/onnxrt/ops_cpu/op_reduce_log_sum.py: 88%
32 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from onnx.defs import onnx_opset_version
9from ._op import OpRunReduceNumpy, OpRun
12class ReduceLogSum_1(OpRunReduceNumpy):
14 atts = {'axes': [], 'keepdims': 1}
16 def __init__(self, onnx_node, desc=None, **options):
17 OpRunReduceNumpy.__init__(self, onnx_node, desc=desc,
18 expected_attributes=ReduceLogSum_1.atts,
19 **options)
21 def _run(self, data, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
22 tax = tuple(self.axes) if self.axes else None
23 res = numpy.sum(data, axis=tax, keepdims=self.keepdims)
24 if len(res.shape) > 0:
25 return (numpy.log(res, out=res), )
26 return (numpy.log(res).astype(data.dtype), )
29class ReduceLogSum_18(OpRun):
31 atts = {'keepdims': 1, 'noop_with_empty_axes': 0}
33 def __init__(self, onnx_node, desc=None, **options):
34 OpRun.__init__(self, onnx_node, desc=desc,
35 expected_attributes=ReduceLogSum_18.atts,
36 **options)
38 def _run(self, data, axes=None, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
39 if ((axes is None or len(axes.shape) == 0 or axes.shape[0] == 0) and
40 self.noop_with_empty_axes):
41 return (data, )
42 if ((axes is not None and len(axes.shape) > 0 and axes.shape[0] > 0) and
43 not isinstance(axes, int)):
44 if isinstance(axes, numpy.ndarray) and len(axes.shape) == 0:
45 axes = int(axes)
46 else:
47 axes = tuple(axes.ravel().tolist()) if len(axes) > 0 else None
48 try:
49 res = numpy.sum(data, axis=axes, keepdims=self.keepdims)
50 if len(res.shape) > 0:
51 return (numpy.log(res, out=res), )
52 return (numpy.log(res).astype(data.dtype), )
53 except TypeError as e: # pragma: no cover
54 raise TypeError(
55 f"Unable to reduce shape {data.shape!r} with axes={axes!r}.") from e
58if onnx_opset_version() >= 18:
59 ReduceLogSum = ReduceLogSum_18
60else: # pragma: no cover
61 ReduceLogSum = ReduceLogSum_1