Coverage for mlprodict/testing/einsum/einsum_impl.py: 97%

254 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Main functions decomposing einsum computation into 

4more simple functions. 

5""" 

6import numpy 

7from .einsum_impl_classes import EinsumSubOp, GraphEinsumSubOp 

8 

9 

10def analyse_einsum_equation(equation): 

11 """ 

12 Analyses an einsum equation. 

13 

14 :param equation: :epkg:`numpy:einsum` equation 

15 :return: three results, list of letters, 

16 a matrix (see below), lengths of each components, 

17 duplicates 

18 

19 The returned a matrix is defined as follows: 

20 

21 .. math:: 

22 

23 m_{ij}=\\left\\{\\begin{array}{ll}-1 & 

24 \\text{if letter j is involved in input i} \\\\ 

25 p & \\text{p is position of letter j in equation i} 

26 \\end{array}\\right. 

27 """ 

28 spl = equation.strip(' ,').split("->") 

29 if len(spl) != 2 or len(spl[1]) == 0 or len(spl[0]) == 0: 

30 raise NotImplementedError( 

31 "The function only implements the case when there are " 

32 "two sides in the equation: %r." % equation) 

33 inputs = list(map(lambda s: s.strip(), spl[0].split(','))) 

34 output = spl[1] 

35 all_letters = set(inputs[0]) 

36 

37 # Set of letters 

38 for inp in inputs[1:]: 

39 all_letters |= set(inp) 

40 letters = list(sorted(all_letters)) 

41 for c in letters: 

42 if not (('a' <= c <= 'z') or ('A' <= c <= 'Z')): 

43 raise ValueError( 

44 "Equation %r must only contain lower or upper letters " 

45 "but %r is not." % (equation, c)) 

46 

47 rev = {c: i for i, c in enumerate(letters)} 

48 for c in output: 

49 if c not in letters: 

50 raise ValueError( 

51 "Output contains one unexpected letter %r in " 

52 "equation %r." % (c, equation)) 

53 mat = numpy.full((len(inputs) + 1, len(letters)), -1, dtype=numpy.int8) 

54 for i, inp in enumerate(inputs): 

55 for k, c in enumerate(inp): 

56 mat[i, rev[c]] = k 

57 for k, c in enumerate(output): 

58 mat[len(inputs), rev[c]] = k 

59 lengths = [len(inp) for inp in inputs] 

60 lengths.append(len(output)) 

61 

62 # Look for duplicates 

63 duplicates = [] 

64 for inp in inputs + [output]: 

65 if len(inp) == len(set(inp)): 

66 duplicates.append(None) 

67 continue 

68 # There is some duplicates. 

69 counts = {} 

70 for i, c in enumerate(inp): 

71 if c in counts: 

72 counts[c].append(i) 

73 else: 

74 counts[c] = [i] 

75 duplicates.append(counts) 

76 

77 return "".join(letters), mat, lengths, duplicates 

78 

79 

80def decompose_einsum_equation(equation, *shapes, strategy="simple", 

81 clean=False, verbose=False): 

82 """ 

83 Decomposes an equation used in :epkg:`numpy:einsum` knowing 

84 the input shapes. It returns a sequence of operations 

85 to do to compute the results. 

86 

87 :param equation: a string 

88 :param shapes: sequence of input shapes 

89 :param strategy: there are different way to decompose the equation, 

90 this parameters defines the way to do it (see below) 

91 :param clean: clean the unnecessary node in the graph 

92 :param verbose: verbosity 

93 :return: instance of @see cl GraphEinsumSubOp 

94 

95 About *strategy*: 

96 

97 * `'simple'`: align all dimensions in the alphabetical order, 

98 some generic matrix multiplication remains implemented with 

99 :epkg:`numpy:einsum` but only with two matrices aligned on 

100 the same dimension (see @see fn numpy_extended_dot) 

101 * `'numpy'`: same as `simple` but the decomposition does not use 

102 :epkg:`numpy:einsum` anymore but only multiplication or 

103 matrix multiplication merged into a single operator called 

104 *batch_dot* (see @see fn numpy_extended_dot_matrix) 

105 

106 Available operations: *expand_dims*, *transpose*, *matmul*, *reduce_sum*, 

107 *id*, *squeeze*, *diagonal*. It analyses an equation and produces a graph 

108 where node are instance of class @see cl EinsumSubOp. 

109 

110 .. runpython:: 

111 :showcode: 

112 

113 from mlprodict.testing.einsum import decompose_einsum_equation 

114 seq = decompose_einsum_equation("bac,cd,def->ebc") 

115 for op in seq: 

116 print(op) 

117 

118 It can be better displayed as the following. 

119 

120 .. gdot:: 

121 :script: DOT-SECTION 

122 :process: 

123 

124 from mlprodict.testing.einsum import decompose_einsum_equation 

125 seq = decompose_einsum_equation( 

126 "bac,cd,def->ebc", (2, 2, 2), (2, 2), (2, 2, 2)) 

127 print("DOT-SECTION", seq.to_dot()) 

128 

129 See notebook :ref:`einsumdecompositionrst`. 

130 """ 

131 if len(shapes) > 0: 

132 for sh in shapes: 

133 if not isinstance(sh, tuple): 

134 raise TypeError( 

135 f"All shapes must be tuples for {sh!r} is not.") 

136 if strategy in ("simple", "numpy"): 

137 op_matmul = {'simple': 'matmul', 

138 'numpy': 'batch_dot'} 

139 graph = _decompose_einsum_equation_simple( 

140 equation, *shapes, verbose=verbose, op_matmul=op_matmul[strategy]) 

141 else: 

142 raise ValueError(f"Unknown strategy {strategy!r}.") 

143 

144 # Last step: clean unused nodes. 

145 if clean: 

146 last_node = graph.last_added_op 

147 graph.append(EinsumSubOp(last_node.full_dim, 'id', last_node)) 

148 graph.mark_last_node() 

149 graph.simplify_mm_nodes(verbose=verbose) 

150 graph.remove_duplicate_transpose(verbose=verbose) 

151 graph.clean_unused_nodes(verbose=verbose) 

152 else: 

153 graph.mark_last_node() 

154 return graph 

155 

156 

157def apply_einsum_sequence(seq, *inputs, verbose=False, **kwargs): 

158 """ 

159 Applies a sequence of operations on a list of inputs. 

160 The sequence of operations is produced by function 

161 @see fn decompose_einsum_equation. 

162 

163 :param seq: sequence of operations 

164 :param inputs: inputs 

165 :param kwargs: additional parameters, 

166 see :meth:`apply_sequence 

167 <mlprodict.testing.einsum.einsum_impl_classes. 

168 GraphEinsumSubOp.apply_sequence>`. 

169 :return: output 

170 

171 .. runpython:: 

172 :showcode: 

173 

174 import numpy 

175 from mlprodict.testing.einsum import ( 

176 decompose_einsum_equation, apply_einsum_sequence) 

177 

178 m1 = numpy.arange(2 * 2 * 2).reshape((2, 2, 2)) + 10 

179 m2 = numpy.arange(4).reshape((2, 2)) + 100 

180 m3 = numpy.arange(8).reshape((2, 2, 2)) + 1000 

181 

182 seq = decompose_einsum_equation("bac,cd,def->ebc") 

183 res = apply_einsum_sequence(seq, m1, m2, m3) 

184 print(res) 

185 

186 See notebook :ref:`einsumdecompositionrst`. 

187 """ 

188 return seq.apply_sequence(*inputs, verbose=verbose, **kwargs) 

189 

190 

191def is_transpose_identity(perm): 

192 """ 

193 Tells if the permutation *perm* does nothing (itentity). 

194 

195 :param perm: permutation 

196 :return: boolean 

197 """ 

198 return list(perm) == list(range(len(perm))) 

199 

200 

201def _basic_verification(lengths, shapes, equation): 

202 if len(lengths) - 1 != len(shapes): 

203 raise ValueError( 

204 "Equation %r has %d inputs but %d shapes are given." 

205 "" % (equation, len(lengths), len(shapes))) 

206 for i, (le, sh) in enumerate(zip(lengths, shapes)): 

207 if le != len(sh): 

208 raise ValueError( 

209 "Inputs %d has %d dimensions but shapes %r has %d " 

210 " in equation %r." % (i, le, sh, len(sh), equation)) 

211 

212 

213def _apply_transpose_reshape(op, row): 

214 """ 

215 Put all dimensions in the same order. 

216 

217 :param op: integer (for one input) or an operator 

218 :param row: letter involved in this input (as a vector of binaries) 

219 :return: last created operator 

220 """ 

221 axes = [] 

222 p = 0 

223 perm = [] 

224 for i, r in enumerate(row): 

225 if r == -1: 

226 axes.append((p, i)) 

227 else: 

228 p += 1 

229 perm.append((r, i)) 

230 op = EinsumSubOp(len(row), 'expand_dims', op, axes=tuple(axes)) 

231 yield op 

232 perm.sort() 

233 p = 0 

234 new_perm = numpy.arange(len(row)) 

235 for i, r in enumerate(row): 

236 if r == -1: 

237 continue 

238 new_perm[perm[p][1]] = i 

239 p += 1 

240 if not is_transpose_identity(new_perm): 

241 op = EinsumSubOp(len(row), 'transpose', op, perm=tuple(new_perm)) 

242 yield op 

243 

244 

245def _apply_squeeze_transpose(op, row_last, row_output): 

246 """ 

247 Puts output dimension in the expected order. 

248 """ 

249 perm = [] 

250 sq = [] 

251 for i, d in enumerate(row_output): 

252 if d == -1: 

253 sq.append(i) 

254 else: 

255 perm.append((d, i)) 

256 perm.sort() 

257 new_perm = numpy.arange(len(row_last)) 

258 p = 0 

259 for i, d in enumerate(row_output): 

260 if d == -1: 

261 continue 

262 new_perm[i] = perm[p][1] 

263 p += 1 

264 perm = [p[1] for p in perm] 

265 if not is_transpose_identity(new_perm): 

266 op = EinsumSubOp(len(row_last), 'transpose', op, 

267 perm=tuple(new_perm)) 

268 yield op 

269 if len(sq) > 0: 

270 op = EinsumSubOp(len(row_last), 'squeeze', op, axes=tuple(sq)) 

271 yield op 

272 

273 

274def _apply_einsum_matmul(fd, op1, op2, axes, left, right, ndim, 

275 op_matmul, row1, row2, verbose=False): 

276 """ 

277 Decomposes the generic matrix multiplication into numpy operations 

278 depending on the operator to use for matrix multiplication 

279 *op_matmul* (see @see fn decompose_einsum_equation). 

280 """ 

281 allowed = {'matmul', 'batch_dot', 'dot'} 

282 if op_matmul not in allowed: 

283 raise ValueError( # pragma: no cover 

284 f"Unknown operator op_matmul={op_matmul!r} not in {allowed!r}.") 

285 if op_matmul == 'matmul': 

286 if verbose: # pragma: no cover 

287 print( 

288 f" -- MATMUL -> matmul axes={axes!r} left={left!r} right={right!r}") 

289 yield EinsumSubOp(fd, 'matmul', op1, op2, 

290 axes=axes, left=left, right=right, ndim=ndim) 

291 

292 elif len(axes) == 0 and len(set(left) & set(right)) == 0: 

293 if verbose: # pragma: no cover 

294 print( 

295 f" -- MATMUL -> mul axes={axes!r} left={left!r} right={right!r}") 

296 yield EinsumSubOp(fd, 'mul', op1, op2) 

297 

298 elif (len(set(axes) & set(left)) == 0 and 

299 len(set(axes) & set(right)) == 0): 

300 

301 # No intersection between axes and right: matrix multiplication 

302 if verbose: # pragma: no cover 

303 print(" -- MATMUL -> batch_dot axes=%r left=%r right=%r" 

304 "" % (axes, left, right)) 

305 

306 all_axes = set(left) | set(right) | set(axes) 

307 common_axes = list(set(left) & set(right)) 

308 for i in range(ndim): 

309 if i not in all_axes: 

310 common_axes.append(i) 

311 common_axes.sort() 

312 

313 # ReduceSum* 

314 has_dim = set(i for i in range(len(row1)) if row1[i] >= 0) 

315 right_no_left = (set(right) & has_dim) - \ 

316 (set(right) & (set(left) | set(axes))) 

317 if right_no_left: 

318 if verbose: # pragma: no cover 

319 print( 

320 f' -- MATMUL reduce1 has_dim={has_dim!r} axes={right_no_left!r}') 

321 op1 = EinsumSubOp(fd, 'reduce_sum_mm', op1, op2, 

322 axes=tuple(sorted(right_no_left))) 

323 yield op1 

324 

325 has_dim = set(i for i in range(len(row2)) if row2[i] >= 0) 

326 left_no_right = (set(left) & has_dim) - \ 

327 (set(left) & (set(right) | set(axes))) 

328 if left_no_right: 

329 if verbose: # pragma: no cover 

330 print( 

331 f' -- MATMUL reduce2 has_dim={has_dim!r} axes={left_no_right!r}') 

332 op2 = EinsumSubOp(fd, 'reduce_sum', op2, 

333 axes=tuple(sorted(left_no_right))) 

334 yield op2 

335 

336 # Transpose 

337 i_axes = [(-1 if i in common_axes 

338 else (1 if i in axes else 0), i) 

339 for i in range(ndim)] 

340 i_axes.sort() 

341 perm = [_[1] for _ in i_axes] 

342 perm_left = [i for i in range(len(perm)) if perm[i] in left] 

343 perm_right = [i for i in range(len(perm)) if perm[i] in right] 

344 if not is_transpose_identity(perm): 

345 op1 = EinsumSubOp(fd, 'transpose_mm', op1, op2, perm=tuple(perm)) 

346 yield op1 

347 op2 = EinsumSubOp(fd, 'transpose', op2, perm=tuple(perm)) 

348 yield op2 

349 

350 # Reshape 

351 all_axes = list(range(0, ndim)) 

352 new_axes = all_axes[-len(axes):] if len(axes) > 0 else [] 

353 new_common_axes = all_axes[:len(common_axes)] 

354 not_in_both = [] 

355 for i in range(0, ndim): 

356 if i not in left and i not in right and i not in common_axes: 

357 not_in_both.append(i) 

358 

359 op = EinsumSubOp(fd, 'batch_dot', op1, op2, 

360 batch_axes=tuple(new_common_axes), 

361 keep_axes=None, sum_axes=tuple(new_axes), 

362 left=tuple(perm_left), right=tuple(perm_right), 

363 ndim=ndim) 

364 yield op 

365 

366 # Transpose again 

367 ordered_axes = (common_axes + 

368 list(i for i in left if i not in right) + 

369 list(i for i in right if i not in left) + 

370 not_in_both) 

371 rev_perm = [(a, i) for i, a in enumerate(ordered_axes)] 

372 rev_perm.sort() 

373 rev_perm = [p[1] for p in rev_perm] 

374 

375 if not is_transpose_identity(rev_perm): 

376 op_unused = EinsumSubOp(fd, 'transpose_mm', op1, 

377 op, perm=tuple(rev_perm)) 

378 yield op_unused 

379 op = EinsumSubOp(fd, 'transpose', op, perm=tuple(rev_perm)) 

380 yield op 

381 else: 

382 raise NotImplementedError( # pragma: no cover 

383 "axes and right or left have axes in common, " 

384 "axes=%r left=%r right=%r ndim=%r." % ( 

385 axes, left, right, ndim)) 

386 

387 

388def _decompose_einsum_equation_simple(equation, *shapes, verbose=False, 

389 op_matmul='matmul'): 

390 """ 

391 Applies strategy `simple`, `numpy` 

392 defined in by function @see fn decompose_einsum_equation. 

393 

394 :param op_matmul: which operator to use for matrix multiplication, 

395 a single operator *matmul*, or *batch_dot* with *transposes*, 

396 *reduce_sum*, or just *dot* 

397 """ 

398 letters, mat, lengths, duplicates = analyse_einsum_equation(equation) 

399 if len(letters) != mat.shape[1]: 

400 raise RuntimeError( # pragma: no cover 

401 f"Unexpected number of letters {letters!r}, shape={mat.shape!r}.") 

402 if len(shapes) == 0: 

403 shapes = [(2, ) * le for le in lengths[:-1]] 

404 _basic_verification(lengths, shapes, equation) 

405 

406 # last_row, current_row (row = shape) 

407 rows = numpy.full((2, mat.shape[1]), -1) 

408 graph = GraphEinsumSubOp(letters, mat, lengths, duplicates) 

409 fd = mat.shape[1] 

410 if verbose: 

411 print(f"EQUATION={equation!r}") 

412 print(f"LETTERS={letters!r}", f"LENGTHS={lengths!r}") 

413 print(f"DUPLICATES={duplicates!r}") 

414 

415 for i, sh in enumerate(shapes): 

416 if verbose: 

417 print() 

418 print("######### ROW %d shape=%r row=%r" % (i, sh, rows[1, :])) 

419 graph.append(i) 

420 

421 # Input matrix aligned to the same dimensions. 

422 op = EinsumSubOp(fd, 'id', i) 

423 op.compute_output_row(rows[1, :], mat[i, :], verbose=verbose) 

424 marked = graph.append(op) 

425 

426 duplicate = duplicates[i] 

427 if duplicate is not None: 

428 # Diagonal 

429 diag = [] 

430 for _, v in duplicate.items(): 

431 if len(v) == 1: 

432 continue 

433 diag.append((v[0], tuple(v))) 

434 op = EinsumSubOp(fd, 'diagonal', op, diag=diag) 

435 op.compute_output_row(rows[1, :], mat[i, :], verbose=verbose) 

436 tr_row = rows[1, :] 

437 marked = graph.append(op) 

438 else: 

439 diag = None 

440 tr_row = mat[i] 

441 

442 for op in _apply_transpose_reshape(op, tr_row): 

443 op.compute_output_row(rows[1, :], verbose=verbose) 

444 marked = graph.append(op) 

445 

446 # Reduction? (a dimension not used later) 

447 red = [] 

448 for d in range(0, mat.shape[1]): 

449 if (mat[i + 1:, d].max() == -1 and rows[1, d] != -1 and 

450 rows[0, d] == -1): 

451 red.append(d) 

452 if len(red) > 0: 

453 if verbose: 

454 print(" -- REDUCE1 row=%d axes=%r" % (i, red)) 

455 print(mat) 

456 print(' -') 

457 print(rows) 

458 op = EinsumSubOp(fd, 'reduce_sum', 

459 graph.last_added_op, axes=tuple(red)) 

460 op.compute_output_row(rows[1, :], verbose=verbose) 

461 marked = graph.append(op) 

462 

463 if graph.last_op is not None: 

464 # Matrix multiplication? 

465 common_dims = [] 

466 left = [] 

467 right = [] 

468 for d in range(0, mat.shape[1]): 

469 if rows[:, d].min() >= 0: 

470 if mat[i + 1:, d].max() >= 0: 

471 left.append(d) 

472 right.append(d) 

473 else: 

474 common_dims.append(d) 

475 else: 

476 if rows[0, d] >= 0: 

477 left.append(d) 

478 if rows[1, d] >= 0: 

479 right.append(d) 

480 if verbose: 

481 print(f" -- MATMUL common_dims={common_dims!r}") 

482 print(rows) 

483 for iop in _apply_einsum_matmul( 

484 fd, graph.last_op, op, axes=tuple(common_dims), 

485 left=tuple(left), right=tuple(right), 

486 ndim=rows.shape[1], op_matmul=op_matmul, 

487 row1=rows[0, :], row2=rows[1, :], verbose=verbose): 

488 op = iop 

489 op.compute_output_row(rows[0, :], rows[1, :], 

490 ab=True, verbose=verbose) 

491 marked = graph.append(op) 

492 

493 # End 

494 graph.mark(i, marked) 

495 rows[0, :] = rows[1, :] 

496 

497 # Final output 

498 if verbose: 

499 print() 

500 print(f"######### FIN row={rows[1, :]!r}") 

501 

502 if mat[len(shapes), :].max() >= 0: 

503 rows[1, :] = mat[len(shapes), :] 

504 red = [] 

505 for d in range(0, mat.shape[1]): 

506 if rows[0, d] > 0 and rows[1, d] == -1: 

507 red.append(d) 

508 elif rows[0, d] == -1 and rows[1, d] >= 0: 

509 raise RuntimeError( # pragma: no cover 

510 "Issue in equation %r, variable %d, last_result is %r, " 

511 "output is %r." % (equation, d, rows[0, :], rows[1, :])) 

512 if len(red) > 0: 

513 if verbose: # pragma: no cover 

514 print(f"-- REDUCE2 axes={red!r}") 

515 print(mat) 

516 op = EinsumSubOp(fd, 'reduce_sum', op, axes=tuple(red)) 

517 graph.append(op) 

518 op.compute_output_row(rows[1, :], verbose=verbose) 

519 

520 # Removes empty axes. 

521 for op in _apply_squeeze_transpose(op, rows[1, :], mat[len(shapes), :]): 

522 op.compute_output_row(rows[1, :], verbose=verbose) 

523 graph.append(op) 

524 return graph