Coverage for mlprodict/testing/einsum/einsum_impl_ext.py: 96%

271 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Functions implemented einsum computation for two 

4matrices having the same dimensions. 

5""" 

6import numpy 

7 

8 

9def numpy_diagonal(m, axis, axes): 

10 """ 

11 Extracts diagonal coefficients from an array. 

12 

13 :param m: input array 

14 :param axis: kept axis among the diagonal ones 

15 :param axes: diagonal axes (axis must be one of them) 

16 :return: output 

17 

18 .. runpython:: 

19 :showcode: 

20 

21 import numpy 

22 from mlprodict.testing.einsum import numpy_diagonal 

23 

24 mat = numpy.arange(8).reshape((2, 2, 2)) 

25 print(mat) 

26 diag = numpy_diagonal(mat, 1, [1, 2]) 

27 print(diag) 

28 """ 

29 if axis not in axes: 

30 raise RuntimeError( 

31 f"axis {axis!r} must be in axes {axes!r}.") 

32 shape = [] 

33 new_shape = [] 

34 for i, s in enumerate(m.shape): 

35 if i in axes: 

36 if i == axis: 

37 shape.append(s) 

38 new_shape.append(s) 

39 else: 

40 shape.append(1) 

41 else: 

42 shape.append(s) 

43 new_shape.append(s) 

44 

45 # Extracts coefficients. 

46 output = numpy.empty(tuple(shape), dtype=m.dtype) 

47 index_in = [slice(s) for s in m.shape] 

48 index_out = [slice(s) for s in m.shape] 

49 for i in range(0, shape[axis]): 

50 for a in axes: 

51 index_in[a] = i 

52 index_out[a] = i if a == axis else 0 

53 output[tuple(index_out)] = m[tuple(index_in)] 

54 

55 # Removes axis. 

56 return output.reshape(tuple(new_shape)) 

57 

58 

59def _numpy_extended_dot_equation(m1_dim, m2_dim, axes, left, right): 

60 """ 

61 Returns the equation equivalent to an extended version 

62 of an aligned matrix multiplication 

63 (see @see fn numpy_extended_dot). 

64 

65 :param m1: number of dimensions of the first matrix 

66 :param m2: number of dimensions of the second matrix 

67 :param axes: summation axes 

68 :param axes: summation axes 

69 :param left: left axes 

70 :param right: right axes 

71 :return: equation 

72 

73 .. runpython:: 

74 :showcode: 

75 

76 import numpy 

77 from mlprodict.testing.einsum.einsum_impl_ext import ( 

78 numpy_extended_dot_python, _numpy_extended_dot_equation) 

79 

80 a = numpy.arange(6).reshape((3, 2, 1)) 

81 b = numpy.arange(12).reshape((3, 1, 4)) 

82 

83 print(numpy_extended_dot_python( 

84 a, b, axes=(0, ), left=(1,), right=(2,))) 

85 

86 # Equivalent einsum equation 

87 print('equation', _numpy_extended_dot_equation( 

88 len(a.shape), len(a.shape), axes=(0, ), left=(1,), right=(2,))) 

89 

90 # Same einsum computation written in a different way. 

91 print(numpy.einsum('kix,kxj->xij', a, b)) 

92 """ 

93 if m1_dim != m2_dim: 

94 raise RuntimeError( 

95 "Matrices m1 and m2 must have the same number of dimensions, " 

96 "m1=%r, m2=%r." % (m1_dim, m2_dim)) 

97 total = set(axes) | set(left) | set(right) 

98 if len(total) > m1_dim: 

99 raise ValueError( 

100 "Whole set of involved axes should be inferior to the number " 

101 "of dimensions: %r = {%r} | {%r} | {%r} has more than %d elements" 

102 "." % (total, axes, left, right, m1_dim)) 

103 

104 def _check_(axs, n): 

105 for a in axs: 

106 if a < 0 or a >= n: 

107 raise ValueError( 

108 "One axis %d (in %r) is negative or above the maximum " 

109 "dimension %d." % (a, axs, n)) 

110 _check_(axes, m1_dim) 

111 _check_(left, m1_dim) 

112 _check_(right, m1_dim) 

113 

114 l1 = [chr(i + 97) for i in range(m1_dim)] 

115 l2 = [chr(i + 97) for i in range(m1_dim)] 

116 l3 = [chr(i + 97) for i in range(m1_dim)] 

117 for a in left: 

118 l1[a] = l1[a].upper() 

119 l3[a] = l3[a].upper() 

120 for a in right: 

121 l2[a] = l2[a].upper() 

122 l3[a] = l3[a].upper() 

123 for a in axes: 

124 l1[a] = l1[a].lower() 

125 l2[a] = l2[a].lower() 

126 if a not in right: 

127 l3[a] = None 

128 else: 

129 l3[a] = l3[a].lower() 

130 eq = f"{''.join(l1)},{''.join(l2)}->{''.join(s for s in l3 if s)}" 

131 return eq 

132 

133 

134def _common_check_numpy_extended_dot(m1, m2, axes, left, right): 

135 """ 

