Coverage for mlprodict/testing/einsum/blas_lapack.py: 97%

129 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Direct calls to libraries :epkg:`BLAS` and :epkg:`LAPACK`. 

4""" 

5import numpy 

6from scipy.linalg.blas import sgemm, dgemm # pylint: disable=E0611 

7from .direct_blas_lapack import ( # pylint: disable=E0401,E0611 

8 dgemm_dot, sgemm_dot) 

9 

10 

11def pygemm(transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc): 

12 """ 

13 Pure python implementatin of GEMM. 

14 """ 

15 if len(A.shape) != 1: 

16 raise ValueError( # pragma: no cover 

17 "A must be a vector.") 

18 if len(B.shape) != 1: 

19 raise ValueError( # pragma: no cover 

20 "B must be a vector.") 

21 if len(C.shape) != 1: 

22 raise ValueError( 

23 "C must be a vector.") 

24 if A.shape[0] != M * K: 

25 raise ValueError( 

26 f"Dimension mismatch for A.shape={A.shape!r} M={M!r} N={N!r} K={K!r}.") 

27 if B.shape[0] != N * K: 

28 raise ValueError( 

29 f"Dimension mismatch for B.shape={B.shape!r} M={M!r} N={N!r} K={K!r}.") 

30 if C.shape[0] != N * M: 

31 raise ValueError( # pragma: no cover 

32 f"Dimension mismatch for C.shape={C.shape!r} M={M!r} N={N!r} K={K!r}.") 

33 

34 if transA: 

35 a_i_stride = lda 

36 a_k_stride = 1 

37 else: 

38 a_i_stride = 1 

39 a_k_stride = lda 

40 

41 if transB: 

42 b_j_stride = 1 

43 b_k_stride = ldb 

44 else: 

45 b_j_stride = ldb 

46 b_k_stride = 1 

47 

48 c_i_stride = 1 

49 c_j_stride = ldc 

50 

51 n_loop = 0 

52 for j in range(N): 

53 for i in range(M): 

54 total = 0 

55 for k in range(K): 

56 n_loop += 1 

57 a_index = i * a_i_stride + k * a_k_stride 

58 if a_index >= A.shape[0]: 

59 raise IndexError( 

60 "A: i=%d a_index=%d >= %d " 

61 "(a_i_stride=%d a_k_stride=%d)" % ( 

62 i, a_index, A.shape[0], a_i_stride, a_k_stride)) 

63 a_val = A[a_index] 

64 

65 b_index = j * b_j_stride + k * b_k_stride 

66 if b_index >= B.shape[0]: 

67 raise IndexError( 

68 "B: j=%d b_index=%d >= %d " 

69 "(a_i_stride=%d a_k_stride=%d)" % ( 

70 j, b_index, B.shape[0], b_j_stride, b_k_stride)) 

71 b_val = B[b_index] 

72 

73 mult = a_val * b_val 

74 total += mult 

75 

76 c_index = i * c_i_stride + j * c_j_stride 

77 if c_index >= C.shape[0]: 

78 raise IndexError("C: %d >= %d" % (c_index, C.shape[0])) 

79 C[c_index] = alpha * total + beta * C[c_index] 

80 

81 if n_loop != M * N * K: 

82 raise RuntimeError( # pragma: no cover 

83 "Unexpected number of loops: %d != %d = (%d * %d * %d) " 

84 "lda=%d ldb=%d ldc=%d" % ( 

85 n_loop, M * N * K, M, N, K, lda, ldb, ldc)) 

86 

87 

88def gemm_dot(A, B, transA=False, transB=False): 

89 """ 

90 Implements dot product with gemm when possible. 

91 

92 :param A: first matrix 

93 :param B: second matrix 

94 :param transA: is first matrix transposed? 

95 :param transB: is second matrix transposed? 

