Coverage for mlprodict/onnxrt/ops_cpu/op_broadcast_gradient_args.py: 100%
48 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
9from ._new_ops import OperatorSchema
12class BroadcastGradientArgs(OpRun):
14 atts = {}
16 def __init__(self, onnx_node, desc=None, **options):
17 OpRun.__init__(self, onnx_node, desc=desc,
18 **options)
20 def _find_custom_operator_schema(self, op_name):
21 if op_name == "BroadcastGradientArgs":
22 return BroadcastGradientArgsSchema()
23 raise RuntimeError( # pragma: no cover
24 f"Unable to find a schema for operator '{op_name}'.")
26 def _run(self, a_shape, b_shape, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
28 A_dims = a_shape
29 B_dims = b_shape
30 a_size = len(a_shape)
31 b_size = len(b_shape)
33 ndim = max(a_size, b_size)
35 i = a_size - 1
36 j = b_size - 1
37 k = ndim - 1
39 a_axes = []
40 b_axes = []
42 while i >= 0 and j >= 0:
43 A_dim = A_dims[i]
44 B_dim = B_dims[j]
46 if A_dim != B_dim:
47 if A_dim == 1:
48 a_axes.append(k)
49 elif B_dim == 1:
50 b_axes.append(k)
51 else:
52 a = A_dims[:a_size]
53 b = B_dims[:b_size]
54 raise RuntimeError(
55 "Broadcast is not possible between inputs of "
56 "shapes: %r and %r." % (a, b))
57 i -= 1
58 j -= 1
59 k -= 1
61 if i < 0:
62 while k >= 0:
63 a_axes.append(k)
64 k -= 1
65 else:
66 while k >= 0:
67 b_axes.append(k)
68 k -= 1
70 return (numpy.array(a_axes, dtype=numpy.int64),
71 numpy.array(b_axes, dtype=numpy.int64))
74class BroadcastGradientArgsSchema(OperatorSchema):
75 """
76 Defines a schema for operators added in this package
77 such as @see cl BroadcastGradientArgs.
78 """
80 def __init__(self):
81 OperatorSchema.__init__(self, 'BroadcastGradientArgs')
82 self.attributes = BroadcastGradientArgs.atts