136 Common verifications for all implementations of 

137 @see fn numpy_extended_dot. 

138 """ 

139 if m1.dtype != m2.dtype: 

140 raise TypeError( 

141 f"Both matrices should share the same dtype {m1.dtype!r} != {m2.dtype!r}.") 

142 m1_dim = len(m1.shape) 

143 m2_dim = len(m2.shape) 

144 if m1_dim != m2_dim: 

145 raise RuntimeError( # pragma: no cover 

146 "Matrices m1 and m2 must have the same number of dimensions, " 

147 "m1=%r, m2=%r." % (m1_dim, m2_dim)) 

148 total = set(axes) | set(left) | set(right) 

149 if len(total) > m1_dim: 

150 raise ValueError( 

151 "Whole set of involved axes should be inferior to the number " 

152 "of dimensions: %r = {%r} | {%r} | {%r} has more than %d elements" 

153 "." % (total, axes, left, right, m1_dim)) 

154 

155 

156def numpy_extended_dot(m1, m2, axes, left, right, verbose=False): 

157 """ 

158 Extended version of a matrix multiplication (:epkg:`numpy:dot`) 

159 with two matrices *m1*, *m2* of the same dimensions. 

160 Loops over *left* axes for *m1* and *right* axes for *m2*, 

161 summation is done over *axes*. 

162 Other axes must be empty. 

163 This multiplication combines matrix multiplication (dot) 

164 and broadcasted multiplication term by term. 

165 

166 :param m1: first matrix 

167 :param m2: second matrix 

168 :param axes: summation axes 

169 :param left: left axes 

170 :param right: right axes 

171 :param verbose: display intermediate information 

172 :return: output 

173 

174 The dot product is equivalent to: 

175 

176 .. runpython:: 

177 :showcode: 

178 

179 import numpy 

180 from mlprodict.testing.einsum import numpy_extended_dot 

181 

182 m1 = numpy.arange(4).reshape((2, 2)) 

183 m2 = m1 + 10 

184 print("dot product") 

185 print(m1 @ m2) 

186 

187 dm1 = m1.reshape((2, 2, 1)) 

188 dm2 = m2.reshape((1, 2, 2)) 

189 dot = numpy_extended_dot(dm1, dm2, axes=[1], left=[0], right=[2], 

190 verbose=True) 

191 print("extended dot product") 

192 print(dot) 

193 

194 Empty axes should be squeezed to get identical results. 

195 Dot product when the second matrix is transposed. 

196 

197 .. runpython:: 

198 :showcode: 

199 

200 import numpy 

201 from mlprodict.testing.einsum import numpy_extended_dot 

202 

203 m1 = numpy.arange(4).reshape((2, 2)) 

204 m2 = m1 + 10 

205 print("dot product") 

206 print(m1 @ m2.T) 

207 

208 dm1 = m1.reshape((2, 1, 2)) 

209 dm2 = m2.reshape((1, 2, 2)) 

210 dot = numpy_extended_dot(dm1, dm2, axes=[2], left=[0], right=[1], 

211 verbose=True) 

212 print("extended dot product") 

213 print(dot) 

214 

215 An example when right axes include the summation axis. 

216 

217 .. runpython:: 

218 :showcode: 

219 

220 import numpy 

221 from mlprodict.testing.einsum import numpy_extended_dot 

222 

223 m1 = numpy.arange(4).reshape((2, 2)) 

224 m2 = m1 + 10 

225 dm1 = m1.reshape((2, 2, 1)) 

226 dm2 = m2.reshape((1, 2, 2)) 