96 """ 

97 if A.dtype != B.dtype: 

98 raise TypeError( # pragma: no cover 

99 f"Matrices A and B must have the same dtype not {A.dtype!r}, {B.dtype!r}.") 

100 if len(A.shape) != 2: 

101 raise ValueError( # pragma: no cover 

102 f"Matrix A does not have 2 dimensions but {len(A.shape)}.") 

103 if len(B.shape) != 2: 

104 raise ValueError( # pragma: no cover 

105 f"Matrix B does not have 2 dimensions but {len(B.shape)}.") 

106 

107 def _make_contiguous_(A, B): 

108 if not A.flags['C_CONTIGUOUS']: 

109 A = numpy.ascontiguousarray(A) 

110 if not B.flags['C_CONTIGUOUS']: 

111 B = numpy.ascontiguousarray(B) 

112 return A, B 

113 

114 all_dims = A.shape + B.shape 

115 square = min(all_dims) == max(all_dims) 

116 

117 if transA: 

118 if transB: 

119 if A.dtype == numpy.float32: 

120 if square: 

121 C = numpy.zeros((A.shape[1], B.shape[0]), dtype=A.dtype) 

122 A, B = _make_contiguous_(A, B) 

123 sgemm_dot(B, A, True, True, C) 

124 return C 

125 else: 

126 C = numpy.zeros((A.shape[1], B.shape[0]), dtype=A.dtype) 

127 return sgemm(1, A, B, 0, C, 1, 1, 1) 

128 if A.dtype == numpy.float64: 

129 if square: 

130 C = numpy.zeros((A.shape[1], B.shape[0]), dtype=A.dtype) 

131 A, B = _make_contiguous_(A, B) 

132 dgemm_dot(B, A, True, True, C) 

133 return C 

134 else: 

135 C = numpy.zeros((A.shape[1], B.shape[0]), dtype=A.dtype) 

136 return dgemm(1, A, B, 0, C, 1, 1, 1) 

137 return A.T @ B.T 

138 else: 

139 if A.dtype == numpy.float32: 

140 if square: 

141 C = numpy.zeros((A.shape[1], B.shape[1]), dtype=A.dtype) 

142 A, B = _make_contiguous_(A, B) 

143 sgemm_dot(B, A, False, True, C) 

144 return C 

145 else: 

146 C = numpy.zeros((A.shape[1], B.shape[1]), dtype=A.dtype) 

147 return sgemm(1, A, B, 0, C, 1, 0, 1) 

148 if A.dtype == numpy.float64: 

149 if square: 

150 C = numpy.zeros((A.shape[1], B.shape[1]), dtype=A.dtype) 

151 A, B = _make_contiguous_(A, B) 

152 dgemm_dot(B, A, False, True, C) 

153 return C 

154 else: 

155 C = numpy.zeros((A.shape[1], B.shape[1]), dtype=A.dtype) 

156 return dgemm(1, A, B, 0, C, 1, 0, 1) 

157 return A.T @ B 

158 else: 

159 if transB: 

160 if A.dtype == numpy.float32: 

161 if square: 

162 C = numpy.zeros((A.shape[0], B.shape[0]), dtype=A.dtype) 

163 A, B = _make_contiguous_(A, B) 

164 sgemm_dot(B, A, True, False, C) 

165 return C 

166 else: 

167 C = numpy.zeros((A.shape[0], B.shape[0]), dtype=A.dtype) 

168 return sgemm(1, A, B, 0, C, 0, 1, 1) 

169 if A.dtype == numpy.float64: 

170 if square: 

171 C = numpy.zeros((A.shape[0], B.shape[0]), dtype=A.dtype) 

172 A, B = _make_contiguous_(A, B) 

173 dgemm_dot(B, A, True, False, C) 

174 return C 

175 else: 

176 C = numpy.zeros((A.shape[0], B.shape[0]), dtype=A.dtype) 

177 return dgemm(1, A, B, 0, C, 0, 1, 1) 

178 return A @ B.T 

179 else: 

180 if A.dtype == numpy.float32: 

181 if square: 

182 C = numpy.zeros((A.shape[0], B.shape[1]), dtype=A.dtype) 

183 A, B = _make_contiguous_(A, B) 

184 sgemm_dot(B, A, False, False, C) 

185 return C 

186 else: 

187 C = numpy.zeros((A.shape[0], B.shape[1]), dtype=A.dtype) 

188 return sgemm(1, A, B, 0, C, 0, 0) 

189 if A.dtype == numpy.float64: 

190 if square: 

191 C = numpy.zeros((A.shape[0], B.shape[1]), dtype=A.dtype) 

192 A, B = _make_contiguous_(A, B) 

193 dgemm_dot(B, A, False, False, C) 

194 return C 

195 else: 

196 C = numpy.zeros((A.shape[0], B.shape[1]), dtype=A.dtype) 

197 return dgemm(1, A, B, 0, C, 0, 0, 1) 

198 return A @ B