Coverage for mlprodict/onnxrt/ops_cpu/op_fused_matmul.py: 88%

59 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9from ._new_ops import OperatorSchema 

10 

11 

12class FusedMatMul(OpRun): 

13 

14 atts = {'alpha': 1., 'transA': 0, 'transB': 0} 

15 

16 def __init__(self, onnx_node, desc=None, **options): 

17 OpRun.__init__(self, onnx_node, desc=desc, 

18 expected_attributes=FusedMatMul.atts, 

19 **options) 

20 if self.transA: 

21 _meth = (FusedMatMul._fmatmul11 if self.transB 

22 else FusedMatMul._fmatmul10) 

23 else: 

24 _meth = (FusedMatMul._fmatmul01 if self.transB 

25 else FusedMatMul._fmatmul00) 

26 self._meth_ = _meth 

27 self._meth = lambda a, b: _meth(a, b, self.alpha) 

28 # more recent versions of the operator 

29 if not hasattr(self, "transBatchA"): 

30 self.transBatchA = 0 

31 if not hasattr(self, "transBatchB"): 

32 self.transBatchB = 0 

33 

34 def _find_custom_operator_schema(self, op_name): 

35 if op_name == "FusedMatMul": 

36 return FusedMatMulSchema() 

37 raise RuntimeError( # pragma: no cover 

38 f"Unable to find a schema for operator '{op_name}'.") 

39 

40 @staticmethod 

41 def _fmatmul00(a, b, alpha): 

42 return numpy.matmul(a, b) * alpha 

43 

44 @staticmethod 

45 def _fmatmul01(a, b, alpha): 

46 return numpy.matmul(a, b.T) * alpha 

47 

48 @staticmethod 

49 def _fmatmul10(a, b, alpha): 

50 return numpy.matmul(a.T, b) * alpha 

51 

52 @staticmethod 

53 def _fmatmul11(a, b, alpha): 

54 return numpy.matmul(a.T, b.T) * alpha 

55 

56 @staticmethod 

57 def _transpose(x, trans, transBatch): 

58 if trans: 

59 n = len(x.shape) 

60 perm = list(range(n - 2)) + [n - 2, n - 1] 

61 x = numpy.transpose(x, perm) 

62 if transBatch: 

63 n = len(x.shape) 

64 perm = list(range(1, n - 2)) + [0, n - 1] 

65 x = numpy.transpose(x, perm) 

66 return x 

67 

68 def _run(self, a, b, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

69 if self.transBatchA or self.transBatchB or len(a.shape) != 2 or len(b.shape) != 2: 

70 ta = self._transpose(a, self.transA, self.transBatchA) 

71 tb = self._transpose(b, self.transB, self.transBatchB) 

72 try: 

73 return (numpy.matmul(ta, tb) * self.alpha, ) 

74 except ValueError as e: 

75 raise ValueError( 

76 f"Unable to multiply shape {a.shape}x{b.shape} " 

77 f"({ta.shape}x{tb.shape}) " 

78 f"with transA={self.transA}, " 

79 f"transB={self.transB}, " 

80 f"transBatchA={self.transBatchA}, " 

81 f"transBatchB={self.transBatchB}, " 

82 f"meth={self._meth_}.") from e 

83 try: 

84 return (self._meth(a, b), ) 

85 except ValueError as e: 

86 raise ValueError( 

87 f"Unable to multiply shape {a.shape}x{b.shape} " 

88 f"with transA={self.transA}, " 

89 f"transB={self.transB}, " 

90 f"transBatchA={self.transBatchA}, " 

91 f"transBatchB={self.transBatchB}, " 

92 f"meth={self._meth_}.") from e 

93 

94 

95class FusedMatMulSchema(OperatorSchema): 

96 """ 

97 Defines a schema for operators added in this package 

98 such as @see cl FusedMatMul. 

99 """ 

100 

101 def __init__(self): 

102 OperatorSchema.__init__(self, 'FusedMatMul') 

103 self.attributes = FusedMatMul.atts