227 dot = numpy_extended_dot(dm1, dm2, axes=[2], left=[0], right=[1, 2], 

228 verbose=True) 

229 print(dot) 

230 

231 Example in higher dimension: 

232 

233 .. runpython:: 

234 :showcode: 

235 

236 import numpy 

237 from mlprodict.testing.einsum import numpy_extended_dot 

238 

239 m1 = numpy.arange(8).reshape((2, 2, 2)) 

240 m2 = m1 + 10 

241 

242 dot = numpy_extended_dot(m1, m2, [1], [0], [2], verbose=True) 

243 print(dot) 

244 

245 The current implementation still uses :epkg:`numpy:einsum` 

246 but this should be replaced. 

247 """ 

248 _common_check_numpy_extended_dot(m1, m2, axes, left, right) 

249 eq = _numpy_extended_dot_equation( 

250 len(m1.shape), len(m2.shape), axes, left, right) 

251 if verbose: 

252 print(f" [numpy_extended_dot] {eq}: {m1.shape!r} @ {m2.shape!r}") 

253 output = numpy.einsum(eq, m1, m2) 

254 new_shape = list(output.shape) 

255 for a in axes: 

256 if a not in right: 

257 new_shape.insert(a, 1) 

258 if verbose: 

259 print( 

260 f" [numpy_extended_dot] {output.shape!r} reshaped into {new_shape!r} ") 

261 return output.reshape(tuple(new_shape)) 

262 

263 

264def numpy_extended_dot_ouput_shape(m1, m2, axes, left, right): 

265 """ 

266 Computes the output shape of results produced by function 

267 :func:`numpy_extended_dot 

268 <mlprodict.testing.einsum_impl_ext.numpy_extended_dot>` or 

269 :func:`numpy_extended_dot_python 

270 <mlprodict.testing.einsum_impl_ext.numpy_extended_dot_python>`. 

