Coverage for mlprodict/testing/experimental.py: 100%

179 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Experimental implementation. 

4""" 

5from collections import OrderedDict 

6import numpy 

7 

8 

9def custom_pad(arr, paddings, constant=0, verbose=False): 

10 """ 

11 Implements function 

12 `pad <https://numpy.org/doc/stable/reference/ 

13 generated/numpy.pad.html>`_ in python, 

14 only the constant version. 

15 

16 :param arr: array 

17 :param paddings: paddings 

18 :param constant: constant 

19 :return: padded array 

20 """ 

21 if paddings.shape[0] != len(arr.shape): 

22 raise ValueError( # pragma: no cover 

23 f"Input shape {arr.shape} and paddings {paddings} are inconsistent.") 

24 if min(paddings.ravel()) < 0: 

25 raise NotImplementedError("Negative paddings is not implemented yet.") 

26 if not arr.flags['C_CONTIGUOUS']: 

27 arr = numpy.ascontiguousarray(arr) # pragma: no cover 

28 

29 new_shape = tuple( 

30 a + s for a, s in zip(arr.shape, numpy.sum(paddings, axis=1, keepdims=0))) 

31 

32 cumulative_copy = [1] 

33 for a in reversed(new_shape): 

34 cumulative_copy.insert(0, a * cumulative_copy[0]) 

35 cumulative_input = [1] 

36 for a in reversed(arr.shape): 

37 cumulative_input.insert(0, a * cumulative_input[0]) 

38 

39 input_arr = arr.ravel() 

40 if verbose: 

41 res = numpy.zeros(cumulative_copy[0], dtype=arr.dtype) - 1 

42 else: 

43 res = numpy.empty(cumulative_copy[0], dtype=arr.dtype) 

44 

45 # preparation 

46 first_index = sum( 

47 p * c for p, c in zip(paddings[:, 0], cumulative_copy[1:])) 

48 dh_input = arr.shape[-1] 

49 dh_copy = new_shape[-1] 

50 

51 # constance 

52 no_constant = 1 if constant == 0 else 0 

53 res[first_index:cumulative_copy[0]:dh_copy] = no_constant 

54 

55 # padding 

56 for i, sh in enumerate(new_shape): 

57 upper_number = cumulative_copy[0] // cumulative_copy[i] 

58 contiguous = cumulative_copy[i + 1] 

59 big_index = 0 

60 p_left = paddings[i, 0] * contiguous 

61 p_right = paddings[i, 1] * contiguous 

62 dp = sh * contiguous - p_right 

63 for _ in range(upper_number): 

64 if p_left > 0: 

65 res[big_index:big_index + p_left] = constant 

66 if p_right > 0: 

67 index = big_index + dp 

68 res[index:index + p_right] = constant 

69 big_index += cumulative_copy[i] 

70 

71 # copy 

72 index_input = 0 

73 index_copy = first_index 

74 while index_copy < cumulative_copy[0]: 

75 if res[index_copy] == no_constant: 

76 res[index_copy:index_copy + dh_input] = \ 

77 input_arr[index_input:index_input + dh_input] 

78 index_input += dh_input 

79 index_copy += dh_copy 

80 

81 # final 

82 return res.reshape(new_shape) 

83 

84 

85def custom_einsum(equation, x, y, verbose=False): 

86 """ 

87 Experimental implementation of operator Einsum 

88 when it does a matrix multiplication. 

89 Case: ``bsnh,btnh->bnts`` with shapes 

90 `(1,512,12,64)` and `(1,512,12,64)`. 

91 

92 :param equation: equation 

93 :param x: first matrix 

94 :param y: second matrix 

95 :param verbose: display internal information 

96 :return: result of *einsum* 

97 

98 This implementation does not any transpose, 

99 it does a direct computation of the final result. 

100 It does not implementation diagonal summation (square product). 

