Coverage for mlprodict/onnxrt/ops_cpu/op_broadcast_gradient_args.py: 100%

48 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9from ._new_ops import OperatorSchema 

10 

11 

12class BroadcastGradientArgs(OpRun): 

13 

14 atts = {} 

15 

16 def __init__(self, onnx_node, desc=None, **options): 

17 OpRun.__init__(self, onnx_node, desc=desc, 

18 **options) 

19 

20 def _find_custom_operator_schema(self, op_name): 

21 if op_name == "BroadcastGradientArgs": 

22 return BroadcastGradientArgsSchema() 

23 raise RuntimeError( # pragma: no cover 

24 f"Unable to find a schema for operator '{op_name}'.") 

25 

26 def _run(self, a_shape, b_shape, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

27 

28 A_dims = a_shape 

29 B_dims = b_shape 

30 a_size = len(a_shape) 

31 b_size = len(b_shape) 

32 

33 ndim = max(a_size, b_size) 

34 

35 i = a_size - 1 

36 j = b_size - 1 

37 k = ndim - 1 

38 

39 a_axes = [] 

40 b_axes = [] 

41 

42 while i >= 0 and j >= 0: 

43 A_dim = A_dims[i] 

44 B_dim = B_dims[j] 

45 

46 if A_dim != B_dim: 

47 if A_dim == 1: 

48 a_axes.append(k) 

49 elif B_dim == 1: 

50 b_axes.append(k) 

51 else: 

52 a = A_dims[:a_size] 

53 b = B_dims[:b_size] 

54 raise RuntimeError( 

55 "Broadcast is not possible between inputs of " 

56 "shapes: %r and %r." % (a, b)) 

57 i -= 1 

58 j -= 1 

59 k -= 1 

60 

61 if i < 0: 

62 while k >= 0: 

63 a_axes.append(k) 

64 k -= 1 

65 else: 

66 while k >= 0: 

67 b_axes.append(k) 

68 k -= 1 

69 

70 return (numpy.array(a_axes, dtype=numpy.int64), 

71 numpy.array(b_axes, dtype=numpy.int64)) 

72 

73 

74class BroadcastGradientArgsSchema(OperatorSchema): 

75 """ 

76 Defines a schema for operators added in this package 

77 such as @see cl BroadcastGradientArgs. 

78 """ 

79 

80 def __init__(self): 

81 OperatorSchema.__init__(self, 'BroadcastGradientArgs') 

82 self.attributes = BroadcastGradientArgs.atts