271 """ 

272 _common_check_numpy_extended_dot(m1, m2, axes, left, right) 

273 m1_dim = len(m1.shape) 

274 

275 new_shape = numpy.full(m1_dim, 1, dtype=numpy.int64) 

276 for i in left: 

277 new_shape[i] = m1.shape[i] 

278 for i in right: 

279 if (i in left and m1.shape[i] != m2.shape[i] and 

280 m1.shape[i] != 1 and m2.shape[i] != 1): 

281 raise RuntimeError( # pragma: no cover 

282 "Matrices should have the same dimension for dimension %d, " 

283 "shapes=%r @ %r." % (i, m1.shape, m2.shape)) 

284 new_shape[i] = m2.shape[i] 

285 return new_shape 

286 

287 

288def _numpy_extended_dot_python_l1l2l3(m1_dim, axes, left, right): 

289 l1 = [chr(i + 97) for i in range(m1_dim)] 

290 l2 = [chr(i + 97) for i in range(m1_dim)] 

291 l3 = [chr(i + 97) for i in range(m1_dim)] 

292 for a in left: 

293 l1[a] = l1[a].upper() 

294 l3[a] = l3[a].upper() 

295 for a in right: 

296 l2[a] = l2[a].upper() 

297 l3[a] = l3[a].upper() 

298 for a in axes: 

299 l1[a] = l1[a].lower() 

300 l2[a] = l2[a].lower() 

301 if a not in right: 

302 l3[a] = "-" 

303 else: 

304 l3[a] = l3[a].lower() 

305 return l1, l2, l3 

306 

307 

308def _numpy_extended_dot_python_intermediate(m1_shape, m2_shape, l1, l2, l3): 

309 names = list(sorted(set(l1 + l2))) 

310 kind = numpy.zeros(len(names), dtype=numpy.int64) 

311 cols = {} 

312 

313 for i, n in enumerate(names): 

314 if n in l1: 

315 kind[i] += 1 

316 cols[n] = l1.index(n) 

317 if n in l2: 

318 kind[i] += 2 

319 cols[n] = l2.index(n) 

320 if n in l3: 

321 kind[i] += 4 

322 

323 pos = numpy.zeros(len(names), dtype=numpy.int64) 

324 for j in range(0, pos.shape[0]): 

325 pos[j] = cols[names[j]] 

326 common = [(kind[i] & 3) == 3 for i in range(len(kind))] 

327 broadcast = [common[i] and m1_shape[pos[i]] != m2_shape[pos[i]] 

328 for i in range(len(common))] 

329 

330 return names, kind, cols, common, broadcast, pos 

331 

332 

333def _numpy_extended_dot_python_update_broadcast( 

334 m1, m2, axes, left, right, l1, l2, l3, names, broadcast, cols, 

335 kind, common, verbose=False): 

336 

337 def dispb(c): 

338 return "".join("o" if b else "." for b in c) 

339 

340 if verbose: 

341 print( # pragma: no cover 

342 "[GENERICDOT] before broadcast %s,%s->%s or %s" % ( 

343 "".join(l1), "".join(l2), "".join(l3), 

344 _numpy_extended_dot_equation( 

345 len(m1.shape), len(m1.shape), axes, left, right))) 

346 print( # pragma: no cover 

347 "[GENERICDOT] names=%s kind=%r common=%s broadcast=%s" % ( 

348 "".join(names), kind.tolist(), 

349 dispb(common), dispb(broadcast))) 

350 

351 for i in range(len(broadcast)): # pylint: disable=C0200 

352 if broadcast[i] and not (kind[i] & 3) == 3: 

353 raise RuntimeError( # pragma: no cover 

354 "Broadcast should only happen on common axes, " 

355 "axes=%r left=%r right=%r shape1=%r shape2=%r." 

356 "" % (axes, left, right, m1.shape, m2.shape)) 

357 if not broadcast[i]: 

358 continue 

359 # We split letters. 

360 p = cols[names[i]] 

361 dim = (m1.shape[p], m2.shape[p]) 

362 let = [l1[p], l2[p], l3[p]] 

363 inp = 1 if dim[0] == 1 else 0 

364 if verbose: 

365 print( # pragma: no cover 

366 "[GENERICDOT] name=%s dim=%r let=%r inp=%r p=%r" % ( 

367 names[i], dim, let, inp, p)) 

368 print( # pragma: no cover 

369 f" B0 l1={l1!r}, l2={l2!r} l3={l3!r}") 

370 if (kind[i] & 4) > 0: 

371 # Summation axis is part of the output. 

372 if let[inp].lower() == let[inp]: 

373 let[inp] = let[inp].upper() 

374 else: 

375 let[inp] = let[inp].lower() 

376 l3[p] = let[inp] 

377 if inp == 1: 

378 l2[p] = let[inp] 

379 else: 

380 l1[p] = let[inp] 

381 if verbose: 

382 print( # pragma: no cover 

383 f" B1 l1={l1!r}, l2={l2!r} l3={l3!r}") 

384 else: 

385 # Summation axis is not part of the output. 

386 if let[inp].lower() == let[inp]: 

387 let[inp] = let[inp].upper() 

388 else: 

389 let[inp] = let[inp].lower() 

390 if inp == 1: 

391 l2[p] = let[inp] 

392 else: 

393 l1[p] = let[inp] 

394 if verbose: 

395 print(f" B2 l1={l1!r}, l2={l2!r} l3={l3!r}") 

396 

397 return l1, l2, l3 

398 

399 

400def numpy_extended_dot_python(m1, m2, axes, left, right, verbose=False): 

401 """ 

402 Implementation of @see fn numpy_extended_dot in pure python. 

403 This implementation is not efficient but shows how to 

404 implement this operation without :epkg:`numpy:einsum`. 

405 

406 .. runpython:: 

407 :showcode: 

408 

409 import numpy 

410 from mlprodict.testing.einsum import numpy_extended_dot_python 

411 from mlprodict.testing.einsum.einsum_impl_ext import ( 

412 _numpy_extended_dot_equation) 

413 

414 a = numpy.arange(6).reshape((3, 2, 1)) 

415 b = numpy.arange(12).reshape((3, 1, 4)) 

416 

417 print(numpy_extended_dot_python( 

418 a, b, axes=(0, ), left=(1,), right=(2,))) 

419 

420 # Equivalent einsum equation 

421 print('equation', _numpy_extended_dot_equation( 

422 len(a.shape), len(a.shape), axes=(0, ), left=(1,), right=(2,))) 

423 

424 # Same einsum computation written in a different way. 

425 print(numpy.einsum('kix,kxj->xij', a, b)) 

