Coverage for mlprodict/onnxrt/ops_cpu/op_gemm.py: 100%

37 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9 

10 

11class Gemm(OpRun): 

12 

13 atts = {'alpha': 1., 'beta': 1., 'transA': 0, 'transB': 0} 

14 

15 def __init__(self, onnx_node, desc=None, **options): 

16 OpRun.__init__(self, onnx_node, desc=desc, 

17 expected_attributes=Gemm.atts, 

18 **options) 

19 if self.transA: 

20 _meth = (Gemm._gemm11 if self.transB 

21 else Gemm._gemm10) 

22 else: 

23 _meth = (Gemm._gemm01 if self.transB 

24 else Gemm._gemm00) 

25 self._meth = lambda a, b, c: _meth(a, b, c, self.alpha, self.beta) 

26 

27 @staticmethod 

28 def _gemm00(a, b, c, alpha, beta): 

29 o = numpy.dot(a, b) * alpha 

30 if c is not None and beta != 0: 

31 o += c * beta 

32 return o 

33 

34 @staticmethod 

35 def _gemm01(a, b, c, alpha, beta): 

36 o = numpy.dot(a, b.T) * alpha 

37 if c is not None and beta != 0: 

38 o += c * beta 

39 return o 

40 

41 @staticmethod 

42 def _gemm10(a, b, c, alpha, beta): 

43 o = numpy.dot(a.T, b) * alpha 

44 if c is not None and beta != 0: 

45 o += c * beta 

46 return o 

47 

48 @staticmethod 

49 def _gemm11(a, b, c, alpha, beta): 

50 o = numpy.dot(a.T, b.T) * alpha 

51 if c is not None and beta != 0: 

52 o += c * beta 

53 return o 

54 

55 def _run(self, a, b, c=None, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

56 return (self._meth(a, b, c), )