Coverage for mlprodict/testing/einsum/einsum_impl_ext.py: 96%
271 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Functions implemented einsum computation for two
4matrices having the same dimensions.
5"""
6import numpy
9def numpy_diagonal(m, axis, axes):
10 """
11 Extracts diagonal coefficients from an array.
13 :param m: input array
14 :param axis: kept axis among the diagonal ones
15 :param axes: diagonal axes (axis must be one of them)
16 :return: output
18 .. runpython::
19 :showcode:
21 import numpy
22 from mlprodict.testing.einsum import numpy_diagonal
24 mat = numpy.arange(8).reshape((2, 2, 2))
25 print(mat)
26 diag = numpy_diagonal(mat, 1, [1, 2])
27 print(diag)
28 """
29 if axis not in axes:
30 raise RuntimeError(
31 f"axis {axis!r} must be in axes {axes!r}.")
32 shape = []
33 new_shape = []
34 for i, s in enumerate(m.shape):
35 if i in axes:
36 if i == axis:
37 shape.append(s)
38 new_shape.append(s)
39 else:
40 shape.append(1)
41 else:
42 shape.append(s)
43 new_shape.append(s)
45 # Extracts coefficients.
46 output = numpy.empty(tuple(shape), dtype=m.dtype)
47 index_in = [slice(s) for s in m.shape]
48 index_out = [slice(s) for s in m.shape]
49 for i in range(0, shape[axis]):
50 for a in axes:
51 index_in[a] = i
52 index_out[a] = i if a == axis else 0
53 output[tuple(index_out)] = m[tuple(index_in)]
55 # Removes axis.
56 return output.reshape(tuple(new_shape))
59def _numpy_extended_dot_equation(m1_dim, m2_dim, axes, left, right):
60 """
61 Returns the equation equivalent to an extended version
62 of an aligned matrix multiplication
63 (see @see fn numpy_extended_dot).
65 :param m1: number of dimensions of the first matrix
66 :param m2: number of dimensions of the second matrix
67 :param axes: summation axes
68 :param axes: summation axes
69 :param left: left axes
70 :param right: right axes
71 :return: equation
73 .. runpython::
74 :showcode:
76 import numpy
77 from mlprodict.testing.einsum.einsum_impl_ext import (
78 numpy_extended_dot_python, _numpy_extended_dot_equation)
80 a = numpy.arange(6).reshape((3, 2, 1))
81 b = numpy.arange(12).reshape((3, 1, 4))
83 print(numpy_extended_dot_python(
84 a, b, axes=(0, ), left=(1,), right=(2,)))
86 # Equivalent einsum equation
87 print('equation', _numpy_extended_dot_equation(
88 len(a.shape), len(a.shape), axes=(0, ), left=(1,), right=(2,)))
90 # Same einsum computation written in a different way.
91 print(numpy.einsum('kix,kxj->xij', a, b))
92 """
93 if m1_dim != m2_dim:
94 raise RuntimeError(
95 "Matrices m1 and m2 must have the same number of dimensions, "
96 "m1=%r, m2=%r." % (m1_dim, m2_dim))
97 total = set(axes) | set(left) | set(right)
98 if len(total) > m1_dim:
99 raise ValueError(
100 "Whole set of involved axes should be inferior to the number "
101 "of dimensions: %r = {%r} | {%r} | {%r} has more than %d elements"
102 "." % (total, axes, left, right, m1_dim))
104 def _check_(axs, n):
105 for a in axs:
106 if a < 0 or a >= n:
107 raise ValueError(
108 "One axis %d (in %r) is negative or above the maximum "
109 "dimension %d." % (a, axs, n))
110 _check_(axes, m1_dim)
111 _check_(left, m1_dim)
112 _check_(right, m1_dim)
114 l1 = [chr(i + 97) for i in range(m1_dim)]
115 l2 = [chr(i + 97) for i in range(m1_dim)]
116 l3 = [chr(i + 97) for i in range(m1_dim)]
117 for a in left:
118 l1[a] = l1[a].upper()
119 l3[a] = l3[a].upper()
120 for a in right:
121 l2[a] = l2[a].upper()
122 l3[a] = l3[a].upper()
123 for a in axes:
124 l1[a] = l1[a].lower()
125 l2[a] = l2[a].lower()
126 if a not in right:
127 l3[a] = None
128 else:
129 l3[a] = l3[a].lower()
130 eq = f"{''.join(l1)},{''.join(l2)}->{''.join(s for s in l3 if s)}"
131 return eq
134def _common_check_numpy_extended_dot(m1, m2, axes, left, right):
135 """
136 Common verifications for all implementations of
137 @see fn numpy_extended_dot.
138 """
139 if m1.dtype != m2.dtype:
140 raise TypeError(
141 f"Both matrices should share the same dtype {m1.dtype!r} != {m2.dtype!r}.")
142 m1_dim = len(m1.shape)
143 m2_dim = len(m2.shape)
144 if m1_dim != m2_dim:
145 raise RuntimeError( # pragma: no cover
146 "Matrices m1 and m2 must have the same number of dimensions, "
147 "m1=%r, m2=%r." % (m1_dim, m2_dim))
148 total = set(axes) | set(left) | set(right)
149 if len(total) > m1_dim:
150 raise ValueError(
151 "Whole set of involved axes should be inferior to the number "
152 "of dimensions: %r = {%r} | {%r} | {%r} has more than %d elements"
153 "." % (total, axes, left, right, m1_dim))
156def numpy_extended_dot(m1, m2, axes, left, right, verbose=False):
157 """
158 Extended version of a matrix multiplication (:epkg:`numpy:dot`)
159 with two matrices *m1*, *m2* of the same dimensions.
160 Loops over *left* axes for *m1* and *right* axes for *m2*,
161 summation is done over *axes*.
162 Other axes must be empty.
163 This multiplication combines matrix multiplication (dot)
164 and broadcasted multiplication term by term.
166 :param m1: first matrix
167 :param m2: second matrix
168 :param axes: summation axes
169 :param left: left axes
170 :param right: right axes
171 :param verbose: display intermediate information
172 :return: output
174 The dot product is equivalent to:
176 .. runpython::
177 :showcode:
179 import numpy
180 from mlprodict.testing.einsum import numpy_extended_dot
182 m1 = numpy.arange(4).reshape((2, 2))
183 m2 = m1 + 10
184 print("dot product")
185 print(m1 @ m2)
187 dm1 = m1.reshape((2, 2, 1))
188 dm2 = m2.reshape((1, 2, 2))
189 dot = numpy_extended_dot(dm1, dm2, axes=[1], left=[0], right=[2],
190 verbose=True)
191 print("extended dot product")
192 print(dot)
194 Empty axes should be squeezed to get identical results.
195 Dot product when the second matrix is transposed.
197 .. runpython::
198 :showcode:
200 import numpy
201 from mlprodict.testing.einsum import numpy_extended_dot
203 m1 = numpy.arange(4).reshape((2, 2))
204 m2 = m1 + 10
205 print("dot product")
206 print(m1 @ m2.T)
208 dm1 = m1.reshape((2, 1, 2))
209 dm2 = m2.reshape((1, 2, 2))
210 dot = numpy_extended_dot(dm1, dm2, axes=[2], left=[0], right=[1],
211 verbose=True)
212 print("extended dot product")
213 print(dot)
215 An example when right axes include the summation axis.
217 .. runpython::
218 :showcode:
220 import numpy
221 from mlprodict.testing.einsum import numpy_extended_dot
223 m1 = numpy.arange(4).reshape((2, 2))
224 m2 = m1 + 10
225 dm1 = m1.reshape((2, 2, 1))
226 dm2 = m2.reshape((1, 2, 2))
227 dot = numpy_extended_dot(dm1, dm2, axes=[2], left=[0], right=[1, 2],
228 verbose=True)
229 print(dot)
231 Example in higher dimension:
233 .. runpython::
234 :showcode:
236 import numpy
237 from mlprodict.testing.einsum import numpy_extended_dot
239 m1 = numpy.arange(8).reshape((2, 2, 2))
240 m2 = m1 + 10
242 dot = numpy_extended_dot(m1, m2, [1], [0], [2], verbose=True)
243 print(dot)
245 The current implementation still uses :epkg:`numpy:einsum`
246 but this should be replaced.
247 """
248 _common_check_numpy_extended_dot(m1, m2, axes, left, right)
249 eq = _numpy_extended_dot_equation(
250 len(m1.shape), len(m2.shape), axes, left, right)
251 if verbose:
252 print(f" [numpy_extended_dot] {eq}: {m1.shape!r} @ {m2.shape!r}")
253 output = numpy.einsum(eq, m1, m2)
254 new_shape = list(output.shape)
255 for a in axes:
256 if a not in right:
257 new_shape.insert(a, 1)
258 if verbose:
259 print(
260 f" [numpy_extended_dot] {output.shape!r} reshaped into {new_shape!r} ")
261 return output.reshape(tuple(new_shape))
264def numpy_extended_dot_ouput_shape(m1, m2, axes, left, right):
265 """
266 Computes the output shape of results produced by function
267 :func:`numpy_extended_dot
268 <mlprodict.testing.einsum_impl_ext.numpy_extended_dot>` or
269 :func:`numpy_extended_dot_python
270 <mlprodict.testing.einsum_impl_ext.numpy_extended_dot_python>`.
271 """
272 _common_check_numpy_extended_dot(m1, m2, axes, left, right)
273 m1_dim = len(m1.shape)
275 new_shape = numpy.full(m1_dim, 1, dtype=numpy.int64)
276 for i in left:
277 new_shape[i] = m1.shape[i]
278 for i in right:
279 if (i in left and m1.shape[i] != m2.shape[i] and
280 m1.shape[i] != 1 and m2.shape[i] != 1):
281 raise RuntimeError( # pragma: no cover
282 "Matrices should have the same dimension for dimension %d, "
283 "shapes=%r @ %r." % (i, m1.shape, m2.shape))
284 new_shape[i] = m2.shape[i]
285 return new_shape
288def _numpy_extended_dot_python_l1l2l3(m1_dim, axes, left, right):
289 l1 = [chr(i + 97) for i in range(m1_dim)]
290 l2 = [chr(i + 97) for i in range(m1_dim)]
291 l3 = [chr(i + 97) for i in range(m1_dim)]
292 for a in left:
293 l1[a] = l1[a].upper()
294 l3[a] = l3[a].upper()
295 for a in right:
296 l2[a] = l2[a].upper()
297 l3[a] = l3[a].upper()
298 for a in axes:
299 l1[a] = l1[a].lower()
300 l2[a] = l2[a].lower()
301 if a not in right:
302 l3[a] = "-"
303 else:
304 l3[a] = l3[a].lower()
305 return l1, l2, l3
308def _numpy_extended_dot_python_intermediate(m1_shape, m2_shape, l1, l2, l3):
309 names = list(sorted(set(l1 + l2)))
310 kind = numpy.zeros(len(names), dtype=numpy.int64)
311 cols = {}
313 for i, n in enumerate(names):
314 if n in l1:
315 kind[i] += 1
316 cols[n] = l1.index(n)
317 if n in l2:
318 kind[i] += 2
319 cols[n] = l2.index(n)
320 if n in l3:
321 kind[i] += 4
323 pos = numpy.zeros(len(names), dtype=numpy.int64)
324 for j in range(0, pos.shape[0]):
325 pos[j] = cols[names[j]]
326 common = [(kind[i] & 3) == 3 for i in range(len(kind))]
327 broadcast = [common[i] and m1_shape[pos[i]] != m2_shape[pos[i]]
328 for i in range(len(common))]
330 return names, kind, cols, common, broadcast, pos
333def _numpy_extended_dot_python_update_broadcast(
334 m1, m2, axes, left, right, l1, l2, l3, names, broadcast, cols,
335 kind, common, verbose=False):
337 def dispb(c):
338 return "".join("o" if b else "." for b in c)
340 if verbose:
341 print( # pragma: no cover
342 "[GENERICDOT] before broadcast %s,%s->%s or %s" % (
343 "".join(l1), "".join(l2), "".join(l3),
344 _numpy_extended_dot_equation(
345 len(m1.shape), len(m1.shape), axes, left, right)))
346 print( # pragma: no cover
347 "[GENERICDOT] names=%s kind=%r common=%s broadcast=%s" % (
348 "".join(names), kind.tolist(),
349 dispb(common), dispb(broadcast)))
351 for i in range(len(broadcast)): # pylint: disable=C0200
352 if broadcast[i] and not (kind[i] & 3) == 3:
353 raise RuntimeError( # pragma: no cover
354 "Broadcast should only happen on common axes, "
355 "axes=%r left=%r right=%r shape1=%r shape2=%r."
356 "" % (axes, left, right, m1.shape, m2.shape))
357 if not broadcast[i]:
358 continue
359 # We split letters.
360 p = cols[names[i]]
361 dim = (m1.shape[p], m2.shape[p])
362 let = [l1[p], l2[p], l3[p]]
363 inp = 1 if dim[0] == 1 else 0
364 if verbose:
365 print( # pragma: no cover
366 "[GENERICDOT] name=%s dim=%r let=%r inp=%r p=%r" % (
367 names[i], dim, let, inp, p))
368 print( # pragma: no cover
369 f" B0 l1={l1!r}, l2={l2!r} l3={l3!r}")
370 if (kind[i] & 4) > 0:
371 # Summation axis is part of the output.
372 if let[inp].lower() == let[inp]:
373 let[inp] = let[inp].upper()
374 else:
375 let[inp] = let[inp].lower()
376 l3[p] = let[inp]
377 if inp == 1:
378 l2[p] = let[inp]
379 else:
380 l1[p] = let[inp]
381 if verbose:
382 print( # pragma: no cover
383 f" B1 l1={l1!r}, l2={l2!r} l3={l3!r}")
384 else:
385 # Summation axis is not part of the output.
386 if let[inp].lower() == let[inp]:
387 let[inp] = let[inp].upper()
388 else:
389 let[inp] = let[inp].lower()
390 if inp == 1:
391 l2[p] = let[inp]
392 else:
393 l1[p] = let[inp]
394 if verbose:
395 print(f" B2 l1={l1!r}, l2={l2!r} l3={l3!r}")
397 return l1, l2, l3
400def numpy_extended_dot_python(m1, m2, axes, left, right, verbose=False):
401 """
402 Implementation of @see fn numpy_extended_dot in pure python.
403 This implementation is not efficient but shows how to
404 implement this operation without :epkg:`numpy:einsum`.
406 .. runpython::
407 :showcode:
409 import numpy
410 from mlprodict.testing.einsum import numpy_extended_dot_python
411 from mlprodict.testing.einsum.einsum_impl_ext import (
412 _numpy_extended_dot_equation)
414 a = numpy.arange(6).reshape((3, 2, 1))
415 b = numpy.arange(12).reshape((3, 1, 4))
417 print(numpy_extended_dot_python(
418 a, b, axes=(0, ), left=(1,), right=(2,)))
420 # Equivalent einsum equation
421 print('equation', _numpy_extended_dot_equation(
422 len(a.shape), len(a.shape), axes=(0, ), left=(1,), right=(2,)))
424 # Same einsum computation written in a different way.
425 print(numpy.einsum('kix,kxj->xij', a, b))
426 """
427 def dispb(c):
428 return "".join("o" if b else "." for b in c)
430 new_shape = numpy_extended_dot_ouput_shape(m1, m2, axes, left, right)
431 m1_dim = len(m1.shape)
433 # output result
434 res = numpy.full(tuple(new_shape), 0, dtype=m1.dtype)
436 # indices
437 l1, l2, l3 = _numpy_extended_dot_python_l1l2l3(m1_dim, axes, left, right)
438 names, kind, cols, common, broadcast, pos = (
439 _numpy_extended_dot_python_intermediate(
440 m1.shape, m2.shape, l1, l2, l3))
442 if any(broadcast):
443 l1, l2, l3 = _numpy_extended_dot_python_update_broadcast(
444 m1, m2, axes, left, right, l1, l2, l3, names, broadcast, cols,
445 kind, common, verbose=verbose)
447 names, kind, cols, common, broadcast, pos = (
448 _numpy_extended_dot_python_intermediate(
449 m1.shape, m2.shape, l1, l2, l3))
451 indices = numpy.array([0 for n in names], dtype=numpy.int64)
452 pl1 = numpy.array([names.index(c) for c in l1], dtype=numpy.int64)
453 pl2 = numpy.array([names.index(c) for c in l2], dtype=numpy.int64)
454 limits = numpy.array(
455 [m1.shape[pos[n]] if (kind[n] & 1) == 1 else m2.shape[pos[n]]
456 for n in range(len(names))], dtype=numpy.int64)
457 plo = numpy.array(
458 [-1 if c not in names else names.index(c) for c in l3],
459 dtype=numpy.int64)
461 if verbose:
462 print("[GENERICDOT] %s,%s->%s or %s" % (
463 "".join(l1), "".join(l2), "".join(l3),
464 _numpy_extended_dot_equation(
465 len(m1.shape), len(m1.shape), axes, left, right)))
466 print("[GENERICDOT] shape1=%r shape2=%r shape=%r" % (
467 m1.shape, m2.shape, res.shape))
468 print(f"[GENERICDOT] axes={axes!r} left={left!r} right={right!r}")
469 print(f"[GENERICDOT] pl1={pl1!r} pl2={pl2!r} plo={plo!r}")
470 print("[GENERICDOT] names=%s kind=%r common=%s broadcast=%s" % (
471 "".join(names), kind.tolist(),
472 dispb(common), dispb(broadcast)))
473 print(f"[GENERICDOT] pos={pos.tolist()!r}")
474 print(f"[GENERICDOT] cols={cols!r}")
475 print(f"[GENERICDOT] limits={limits!r}")
477 while indices[0] < limits[0]:
479 # The function spends most of its time is these three lines.
480 t1 = tuple(indices[n] for n in pl1)
481 t2 = tuple(indices[n] for n in pl2)
482 to = tuple(0 if n == -1 else indices[n] for n in plo)
484 c = m1[t1] * m2[t2]
486 if verbose:
487 print(f" {t1!r} x {t2!r} -> {to!r} v={c!r} I={indices!r}")
489 res[to] += c
491 last = len(indices) - 1
492 indices[last] += 1
493 for i in range(last, 0, -1):
494 if indices[i] < limits[i]:
495 break
496 indices[i] = 0
497 if i > 0:
498 indices[i - 1] += 1
500 return res
503def numpy_extended_dot_matrix(m1, m2, axes, left, right, verbose=False):
504 """
505 Implementation of @see fn numpy_extended_dot using dot product,
506 multiplication, transpose and reduction
507 but not a custom python implementation like
508 @see fn numpy_extended_dot_python.
510 .. runpython::
511 :showcode:
513 import numpy
514 from mlprodict.testing.einsum import numpy_extended_dot_matrix
515 from mlprodict.testing.einsum.einsum_impl_ext import (
516 _numpy_extended_dot_equation)
518 a = numpy.arange(6).reshape((3, 2, 1))
519 b = numpy.arange(12).reshape((3, 1, 4))
521 print(numpy_extended_dot_matrix(
522 a, b, axes=(0, ), left=(1,), right=(2,)))
524 # Equivalent einsum equation
525 print('equation', _numpy_extended_dot_equation(
526 len(a.shape), len(a.shape), axes=(0, ), left=(1,), right=(2,)))
528 # Same einsum computation written in a different way.
529 print(numpy.einsum('kix,kxj->xij', a, b))
530 """
531 _common_check_numpy_extended_dot(m1, m2, axes, left, right)
533 if verbose:
534 print( # pragma: no cover
535 "[GENERICDOT] shape1=%r shape2=%r axes=%r "
536 "left=%r right=%r -- %s" % (
537 m1.shape, m2.shape, axes, left, right,
538 _numpy_extended_dot_equation(
539 len(m1.shape), len(m1.shape), axes, left, right)))
541 if len(axes) == 0 and len(set(left) & set(right)) == 0:
542 # Simple multiplication
543 res = m1 * m2
544 if verbose:
545 print( # pragma: no cover
546 f"[GENERICDOT] Mul {m1.shape!r} @ {m2.shape!r} -> {res.shape!r}")
547 return res
549 if (len(set(axes) & set(left)) == 0 and
550 len(set(axes) & set(right)) == 0):
552 # No intersection between axes and right: matrix multiplication
553 # ReduceSum
554 right_no_left = set(right) - (set(right) & (set(left) | set(axes)))
555 if right_no_left:
556 red1 = m1.sum(axis=tuple(sorted(right_no_left)), keepdims=True)
557 if verbose:
558 print("[GENERICDOT] reducesumL=%r, %r -> %r" % (
559 right_no_left, m1.shape, red1.shape))
560 else:
561 red1 = m1
563 left_no_right = set(left) - (set(left) & (set(right) | set(axes)))
564 if left_no_right:
565 red2 = m2.sum(axis=tuple(sorted(left_no_right)), keepdims=True)
566 if verbose:
567 print("[GENERICDOT] reducesumR=%r, %r -> %r" % (
568 left_no_right, m2.shape, red2.shape))
569 else:
570 red2 = m2
572 # Transpose
573 common_axes = sorted(set(left) & set(right))
574 i_axes = [(-1 if i in common_axes
575 else (1 if i in axes else 0), i)
576 for i in range(len(m1.shape))]
577 i_axes.sort()
578 perm = [_[1] for _ in i_axes]
579 trm1 = numpy.transpose(red1, axes=perm)
580 trm2 = numpy.transpose(red2, axes=perm)
581 if verbose:
582 print(
583 f"[GENERICDOT] transposeL={perm!r}, {red1.shape!r} -> {trm1.shape!r}")
584 print(
585 f"[GENERICDOT] transposeR={perm!r}, {red2.shape!r} -> {trm2.shape!r}")
586 final_shape = numpy_extended_dot_ouput_shape(
587 m1, m2, axes, left, right)
588 perm_left = [i for i in range(len(perm)) if perm[i] in left]
589 perm_right = [i for i in range(len(perm)) if perm[i] in right]
590 perm_common_axes = [i for i in range(len(perm))
591 if perm[i] in common_axes]
593 if verbose:
594 print("[GENERICDOT] MatMul %r @ %r -> %r -- %s" % (
595 m1.shape, m2.shape, final_shape,
596 _numpy_extended_dot_equation(
597 len(m1.shape), len(m1.shape), axes, left, right)))
598 print(f"[GENERICDOT] axes={axes!r} left={left!r} right={right!r}")
599 print("[GENERICDOT] perm=%r perm_left=%r "
600 "perm_right=%r perm_common_axes=%r" % (
601 perm, perm_left, perm_right, perm_common_axes))
603 # Reshape
604 dim0 = int(numpy.prod([trm1.shape[i] for i in perm_common_axes]))
605 dim0b = int(numpy.prod([trm2.shape[i] for i in perm_common_axes]))
606 if len(axes) > 0:
607 all_axes = list(range(0, len(m1.shape)))
608 new_axes = all_axes[-len(axes):]
609 else:
610 new_axes = []
611 dim1 = int(numpy.prod([trm1.shape[i] for i in new_axes]))
612 dim2 = int(numpy.prod([trm2.shape[i] for i in new_axes]))
613 if dim1 != dim2:
614 raise RuntimeError( # pragma: no cover
615 "Summation axis do not have the same length %d != %d, "
616 "trshape1=%r trshape2=%r "
617 "p_axes=%r p_left=%r p_right=%r p_common=%r"
618 "." % (dim1, dim2, trm1.shape, trm2.shape,
619 new_axes, perm_left, perm_right, perm_common_axes))
620 else:
621 shm1 = trm1.reshape((dim0, -1, dim1))
622 shm2 = trm2.reshape((dim0b, -1, dim2))
624 if verbose:
625 print("[GENERICDOT] Reshape %r @ %r -> %r @ %r" % (
626 (dim0, -1, dim1), (dim0, -1, dim2),
627 shm1.shape, shm2.shape))
628 print("[GENERICDOT] matmul")
630 # Multiplication (this should be done in a different way.
631 res = shm1 @ numpy.transpose(shm2, axes=(0, 2, 1))
633 if verbose:
634 print(f"[GENERICDOT] Shape after multiplication {res.shape}")
636 # Transpose again
637 not_in_both = []
638 for i in range(0, len(m1.shape)):
639 if i not in left and i not in right:
640 not_in_both.append(i)
641 ordered_axes = (common_axes +
642 list(i for i in left if i not in right) +
643 list(i for i in right if i not in left) +
644 not_in_both)
646 perm_not_in_both = [i for i in range(len(perm))
647 if perm[i] in not_in_both]
648 current_shape = ([max(trm1.shape[i], trm2.shape[i])
649 for i in sorted(perm_common_axes)] +
650 [trm1.shape[i] for i in sorted(perm_left)
651 if i not in perm_common_axes] +
652 [trm2.shape[i] for i in sorted(perm_right)
653 if i not in perm_common_axes] +
654 [1 for i in perm_not_in_both])
656 if verbose:
657 print("[GENERICDOT] current_shape=%r final_shape=%r "
658 "last_shape=%r" % (current_shape, final_shape, res.shape))
660 if len(current_shape) != len(final_shape):
661 raise RuntimeError( # pragma: no cover
662 "Shapes mismatch %r > %r, "
663 "shape1=%r shape2=%r axes=%r left=%r right=%r." % (
664 current_shape, final_shape,
665 m1.shape, m2.shape, axes, left, right))
667 res = res.reshape(current_shape)
669 perm = [(a, i) for i, a in enumerate(ordered_axes)]
670 perm.sort()
671 perm = [p[1] for p in perm]
673 if verbose:
674 print(f"[GENERICDOT] ordered_axes={ordered_axes!r} perm={perm!r}")
676 return numpy.transpose(res, axes=perm)
678 else:
679 # Multiplication and Matrix multiplication at the same time.
680 l_axes = set(left) & set(axes)
681 r_axes = set(right) & set(axes)
682 if r_axes and not l_axes:
683 new_axes = list(a for a in axes if a not in right)
684 new_left = list(sorted(set(left) | r_axes))
685 if verbose: # pragma: no cover
686 eq1 = _numpy_extended_dot_equation(
687 len(m1.shape), len(m1.shape), axes, left, right)
688 eq2 = _numpy_extended_dot_equation(
689 len(m1.shape), len(m1.shape), new_axes, new_left, right)
690 print("[GENERICDOT] replace left %r by %r axes %r by %r, "
691 "eq %r by %r" % (
692 left, new_left, axes, new_axes, eq1, eq2))
693 return numpy_extended_dot_matrix(m1, m2, new_axes, new_left, right,
694 verbose=verbose)
695 raise RuntimeError( # pragma: no cover
696 "shape1=%r shape2=%r axes=%r left=%r right=%r eq=%s." % (
697 m1.shape, m2.shape, axes, left, right,
698 _numpy_extended_dot_equation(
699 len(m1.shape), len(m1.shape), axes, left, right)))