426 """ 

427 def dispb(c): 

428 return "".join("o" if b else "." for b in c) 

429 

430 new_shape = numpy_extended_dot_ouput_shape(m1, m2, axes, left, right) 

431 m1_dim = len(m1.shape) 

432 

433 # output result 

434 res = numpy.full(tuple(new_shape), 0, dtype=m1.dtype) 

435 

436 # indices 

437 l1, l2, l3 = _numpy_extended_dot_python_l1l2l3(m1_dim, axes, left, right) 

438 names, kind, cols, common, broadcast, pos = ( 

439 _numpy_extended_dot_python_intermediate( 

440 m1.shape, m2.shape, l1, l2, l3)) 

441 

442 if any(broadcast): 

443 l1, l2, l3 = _numpy_extended_dot_python_update_broadcast( 

444 m1, m2, axes, left, right, l1, l2, l3, names, broadcast, cols, 

445 kind, common, verbose=verbose) 

446 

447 names, kind, cols, common, broadcast, pos = ( 

448 _numpy_extended_dot_python_intermediate( 

449 m1.shape, m2.shape, l1, l2, l3)) 

450 

451 indices = numpy.array([0 for n in names], dtype=numpy.int64) 

452 pl1 = numpy.array([names.index(c) for c in l1], dtype=numpy.int64) 

453 pl2 = numpy.array([names.index(c) for c in l2], dtype=numpy.int64) 

454 limits = numpy.array( 

455 [m1.shape[pos[n]] if (kind[n] & 1) == 1 else m2.shape[pos[n]] 

456 for n in range(len(names))], dtype=numpy.int64) 

457 plo = numpy.array( 

458 [-1 if c not in names else names.index(c) for c in l3], 

459 dtype=numpy.int64) 

460 

461 if verbose: 

462 print("[GENERICDOT] %s,%s->%s or %s" % ( 

463 "".join(l1), "".join(l2), "".join(l3), 

464 _numpy_extended_dot_equation( 

465 len(m1.shape), len(m1.shape), axes, left, right))) 

466 print("[GENERICDOT] shape1=%r shape2=%r shape=%r" % ( 

467 m1.shape, m2.shape, res.shape)) 

468 print(f"[GENERICDOT] axes={axes!r} left={left!r} right={right!r}") 

469 print(f"[GENERICDOT] pl1={pl1!r} pl2={pl2!r} plo={plo!r}") 

470 print("[GENERICDOT] names=%s kind=%r common=%s broadcast=%s" % ( 

471 "".join(names), kind.tolist(), 

472 dispb(common), dispb(broadcast))) 

473 print(f"[GENERICDOT] pos={pos.tolist()!r}") 

474 print(f"[GENERICDOT] cols={cols!r}") 

475 print(f"[GENERICDOT] limits={limits!r}") 

476 

477 while indices[0] < limits[0]: 

478 

479 # The function spends most of its time is these three lines. 

480 t1 = tuple(indices[n] for n in pl1) 

481 t2 = tuple(indices[n] for n in pl2) 

482 to = tuple(0 if n == -1 else indices[n] for n in plo) 

483 

484 c = m1[t1] * m2[t2] 

485 

486 if verbose: 

487 print(f" {t1!r} x {t2!r} -> {to!r} v={c!r} I={indices!r}") 

488 

489 res[to] += c 

490 

491 last = len(indices) - 1 

492 indices[last] += 1 

493 for i in range(last, 0, -1): 

494 if indices[i] < limits[i]: 

495 break 

496 indices[i] = 0 

497 if i > 0: 

498 indices[i - 1] += 1 

499 

500 return res 

501 

502 

503def numpy_extended_dot_matrix(m1, m2, axes, left, right, verbose=False): 

504 """ 

505 Implementation of @see fn numpy_extended_dot using dot product, 

506 multiplication, transpose and reduction 

507 but not a custom python implementation like 

508 @see fn numpy_extended_dot_python. 

509 

510 .. runpython:: 

511 :showcode: 

512 

513 import numpy 

514 from mlprodict.testing.einsum import numpy_extended_dot_matrix 

515 from mlprodict.testing.einsum.einsum_impl_ext import ( 

516 _numpy_extended_dot_equation) 

517 

518 a = numpy.arange(6).reshape((3, 2, 1)) 

519 b = numpy.arange(12).reshape((3, 1, 4)) 

520 

521 print(numpy_extended_dot_matrix( 

522 a, b, axes=(0, ), left=(1,), right=(2,))) 

523 

