Coverage for mlprodict/testing/experimental.py: 100%
179 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Experimental implementation.
4"""
5from collections import OrderedDict
6import numpy
9def custom_pad(arr, paddings, constant=0, verbose=False):
10 """
11 Implements function
12 `pad <https://numpy.org/doc/stable/reference/
13 generated/numpy.pad.html>`_ in python,
14 only the constant version.
16 :param arr: array
17 :param paddings: paddings
18 :param constant: constant
19 :return: padded array
20 """
21 if paddings.shape[0] != len(arr.shape):
22 raise ValueError( # pragma: no cover
23 f"Input shape {arr.shape} and paddings {paddings} are inconsistent.")
24 if min(paddings.ravel()) < 0:
25 raise NotImplementedError("Negative paddings is not implemented yet.")
26 if not arr.flags['C_CONTIGUOUS']:
27 arr = numpy.ascontiguousarray(arr) # pragma: no cover
29 new_shape = tuple(
30 a + s for a, s in zip(arr.shape, numpy.sum(paddings, axis=1, keepdims=0)))
32 cumulative_copy = [1]
33 for a in reversed(new_shape):
34 cumulative_copy.insert(0, a * cumulative_copy[0])
35 cumulative_input = [1]
36 for a in reversed(arr.shape):
37 cumulative_input.insert(0, a * cumulative_input[0])
39 input_arr = arr.ravel()
40 if verbose:
41 res = numpy.zeros(cumulative_copy[0], dtype=arr.dtype) - 1
42 else:
43 res = numpy.empty(cumulative_copy[0], dtype=arr.dtype)
45 # preparation
46 first_index = sum(
47 p * c for p, c in zip(paddings[:, 0], cumulative_copy[1:]))
48 dh_input = arr.shape[-1]
49 dh_copy = new_shape[-1]
51 # constance
52 no_constant = 1 if constant == 0 else 0
53 res[first_index:cumulative_copy[0]:dh_copy] = no_constant
55 # padding
56 for i, sh in enumerate(new_shape):
57 upper_number = cumulative_copy[0] // cumulative_copy[i]
58 contiguous = cumulative_copy[i + 1]
59 big_index = 0
60 p_left = paddings[i, 0] * contiguous
61 p_right = paddings[i, 1] * contiguous
62 dp = sh * contiguous - p_right
63 for _ in range(upper_number):
64 if p_left > 0:
65 res[big_index:big_index + p_left] = constant
66 if p_right > 0:
67 index = big_index + dp
68 res[index:index + p_right] = constant
69 big_index += cumulative_copy[i]
71 # copy
72 index_input = 0
73 index_copy = first_index
74 while index_copy < cumulative_copy[0]:
75 if res[index_copy] == no_constant:
76 res[index_copy:index_copy + dh_input] = \
77 input_arr[index_input:index_input + dh_input]
78 index_input += dh_input
79 index_copy += dh_copy
81 # final
82 return res.reshape(new_shape)
85def custom_einsum(equation, x, y, verbose=False):
86 """
87 Experimental implementation of operator Einsum
88 when it does a matrix multiplication.
89 Case: ``bsnh,btnh->bnts`` with shapes
90 `(1,512,12,64)` and `(1,512,12,64)`.
92 :param equation: equation
93 :param x: first matrix
94 :param y: second matrix
95 :param verbose: display internal information
96 :return: result of *einsum*
98 This implementation does not any transpose,
99 it does a direct computation of the final result.
100 It does not implementation diagonal summation (square product).
101 """
102 def _check_eq(eq, sh):
103 if len(eq) != len(sh):
104 raise ValueError(
105 f"Unable to map equation {eq!r} to shape {sh!r}.")
107 def _split(eq, sh):
108 dx = OrderedDict((e, (v, i)) for i, (e, v) in enumerate(zip(eq, sh)))
109 return dx
111 def _interpret(dx, dy, eqr):
112 c_uni = []
113 c_trp = []
114 c_sum = []
115 for r in eqr:
116 if r in dx:
117 if r in dy:
118 if dx[r][0] != dy[r][0]:
119 raise ValueError(
120 f"Dimension mismatch for letter {r!r} dx={dx!r} dy={dy!r}.")
121 c_trp.append(r)
122 else:
123 c_uni.append((r, None))
124 elif r in dy:
125 c_uni.append((None, r))
126 else:
127 raise ValueError( # pragma: no cover
128 f"Unexpected letter {r!r} in result {eqr!r}.")
129 for c in dx:
130 if c not in eqr:
131 if c not in dy:
132 raise ValueError( # pragma: no cover
133 f"Unable to guess what to do with column {c!r} (left side)")
134 if dx[c][0] != dy[c][0]:
135 raise ValueError( # pragma: no cover
136 f"Dimension mismatch for letter {c!r} dx={dx!r} dy={dy!r}.")
137 c_sum.append(c)
138 for c in dy:
139 if c not in eqr and c not in dx:
140 raise ValueError( # pragma: no cover
141 f"Unable to guess what to do with column {c!r} (right side)")
142 shape = OrderedDict()
143 for i, r in enumerate(eqr):
144 if r in c_trp:
145 shape[r] = (dx[r][0], i)
146 else:
147 for a, b in c_uni:
148 if a == r:
149 shape[r] = (dx[r][0], i)
150 break
151 if b == r:
152 shape[r] = (dy[r][0], i)
153 break
154 if len(shape) != len(eqr):
155 raise RuntimeError( # pragma: no cover
156 "Unable to compute the output shape "
157 "dx=%r dy=%r eqr=%r got shape=%r." % (dx, dy, eqr, shape))
158 return shape, c_trp, c_uni, c_sum
160 def _inc(d):
161 t = 1
162 drev = list(reversed(d.items()))
163 res = []
164 for c, (sh, p) in drev:
165 res.append((c, (t, p)))
166 t *= sh
167 return OrderedDict(reversed(res))
169 def prod(seq):
170 p = 1
171 for s in seq:
172 p *= s
173 return p
175 def get_index(cd, shape, index, col_sum):
176 ind = 0
177 for c, i in zip(shape, index):
178 if c in cd:
179 inc = cd[c][0]
180 ind += inc * i
181 return ind, cd[col_sum][0]
183 def get_incs(cd, shape):
184 incs = []
185 for c in shape:
186 inc = cd[c][0] if c in cd else 0
187 incs.append(inc)
188 return incs
190 if x.dtype != y.dtype:
191 raise RuntimeError("x and y must have the same dtype.")
192 eqx = equation.split(',')[0]
193 eqy = equation.split(',')[-1].split('->')[0]
194 eqr = equation.split('->')[-1]
195 _check_eq(eqx, x.shape)
196 _check_eq(eqy, y.shape)
197 dx = _split(eqx, x.shape)
198 dy = _split(eqy, y.shape)
199 shape, __, _, c_sum = _interpret(dx, dy, eqr)
200 cdx = _inc(dx)
201 cdy = _inc(dy)
202 xrav = x.ravel()
203 yrav = y.ravel()
204 full_size = prod(v[0] for v in shape.values())
205 zrav = numpy.empty((full_size, ), dtype=x.dtype)
207 # loop
208 if len(c_sum) != 1:
209 raise NotImplementedError(
210 f"More than one summation indices {c_sum!r} in equation {equation!r}.")
211 zeros = numpy.zeros((1, ), dtype=x.dtype)
212 shape_dims = [v[0] for v in shape.values()]
213 index = [0 for s in shape]
214 len_index = len(index)
215 loop_size = dx[c_sum[0]][0]
217 i_left_loop, inc_left = get_index(cdx, shape, index, c_sum[0])
218 i_right_loop, inc_right = get_index(cdy, shape, index, c_sum[0])
219 left_incs = get_incs(cdx, shape)
220 right_incs = get_incs(cdy, shape)
222 if verbose:
223 def MakeString(*args):
224 return "".join(map(str, args))
226 print(MakeString("equation=", equation))
227 print(MakeString("c_sum=", c_sum))
228 print(MakeString("full_size=", full_size))
229 print(MakeString("loop_size=", loop_size))
230 print(MakeString("i_left_loop=", i_left_loop))
231 print(MakeString("i_right_loop=", i_right_loop))
232 print(MakeString("inc_left=", inc_left))
233 print(MakeString("inc_right=", inc_right))
234 print(MakeString("left_incs=", left_incs))
235 print(MakeString("right_incs=", right_incs))
236 print(MakeString("shape=", shape))
237 print(MakeString("cdx=", cdx))
238 print(MakeString("cdy=", cdy))
240 for i in range(0, full_size):
242 i_left = i_left_loop
243 i_right = i_right_loop
245 # summation
246 add = zeros[0]
247 for _ in range(loop_size):
248 add += xrav[i_left] * yrav[i_right]
249 i_left += inc_left
250 i_right += inc_right
251 zrav[i] = add
253 if verbose:
254 print(MakeString(
255 " -- index=", index, " ii=", i,
256 " i_left_loop=", i_left_loop, " i_right_loop=", i_right_loop,
257 " add=", add))
259 # increment
260 pos = len_index - 1
261 index[pos] += 1
262 i_left_loop += left_incs[pos]
263 i_right_loop += right_incs[pos]
264 while pos > 0 and index[pos] >= shape_dims[pos]:
265 i_left_loop -= left_incs[pos] * index[pos]
266 i_right_loop -= right_incs[pos] * index[pos]
267 index[pos] = 0
268 pos -= 1
269 index[pos] += 1
270 i_left_loop += left_incs[pos]
271 i_right_loop += right_incs[pos]
273 new_shape = tuple(v[0] for v in shape.values())
274 return zrav.reshape(new_shape)