Coverage for mlprodict/onnxrt/ops_cpu/op_reduce_log_sum_exp.py: 96%
47 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from onnx.defs import onnx_opset_version
9from ._op import OpRunReduceNumpy, OpRun
12class ReduceLogSumExp_1(OpRunReduceNumpy):
14 atts = {'axes': [], 'keepdims': 1}
16 def __init__(self, onnx_node, desc=None, **options):
17 OpRunReduceNumpy.__init__(self, onnx_node, desc=desc,
18 expected_attributes=ReduceLogSumExp_1.atts,
19 **options)
21 def _run(self, data, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
22 tax = tuple(self.axes) if self.axes else None
23 data_max = data.copy()
24 ind = numpy.isinf(data_max)
25 data_max[ind] = -numpy.inf
26 mx = data_max.max(axis=tax, keepdims=True)
27 sub = numpy.subtract(data, mx)
28 exp = numpy.exp(sub, out=sub)
29 mxs = numpy.sum(exp, axis=tax,
30 keepdims=True,
31 dtype=data.dtype)
32 res = numpy.log(mxs) + mx
33 if not self.keepdims:
34 res = numpy.squeeze(res, axis=tax)
35 return (res, )
38class ReduceLogSumExp_18(OpRun):
40 atts = {'keepdims': 1, 'noop_with_empty_axes': 0}
42 def __init__(self, onnx_node, desc=None, **options):
43 OpRun.__init__(self, onnx_node, desc=desc,
44 expected_attributes=ReduceLogSumExp_18.atts,
45 **options)
47 def _run(self, data, axes=None, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
48 if ((axes is None or len(axes.shape) == 0 or axes.shape[0] == 0) and
49 self.noop_with_empty_axes):
50 return (data, )
51 if ((axes is not None and len(axes.shape) > 0 and axes.shape[0] > 0) and
52 not isinstance(axes, int)):
53 if isinstance(axes, numpy.ndarray) and len(axes.shape) == 0:
54 axes = int(axes)
55 else:
56 axes = tuple(axes.ravel().tolist()) if len(axes) > 0 else None
57 try:
58 tax = tuple(axes) if axes else None
59 data_max = data.copy()
60 ind = numpy.isinf(data_max)
61 data_max[ind] = -numpy.inf
62 mx = data_max.max(axis=tax, keepdims=True)
63 sub = numpy.subtract(data, mx)
64 exp = numpy.exp(sub, out=sub)
65 mxs = numpy.sum(exp, axis=tax,
66 keepdims=True,
67 dtype=data.dtype)
68 res = numpy.log(mxs) + mx
69 if not self.keepdims:
70 res = numpy.squeeze(res, axis=tax)
71 return (res, )
72 except TypeError as e: # pragma: no cover
73 raise TypeError(
74 f"Unable to reduce shape {data.shape!r} with axes={axes!r}.") from e
77if onnx_opset_version() >= 18:
78 ReduceLogSumExp = ReduceLogSumExp_18
79else: # pragma: no cover
80 ReduceLogSumExp = ReduceLogSumExp_1