524 # Equivalent einsum equation 

525 print('equation', _numpy_extended_dot_equation( 

526 len(a.shape), len(a.shape), axes=(0, ), left=(1,), right=(2,))) 

527 

528 # Same einsum computation written in a different way. 

529 print(numpy.einsum('kix,kxj->xij', a, b)) 

530 """ 

531 _common_check_numpy_extended_dot(m1, m2, axes, left, right) 

532 

533 if verbose: 

534 print( # pragma: no cover 

535 "[GENERICDOT] shape1=%r shape2=%r axes=%r " 

536 "left=%r right=%r -- %s" % ( 

537 m1.shape, m2.shape, axes, left, right, 

538 _numpy_extended_dot_equation( 

539 len(m1.shape), len(m1.shape), axes, left, right))) 

540 

541 if len(axes) == 0 and len(set(left) & set(right)) == 0: 

542 # Simple multiplication 

543 res = m1 * m2 

544 if verbose: 

545 print( # pragma: no cover 

546 f"[GENERICDOT] Mul {m1.shape!r} @ {m2.shape!r} -> {res.shape!r}") 

547 return res 

548 

549 if (len(set(axes) & set(left)) == 0 and 

550 len(set(axes) & set(right)) == 0): 

551 

552 # No intersection between axes and right: matrix multiplication 

553 # ReduceSum 

554 right_no_left = set(right) - (set(right) & (set(left) | set(axes))) 

555 if right_no_left: 

556 red1 = m1.sum(axis=tuple(sorted(right_no_left)), keepdims=True) 

557 if verbose: 

558 print("[GENERICDOT] reducesumL=%r, %r -> %r" % ( 

559 right_no_left, m1.shape, red1.shape)) 

560 else: 

561 red1 = m1 

562 

563 left_no_right = set(left) - (set(left) & (set(right) | set(axes))) 

564 if left_no_right: 

565 red2 = m2.sum(axis=tuple(sorted(left_no_right)), keepdims=True) 

566 if verbose: 

567 print("[GENERICDOT] reducesumR=%r, %r -> %r" % ( 

568 left_no_right, m2.shape, red2.shape)) 

569 else: 

570 red2 = m2 

571 

572 # Transpose 

573 common_axes = sorted(set(left) & set(right)) 

574 i_axes = [(-1 if i in common_axes 

575 else (1 if i in axes else 0), i) 

576 for i in range(len(m1.shape))] 

577 i_axes.sort() 

578 perm = [_[1] for _ in i_axes] 

579 trm1 = numpy.transpose(red1, axes=perm) 

580 trm2 = numpy.transpose(red2, axes=perm) 

581 if verbose: 

582 print( 

583 f"[GENERICDOT] transposeL={perm!r}, {red1.shape!r} -> {trm1.shape!r}") 

584 print( 

585 f"[GENERICDOT] transposeR={perm!r}, {red2.shape!r} -> {trm2.shape!r}") 

586 final_shape = numpy_extended_dot_ouput_shape( 

587 m1, m2, axes, left, right) 

588 perm_left = [i for i in range(len(perm)) if perm[i] in left] 

589 perm_right = [i for i in range(len(perm)) if perm[i] in right] 

590 perm_common_axes = [i for i in range(len(perm)) 

591 if perm[i] in common_axes] 

592 

593 if verbose: 

594 print("[GENERICDOT] MatMul %r @ %r -> %r -- %s" % ( 

595 m1.shape, m2.shape, final_shape, 

596 _numpy_extended_dot_equation( 

597 len(m1.shape), len(m1.shape), axes, left, right))) 

598 print(f"[GENERICDOT] axes={axes!r} left={left!r} right={right!r}") 

599 print("[GENERICDOT] perm=%r perm_left=%r " 

600 "perm_right=%r perm_common_axes=%r" % ( 

601 perm, perm_left, perm_right, perm_common_axes)) 

602 

603 # Reshape 

604 dim0 = int(numpy.prod([trm1.shape[i] for i in perm_common_axes])) 

605 dim0b = int(numpy.prod([trm2.shape[i] for i in perm_common_axes])) 

606 if len(axes) > 0: 

607 all_axes = list(range(0, len(m1.shape))) 

608 new_axes = all_axes[-len(axes):] 

609 else: 

610 new_axes = [] 

611 dim1 = int(numpy.prod([trm1.shape[i] for i in new_axes])) 

612 dim2 = int(numpy.prod([trm2.shape[i] for i in new_axes])) 

613 if dim1 != dim2: 

614 raise RuntimeError( # pragma: no cover 

615 "Summation axis do not have the same length %d != %d, " 

616 "trshape1=%r trshape2=%r " 

617 "p_axes=%r p_left=%r p_right=%r p_common=%r" 

618 "." % (dim1, dim2, trm1.shape, trm2.shape, 

619 new_axes, perm_left, perm_right, perm_common_axes)) 

620 else: 

621 shm1 = trm1.reshape((dim0, -1, dim1)) 

622 shm2 = trm2.reshape((dim0b, -1, dim2)) 

623 

624 if verbose: 

625 print("[GENERICDOT] Reshape %r @ %r -> %r @ %r" % ( 

626 (dim0, -1, dim1), (dim0, -1, dim2), 

627 shm1.shape, shm2.shape)) 

628 print("[GENERICDOT] matmul") 

629 

630 # Multiplication (this should be done in a different way. 

631 res = shm1 @ numpy.transpose(shm2, axes=(0, 2, 1)) 

632 

633 if verbose: 

634 print(f"[GENERICDOT] Shape after multiplication {res.shape}") 

635 

636 # Transpose again 

637 not_in_both = [] 

638 for i in range(0, len(m1.shape)): 

639 if i not in left and i not in right: 

640 not_in_both.append(i) 

641 ordered_axes = (common_axes + 

642 list(i for i in left if i not in right) + 

643 list(i for i in right if i not in left) + 

644 not_in_both) 

645 

646 perm_not_in_both = [i for i in range(len(perm)) 

647 if perm[i] in not_in_both] 

648 current_shape = ([max(trm1.shape[i], trm2.shape[i]) 

649 for i in sorted(perm_common_axes)] + 

650 [trm1.shape[i] for i in sorted(perm_left) 

651 if i not in perm_common_axes] + 

652 [trm2.shape[i] for i in sorted(perm_right) 

653 if i not in perm_common_axes] + 

654 [1 for i in perm_not_in_both]) 

655 

656 if verbose: 

657 print("[GENERICDOT] current_shape=%r final_shape=%r " 

658 "last_shape=%r" % (current_shape, final_shape, res.shape)) 

659 

660 if len(current_shape) != len(final_shape): 

661 raise RuntimeError( # pragma: no cover 

662 "Shapes mismatch %r > %r, " 

663 "shape1=%r shape2=%r axes=%r left=%r right=%r." % ( 

664 current_shape, final_shape, 

665 m1.shape, m2.shape, axes, left, right)) 

666 

667 res = res.reshape(current_shape) 

668 

669 perm = [(a, i) for i, a in enumerate(ordered_axes)] 

670 perm.sort() 

671 perm = [p[1] for p in perm] 

672 

673 if verbose: 

674 print(f"[GENERICDOT] ordered_axes={ordered_axes!r} perm={perm!r}") 

675 

676 return numpy.transpose(res, axes=perm) 

677 

678 else: 

679 # Multiplication and Matrix multiplication at the same time. 

680 l_axes = set(left) & set(axes) 

681 r_axes = set(right) & set(axes) 

682 if r_axes and not l_axes: 

683 new_axes = list(a for a in axes if a not in right) 

684 new_left = list(sorted(set(left) | r_axes)) 

685 if verbose: # pragma: no cover 

686 eq1 = _numpy_extended_dot_equation( 

687 len(m1.shape), len(m1.shape), axes, left, right) 

688 eq2 = _numpy_extended_dot_equation( 

689 len(m1.shape), len(m1.shape), new_axes, new_left, right) 

690 print("[GENERICDOT] replace left %r by %r axes %r by %r, " 

691 "eq %r by %r" % ( 

692 left, new_left, axes, new_axes, eq1, eq2)) 

693 return numpy_extended_dot_matrix(m1, m2, new_axes, new_left, right, 

694 verbose=verbose) 

695 raise RuntimeError( # pragma: no cover 

696 "shape1=%r shape2=%r axes=%r left=%r right=%r eq=%s." % ( 

697 m1.shape, m2.shape, axes, left, right, 

698 _numpy_extended_dot_equation( 

699 len(m1.shape), len(m1.shape), axes, left, right)))