Coverage for mlprodict/onnxrt/ops_cpu/op_fused_matmul.py: 88%
59 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
9from ._new_ops import OperatorSchema
12class FusedMatMul(OpRun):
14 atts = {'alpha': 1., 'transA': 0, 'transB': 0}
16 def __init__(self, onnx_node, desc=None, **options):
17 OpRun.__init__(self, onnx_node, desc=desc,
18 expected_attributes=FusedMatMul.atts,
19 **options)
20 if self.transA:
21 _meth = (FusedMatMul._fmatmul11 if self.transB
22 else FusedMatMul._fmatmul10)
23 else:
24 _meth = (FusedMatMul._fmatmul01 if self.transB
25 else FusedMatMul._fmatmul00)
26 self._meth_ = _meth
27 self._meth = lambda a, b: _meth(a, b, self.alpha)
28 # more recent versions of the operator
29 if not hasattr(self, "transBatchA"):
30 self.transBatchA = 0
31 if not hasattr(self, "transBatchB"):
32 self.transBatchB = 0
34 def _find_custom_operator_schema(self, op_name):
35 if op_name == "FusedMatMul":
36 return FusedMatMulSchema()
37 raise RuntimeError( # pragma: no cover
38 f"Unable to find a schema for operator '{op_name}'.")
40 @staticmethod
41 def _fmatmul00(a, b, alpha):
42 return numpy.matmul(a, b) * alpha
44 @staticmethod
45 def _fmatmul01(a, b, alpha):
46 return numpy.matmul(a, b.T) * alpha
48 @staticmethod
49 def _fmatmul10(a, b, alpha):
50 return numpy.matmul(a.T, b) * alpha
52 @staticmethod
53 def _fmatmul11(a, b, alpha):
54 return numpy.matmul(a.T, b.T) * alpha
56 @staticmethod
57 def _transpose(x, trans, transBatch):
58 if trans:
59 n = len(x.shape)
60 perm = list(range(n - 2)) + [n - 2, n - 1]
61 x = numpy.transpose(x, perm)
62 if transBatch:
63 n = len(x.shape)
64 perm = list(range(1, n - 2)) + [0, n - 1]
65 x = numpy.transpose(x, perm)
66 return x
68 def _run(self, a, b, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
69 if self.transBatchA or self.transBatchB or len(a.shape) != 2 or len(b.shape) != 2:
70 ta = self._transpose(a, self.transA, self.transBatchA)
71 tb = self._transpose(b, self.transB, self.transBatchB)
72 try:
73 return (numpy.matmul(ta, tb) * self.alpha, )
74 except ValueError as e:
75 raise ValueError(
76 f"Unable to multiply shape {a.shape}x{b.shape} "
77 f"({ta.shape}x{tb.shape}) "
78 f"with transA={self.transA}, "
79 f"transB={self.transB}, "
80 f"transBatchA={self.transBatchA}, "
81 f"transBatchB={self.transBatchB}, "
82 f"meth={self._meth_}.") from e
83 try:
84 return (self._meth(a, b), )
85 except ValueError as e:
86 raise ValueError(
87 f"Unable to multiply shape {a.shape}x{b.shape} "
88 f"with transA={self.transA}, "
89 f"transB={self.transB}, "
90 f"transBatchA={self.transBatchA}, "
91 f"transBatchB={self.transBatchB}, "
92 f"meth={self._meth_}.") from e
95class FusedMatMulSchema(OperatorSchema):
96 """
97 Defines a schema for operators added in this package
98 such as @see cl FusedMatMul.
99 """
101 def __init__(self):
102 OperatorSchema.__init__(self, 'FusedMatMul')
103 self.attributes = FusedMatMul.atts