Coverage for mlprodict/testing/einsum/blas_lapack.py: 97%
129 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Direct calls to libraries :epkg:`BLAS` and :epkg:`LAPACK`.
4"""
5import numpy
6from scipy.linalg.blas import sgemm, dgemm # pylint: disable=E0611
7from .direct_blas_lapack import ( # pylint: disable=E0401,E0611
8 dgemm_dot, sgemm_dot)
11def pygemm(transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc):
12 """
13 Pure python implementatin of GEMM.
14 """
15 if len(A.shape) != 1:
16 raise ValueError( # pragma: no cover
17 "A must be a vector.")
18 if len(B.shape) != 1:
19 raise ValueError( # pragma: no cover
20 "B must be a vector.")
21 if len(C.shape) != 1:
22 raise ValueError(
23 "C must be a vector.")
24 if A.shape[0] != M * K:
25 raise ValueError(
26 f"Dimension mismatch for A.shape={A.shape!r} M={M!r} N={N!r} K={K!r}.")
27 if B.shape[0] != N * K:
28 raise ValueError(
29 f"Dimension mismatch for B.shape={B.shape!r} M={M!r} N={N!r} K={K!r}.")
30 if C.shape[0] != N * M:
31 raise ValueError( # pragma: no cover
32 f"Dimension mismatch for C.shape={C.shape!r} M={M!r} N={N!r} K={K!r}.")
34 if transA:
35 a_i_stride = lda
36 a_k_stride = 1
37 else:
38 a_i_stride = 1
39 a_k_stride = lda
41 if transB:
42 b_j_stride = 1
43 b_k_stride = ldb
44 else:
45 b_j_stride = ldb
46 b_k_stride = 1
48 c_i_stride = 1
49 c_j_stride = ldc
51 n_loop = 0
52 for j in range(N):
53 for i in range(M):
54 total = 0
55 for k in range(K):
56 n_loop += 1
57 a_index = i * a_i_stride + k * a_k_stride
58 if a_index >= A.shape[0]:
59 raise IndexError(
60 "A: i=%d a_index=%d >= %d "
61 "(a_i_stride=%d a_k_stride=%d)" % (
62 i, a_index, A.shape[0], a_i_stride, a_k_stride))
63 a_val = A[a_index]
65 b_index = j * b_j_stride + k * b_k_stride
66 if b_index >= B.shape[0]:
67 raise IndexError(
68 "B: j=%d b_index=%d >= %d "
69 "(a_i_stride=%d a_k_stride=%d)" % (
70 j, b_index, B.shape[0], b_j_stride, b_k_stride))
71 b_val = B[b_index]
73 mult = a_val * b_val
74 total += mult
76 c_index = i * c_i_stride + j * c_j_stride
77 if c_index >= C.shape[0]:
78 raise IndexError("C: %d >= %d" % (c_index, C.shape[0]))
79 C[c_index] = alpha * total + beta * C[c_index]
81 if n_loop != M * N * K:
82 raise RuntimeError( # pragma: no cover
83 "Unexpected number of loops: %d != %d = (%d * %d * %d) "
84 "lda=%d ldb=%d ldc=%d" % (
85 n_loop, M * N * K, M, N, K, lda, ldb, ldc))
88def gemm_dot(A, B, transA=False, transB=False):
89 """
90 Implements dot product with gemm when possible.
92 :param A: first matrix
93 :param B: second matrix
94 :param transA: is first matrix transposed?
95 :param transB: is second matrix transposed?
96 """
97 if A.dtype != B.dtype:
98 raise TypeError( # pragma: no cover
99 f"Matrices A and B must have the same dtype not {A.dtype!r}, {B.dtype!r}.")
100 if len(A.shape) != 2:
101 raise ValueError( # pragma: no cover
102 f"Matrix A does not have 2 dimensions but {len(A.shape)}.")
103 if len(B.shape) != 2:
104 raise ValueError( # pragma: no cover
105 f"Matrix B does not have 2 dimensions but {len(B.shape)}.")
107 def _make_contiguous_(A, B):
108 if not A.flags['C_CONTIGUOUS']:
109 A = numpy.ascontiguousarray(A)
110 if not B.flags['C_CONTIGUOUS']:
111 B = numpy.ascontiguousarray(B)
112 return A, B
114 all_dims = A.shape + B.shape
115 square = min(all_dims) == max(all_dims)
117 if transA:
118 if transB:
119 if A.dtype == numpy.float32:
120 if square:
121 C = numpy.zeros((A.shape[1], B.shape[0]), dtype=A.dtype)
122 A, B = _make_contiguous_(A, B)
123 sgemm_dot(B, A, True, True, C)
124 return C
125 else:
126 C = numpy.zeros((A.shape[1], B.shape[0]), dtype=A.dtype)
127 return sgemm(1, A, B, 0, C, 1, 1, 1)
128 if A.dtype == numpy.float64:
129 if square:
130 C = numpy.zeros((A.shape[1], B.shape[0]), dtype=A.dtype)
131 A, B = _make_contiguous_(A, B)
132 dgemm_dot(B, A, True, True, C)
133 return C
134 else:
135 C = numpy.zeros((A.shape[1], B.shape[0]), dtype=A.dtype)
136 return dgemm(1, A, B, 0, C, 1, 1, 1)
137 return A.T @ B.T
138 else:
139 if A.dtype == numpy.float32:
140 if square:
141 C = numpy.zeros((A.shape[1], B.shape[1]), dtype=A.dtype)
142 A, B = _make_contiguous_(A, B)
143 sgemm_dot(B, A, False, True, C)
144 return C
145 else:
146 C = numpy.zeros((A.shape[1], B.shape[1]), dtype=A.dtype)
147 return sgemm(1, A, B, 0, C, 1, 0, 1)
148 if A.dtype == numpy.float64:
149 if square:
150 C = numpy.zeros((A.shape[1], B.shape[1]), dtype=A.dtype)
151 A, B = _make_contiguous_(A, B)
152 dgemm_dot(B, A, False, True, C)
153 return C
154 else:
155 C = numpy.zeros((A.shape[1], B.shape[1]), dtype=A.dtype)
156 return dgemm(1, A, B, 0, C, 1, 0, 1)
157 return A.T @ B
158 else:
159 if transB:
160 if A.dtype == numpy.float32:
161 if square:
162 C = numpy.zeros((A.shape[0], B.shape[0]), dtype=A.dtype)
163 A, B = _make_contiguous_(A, B)
164 sgemm_dot(B, A, True, False, C)
165 return C
166 else:
167 C = numpy.zeros((A.shape[0], B.shape[0]), dtype=A.dtype)
168 return sgemm(1, A, B, 0, C, 0, 1, 1)
169 if A.dtype == numpy.float64:
170 if square:
171 C = numpy.zeros((A.shape[0], B.shape[0]), dtype=A.dtype)
172 A, B = _make_contiguous_(A, B)
173 dgemm_dot(B, A, True, False, C)
174 return C
175 else:
176 C = numpy.zeros((A.shape[0], B.shape[0]), dtype=A.dtype)
177 return dgemm(1, A, B, 0, C, 0, 1, 1)
178 return A @ B.T
179 else:
180 if A.dtype == numpy.float32:
181 if square:
182 C = numpy.zeros((A.shape[0], B.shape[1]), dtype=A.dtype)
183 A, B = _make_contiguous_(A, B)
184 sgemm_dot(B, A, False, False, C)
185 return C
186 else:
187 C = numpy.zeros((A.shape[0], B.shape[1]), dtype=A.dtype)
188 return sgemm(1, A, B, 0, C, 0, 0)
189 if A.dtype == numpy.float64:
190 if square:
191 C = numpy.zeros((A.shape[0], B.shape[1]), dtype=A.dtype)
192 A, B = _make_contiguous_(A, B)
193 dgemm_dot(B, A, False, False, C)
194 return C
195 else:
196 C = numpy.zeros((A.shape[0], B.shape[1]), dtype=A.dtype)
197 return dgemm(1, A, B, 0, C, 0, 0, 1)
198 return A @ B