101 """ 

102 def _check_eq(eq, sh): 

103 if len(eq) != len(sh): 

104 raise ValueError( 

105 f"Unable to map equation {eq!r} to shape {sh!r}.") 

106 

107 def _split(eq, sh): 

108 dx = OrderedDict((e, (v, i)) for i, (e, v) in enumerate(zip(eq, sh))) 

109 return dx 

110 

111 def _interpret(dx, dy, eqr): 

112 c_uni = [] 

113 c_trp = [] 

114 c_sum = [] 

115 for r in eqr: 

116 if r in dx: 

117 if r in dy: 

118 if dx[r][0] != dy[r][0]: 

119 raise ValueError( 

120 f"Dimension mismatch for letter {r!r} dx={dx!r} dy={dy!r}.") 

121 c_trp.append(r) 

122 else: 

123 c_uni.append((r, None)) 

124 elif r in dy: 

125 c_uni.append((None, r)) 

126 else: 

127 raise ValueError( # pragma: no cover 

128 f"Unexpected letter {r!r} in result {eqr!r}.") 

129 for c in dx: 

130 if c not in eqr: 

131 if c not in dy: 

132 raise ValueError( # pragma: no cover 

133 f"Unable to guess what to do with column {c!r} (left side)") 

134 if dx[c][0] != dy[c][0]: 

135 raise ValueError( # pragma: no cover 

136 f"Dimension mismatch for letter {c!r} dx={dx!r} dy={dy!r}.") 

137 c_sum.append(c) 

138 for c in dy: 

139 if c not in eqr and c not in dx: 

140 raise ValueError( # pragma: no cover 

141 f"Unable to guess what to do with column {c!r} (right side)") 

142 shape = OrderedDict() 

143 for i, r in enumerate(eqr): 

144 if r in c_trp: 

145 shape[r] = (dx[r][0], i) 

146 else: 

147 for a, b in c_uni: 

148 if a == r: 

149 shape[r] = (dx[r][0], i) 

150 break 

151 if b == r: 

152 shape[r] = (dy[r][0], i) 

153 break 

154 if len(shape) != len(eqr): 

155 raise RuntimeError( # pragma: no cover 

156 "Unable to compute the output shape " 

157 "dx=%r dy=%r eqr=%r got shape=%r." % (dx, dy, eqr, shape)) 

158 return shape, c_trp, c_uni, c_sum 

159 

160 def _inc(d): 

161 t = 1 

162 drev = list(reversed(d.items())) 

163 res = [] 

164 for c, (sh, p) in drev: 

165 res.append((c, (t, p))) 

166 t *= sh 

167 return OrderedDict(reversed(res)) 

168 

169 def prod(seq): 

170 p = 1 

171 for s in seq: 

172 p *= s 

173 return p 

174 

175 def get_index(cd, shape, index, col_sum): 

176 ind = 0 

177 for c, i in zip(shape, index): 

178 if c in cd: 

179 inc = cd[c][0] 

180 ind += inc * i 

181 return ind, cd[col_sum][0] 

182 

183 def get_incs(cd, shape): 

184 incs = [] 

185 for c in shape: 

186 inc = cd[c][0] if c in cd else 0 

187 incs.append(inc) 

188 return incs 

189 

190 if x.dtype != y.dtype: 

191 raise RuntimeError("x and y must have the same dtype.") 

192 eqx = equation.split(',')[0] 

193 eqy = equation.split(',')[-1].split('->')[0] 

194 eqr = equation.split('->')[-1] 

195 _check_eq(eqx, x.shape) 

196 _check_eq(eqy, y.shape) 

197 dx = _split(eqx, x.shape) 

198 dy = _split(eqy, y.shape) 

199 shape, __, _, c_sum = _interpret(dx, dy, eqr) 

200 cdx = _inc(dx) 

201 cdy = _inc(dy) 

202 xrav = x.ravel() 

203 yrav = y.ravel() 

204 full_size = prod(v[0] for v in shape.values()) 

205 zrav = numpy.empty((full_size, ), dtype=x.dtype) 

206 

207 # loop 

208 if len(c_sum) != 1: 

209 raise NotImplementedError( 

210 f"More than one summation indices {c_sum!r} in equation {equation!r}.") 

211 zeros = numpy.zeros((1, ), dtype=x.dtype) 

212 shape_dims = [v[0] for v in shape.values()] 

213 index = [0 for s in shape] 

214 len_index = len(index) 

215 loop_size = dx[c_sum[0]][0] 

216 

217 i_left_loop, inc_left = get_index(cdx, shape, index, c_sum[0]) 

218 i_right_loop, inc_right = get_index(cdy, shape, index, c_sum[0]) 

219 left_incs = get_incs(cdx, shape) 

220 right_incs = get_incs(cdy, shape) 

221 

222 if verbose: 

223 def MakeString(*args): 

224 return "".join(map(str, args)) 

225 

226 print(MakeString("equation=", equation)) 

227 print(MakeString("c_sum=", c_sum)) 

228 print(MakeString("full_size=", full_size)) 

229 print(MakeString("loop_size=", loop_size)) 

230 print(MakeString("i_left_loop=", i_left_loop)) 

231 print(MakeString("i_right_loop=", i_right_loop)) 

232 print(MakeString("inc_left=", inc_left)) 

233 print(MakeString("inc_right=", inc_right)) 

234 print(MakeString("left_incs=", left_incs)) 

235 print(MakeString("right_incs=", right_incs)) 

236 print(MakeString("shape=", shape)) 

237 print(MakeString("cdx=", cdx)) 

238 print(MakeString("cdy=", cdy)) 

239 

240 for i in range(0, full_size): 

241 

242 i_left = i_left_loop 

243 i_right = i_right_loop 

244 

245 # summation 

246 add = zeros[0] 

247 for _ in range(loop_size): 

248 add += xrav[i_left] * yrav[i_right] 

249 i_left += inc_left 

250 i_right += inc_right 

251 zrav[i] = add 

252 

253 if verbose: 

254 print(MakeString( 

255 " -- index=", index, " ii=", i, 

256 " i_left_loop=", i_left_loop, " i_right_loop=", i_right_loop, 

257 " add=", add)) 

258 

259 # increment 

260 pos = len_index - 1 

261 index[pos] += 1 

262 i_left_loop += left_incs[pos] 

263 i_right_loop += right_incs[pos] 

264 while pos > 0 and index[pos] >= shape_dims[pos]: 

265 i_left_loop -= left_incs[pos] * index[pos] 

266 i_right_loop -= right_incs[pos] * index[pos] 

267 index[pos] = 0 

268 pos -= 1 

269 index[pos] += 1 

270 i_left_loop += left_incs[pos] 

271 i_right_loop += right_incs[pos] 

272 

273 new_shape = tuple(v[0] for v in shape.values()) 

274 return zrav.reshape(new_shape)