Coverage for mlprodict/testing/einsum/einsum_impl.py: 97%
254 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Main functions decomposing einsum computation into
4more simple functions.
5"""
6import numpy
7from .einsum_impl_classes import EinsumSubOp, GraphEinsumSubOp
10def analyse_einsum_equation(equation):
11 """
12 Analyses an einsum equation.
14 :param equation: :epkg:`numpy:einsum` equation
15 :return: three results, list of letters,
16 a matrix (see below), lengths of each components,
17 duplicates
19 The returned a matrix is defined as follows:
21 .. math::
23 m_{ij}=\\left\\{\\begin{array}{ll}-1 &
24 \\text{if letter j is involved in input i} \\\\
25 p & \\text{p is position of letter j in equation i}
26 \\end{array}\\right.
27 """
28 spl = equation.strip(' ,').split("->")
29 if len(spl) != 2 or len(spl[1]) == 0 or len(spl[0]) == 0:
30 raise NotImplementedError(
31 "The function only implements the case when there are "
32 "two sides in the equation: %r." % equation)
33 inputs = list(map(lambda s: s.strip(), spl[0].split(',')))
34 output = spl[1]
35 all_letters = set(inputs[0])
37 # Set of letters
38 for inp in inputs[1:]:
39 all_letters |= set(inp)
40 letters = list(sorted(all_letters))
41 for c in letters:
42 if not (('a' <= c <= 'z') or ('A' <= c <= 'Z')):
43 raise ValueError(
44 "Equation %r must only contain lower or upper letters "
45 "but %r is not." % (equation, c))
47 rev = {c: i for i, c in enumerate(letters)}
48 for c in output:
49 if c not in letters:
50 raise ValueError(
51 "Output contains one unexpected letter %r in "
52 "equation %r." % (c, equation))
53 mat = numpy.full((len(inputs) + 1, len(letters)), -1, dtype=numpy.int8)
54 for i, inp in enumerate(inputs):
55 for k, c in enumerate(inp):
56 mat[i, rev[c]] = k
57 for k, c in enumerate(output):
58 mat[len(inputs), rev[c]] = k
59 lengths = [len(inp) for inp in inputs]
60 lengths.append(len(output))
62 # Look for duplicates
63 duplicates = []
64 for inp in inputs + [output]:
65 if len(inp) == len(set(inp)):
66 duplicates.append(None)
67 continue
68 # There is some duplicates.
69 counts = {}
70 for i, c in enumerate(inp):
71 if c in counts:
72 counts[c].append(i)
73 else:
74 counts[c] = [i]
75 duplicates.append(counts)
77 return "".join(letters), mat, lengths, duplicates
80def decompose_einsum_equation(equation, *shapes, strategy="simple",
81 clean=False, verbose=False):
82 """
83 Decomposes an equation used in :epkg:`numpy:einsum` knowing
84 the input shapes. It returns a sequence of operations
85 to do to compute the results.
87 :param equation: a string
88 :param shapes: sequence of input shapes
89 :param strategy: there are different way to decompose the equation,
90 this parameters defines the way to do it (see below)
91 :param clean: clean the unnecessary node in the graph
92 :param verbose: verbosity
93 :return: instance of @see cl GraphEinsumSubOp
95 About *strategy*:
97 * `'simple'`: align all dimensions in the alphabetical order,
98 some generic matrix multiplication remains implemented with
99 :epkg:`numpy:einsum` but only with two matrices aligned on
100 the same dimension (see @see fn numpy_extended_dot)
101 * `'numpy'`: same as `simple` but the decomposition does not use
102 :epkg:`numpy:einsum` anymore but only multiplication or
103 matrix multiplication merged into a single operator called
104 *batch_dot* (see @see fn numpy_extended_dot_matrix)
106 Available operations: *expand_dims*, *transpose*, *matmul*, *reduce_sum*,
107 *id*, *squeeze*, *diagonal*. It analyses an equation and produces a graph
108 where node are instance of class @see cl EinsumSubOp.
110 .. runpython::
111 :showcode:
113 from mlprodict.testing.einsum import decompose_einsum_equation
114 seq = decompose_einsum_equation("bac,cd,def->ebc")
115 for op in seq:
116 print(op)
118 It can be better displayed as the following.
120 .. gdot::
121 :script: DOT-SECTION
122 :process:
124 from mlprodict.testing.einsum import decompose_einsum_equation
125 seq = decompose_einsum_equation(
126 "bac,cd,def->ebc", (2, 2, 2), (2, 2), (2, 2, 2))
127 print("DOT-SECTION", seq.to_dot())
129 See notebook :ref:`einsumdecompositionrst`.
130 """
131 if len(shapes) > 0:
132 for sh in shapes:
133 if not isinstance(sh, tuple):
134 raise TypeError(
135 f"All shapes must be tuples for {sh!r} is not.")
136 if strategy in ("simple", "numpy"):
137 op_matmul = {'simple': 'matmul',
138 'numpy': 'batch_dot'}
139 graph = _decompose_einsum_equation_simple(
140 equation, *shapes, verbose=verbose, op_matmul=op_matmul[strategy])
141 else:
142 raise ValueError(f"Unknown strategy {strategy!r}.")
144 # Last step: clean unused nodes.
145 if clean:
146 last_node = graph.last_added_op
147 graph.append(EinsumSubOp(last_node.full_dim, 'id', last_node))
148 graph.mark_last_node()
149 graph.simplify_mm_nodes(verbose=verbose)
150 graph.remove_duplicate_transpose(verbose=verbose)
151 graph.clean_unused_nodes(verbose=verbose)
152 else:
153 graph.mark_last_node()
154 return graph
157def apply_einsum_sequence(seq, *inputs, verbose=False, **kwargs):
158 """
159 Applies a sequence of operations on a list of inputs.
160 The sequence of operations is produced by function
161 @see fn decompose_einsum_equation.
163 :param seq: sequence of operations
164 :param inputs: inputs
165 :param kwargs: additional parameters,
166 see :meth:`apply_sequence
167 <mlprodict.testing.einsum.einsum_impl_classes.
168 GraphEinsumSubOp.apply_sequence>`.
169 :return: output
171 .. runpython::
172 :showcode:
174 import numpy
175 from mlprodict.testing.einsum import (
176 decompose_einsum_equation, apply_einsum_sequence)
178 m1 = numpy.arange(2 * 2 * 2).reshape((2, 2, 2)) + 10
179 m2 = numpy.arange(4).reshape((2, 2)) + 100
180 m3 = numpy.arange(8).reshape((2, 2, 2)) + 1000
182 seq = decompose_einsum_equation("bac,cd,def->ebc")
183 res = apply_einsum_sequence(seq, m1, m2, m3)
184 print(res)
186 See notebook :ref:`einsumdecompositionrst`.
187 """
188 return seq.apply_sequence(*inputs, verbose=verbose, **kwargs)
191def is_transpose_identity(perm):
192 """
193 Tells if the permutation *perm* does nothing (itentity).
195 :param perm: permutation
196 :return: boolean
197 """
198 return list(perm) == list(range(len(perm)))
201def _basic_verification(lengths, shapes, equation):
202 if len(lengths) - 1 != len(shapes):
203 raise ValueError(
204 "Equation %r has %d inputs but %d shapes are given."
205 "" % (equation, len(lengths), len(shapes)))
206 for i, (le, sh) in enumerate(zip(lengths, shapes)):
207 if le != len(sh):
208 raise ValueError(
209 "Inputs %d has %d dimensions but shapes %r has %d "
210 " in equation %r." % (i, le, sh, len(sh), equation))
213def _apply_transpose_reshape(op, row):
214 """
215 Put all dimensions in the same order.
217 :param op: integer (for one input) or an operator
218 :param row: letter involved in this input (as a vector of binaries)
219 :return: last created operator
220 """
221 axes = []
222 p = 0
223 perm = []
224 for i, r in enumerate(row):
225 if r == -1:
226 axes.append((p, i))
227 else:
228 p += 1
229 perm.append((r, i))
230 op = EinsumSubOp(len(row), 'expand_dims', op, axes=tuple(axes))
231 yield op
232 perm.sort()
233 p = 0
234 new_perm = numpy.arange(len(row))
235 for i, r in enumerate(row):
236 if r == -1:
237 continue
238 new_perm[perm[p][1]] = i
239 p += 1
240 if not is_transpose_identity(new_perm):
241 op = EinsumSubOp(len(row), 'transpose', op, perm=tuple(new_perm))
242 yield op
245def _apply_squeeze_transpose(op, row_last, row_output):
246 """
247 Puts output dimension in the expected order.
248 """
249 perm = []
250 sq = []
251 for i, d in enumerate(row_output):
252 if d == -1:
253 sq.append(i)
254 else:
255 perm.append((d, i))
256 perm.sort()
257 new_perm = numpy.arange(len(row_last))
258 p = 0
259 for i, d in enumerate(row_output):
260 if d == -1:
261 continue
262 new_perm[i] = perm[p][1]
263 p += 1
264 perm = [p[1] for p in perm]
265 if not is_transpose_identity(new_perm):
266 op = EinsumSubOp(len(row_last), 'transpose', op,
267 perm=tuple(new_perm))
268 yield op
269 if len(sq) > 0:
270 op = EinsumSubOp(len(row_last), 'squeeze', op, axes=tuple(sq))
271 yield op
274def _apply_einsum_matmul(fd, op1, op2, axes, left, right, ndim,
275 op_matmul, row1, row2, verbose=False):
276 """
277 Decomposes the generic matrix multiplication into numpy operations
278 depending on the operator to use for matrix multiplication
279 *op_matmul* (see @see fn decompose_einsum_equation).
280 """
281 allowed = {'matmul', 'batch_dot', 'dot'}
282 if op_matmul not in allowed:
283 raise ValueError( # pragma: no cover
284 f"Unknown operator op_matmul={op_matmul!r} not in {allowed!r}.")
285 if op_matmul == 'matmul':
286 if verbose: # pragma: no cover
287 print(
288 f" -- MATMUL -> matmul axes={axes!r} left={left!r} right={right!r}")
289 yield EinsumSubOp(fd, 'matmul', op1, op2,
290 axes=axes, left=left, right=right, ndim=ndim)
292 elif len(axes) == 0 and len(set(left) & set(right)) == 0:
293 if verbose: # pragma: no cover
294 print(
295 f" -- MATMUL -> mul axes={axes!r} left={left!r} right={right!r}")
296 yield EinsumSubOp(fd, 'mul', op1, op2)
298 elif (len(set(axes) & set(left)) == 0 and
299 len(set(axes) & set(right)) == 0):
301 # No intersection between axes and right: matrix multiplication
302 if verbose: # pragma: no cover
303 print(" -- MATMUL -> batch_dot axes=%r left=%r right=%r"
304 "" % (axes, left, right))
306 all_axes = set(left) | set(right) | set(axes)
307 common_axes = list(set(left) & set(right))
308 for i in range(ndim):
309 if i not in all_axes:
310 common_axes.append(i)
311 common_axes.sort()
313 # ReduceSum*
314 has_dim = set(i for i in range(len(row1)) if row1[i] >= 0)
315 right_no_left = (set(right) & has_dim) - \
316 (set(right) & (set(left) | set(axes)))
317 if right_no_left:
318 if verbose: # pragma: no cover
319 print(
320 f' -- MATMUL reduce1 has_dim={has_dim!r} axes={right_no_left!r}')
321 op1 = EinsumSubOp(fd, 'reduce_sum_mm', op1, op2,
322 axes=tuple(sorted(right_no_left)))
323 yield op1
325 has_dim = set(i for i in range(len(row2)) if row2[i] >= 0)
326 left_no_right = (set(left) & has_dim) - \
327 (set(left) & (set(right) | set(axes)))
328 if left_no_right:
329 if verbose: # pragma: no cover
330 print(
331 f' -- MATMUL reduce2 has_dim={has_dim!r} axes={left_no_right!r}')
332 op2 = EinsumSubOp(fd, 'reduce_sum', op2,
333 axes=tuple(sorted(left_no_right)))
334 yield op2
336 # Transpose
337 i_axes = [(-1 if i in common_axes
338 else (1 if i in axes else 0), i)
339 for i in range(ndim)]
340 i_axes.sort()
341 perm = [_[1] for _ in i_axes]
342 perm_left = [i for i in range(len(perm)) if perm[i] in left]
343 perm_right = [i for i in range(len(perm)) if perm[i] in right]
344 if not is_transpose_identity(perm):
345 op1 = EinsumSubOp(fd, 'transpose_mm', op1, op2, perm=tuple(perm))
346 yield op1
347 op2 = EinsumSubOp(fd, 'transpose', op2, perm=tuple(perm))
348 yield op2
350 # Reshape
351 all_axes = list(range(0, ndim))
352 new_axes = all_axes[-len(axes):] if len(axes) > 0 else []
353 new_common_axes = all_axes[:len(common_axes)]
354 not_in_both = []
355 for i in range(0, ndim):
356 if i not in left and i not in right and i not in common_axes:
357 not_in_both.append(i)
359 op = EinsumSubOp(fd, 'batch_dot', op1, op2,
360 batch_axes=tuple(new_common_axes),
361 keep_axes=None, sum_axes=tuple(new_axes),
362 left=tuple(perm_left), right=tuple(perm_right),
363 ndim=ndim)
364 yield op
366 # Transpose again
367 ordered_axes = (common_axes +
368 list(i for i in left if i not in right) +
369 list(i for i in right if i not in left) +
370 not_in_both)
371 rev_perm = [(a, i) for i, a in enumerate(ordered_axes)]
372 rev_perm.sort()
373 rev_perm = [p[1] for p in rev_perm]
375 if not is_transpose_identity(rev_perm):
376 op_unused = EinsumSubOp(fd, 'transpose_mm', op1,
377 op, perm=tuple(rev_perm))
378 yield op_unused
379 op = EinsumSubOp(fd, 'transpose', op, perm=tuple(rev_perm))
380 yield op
381 else:
382 raise NotImplementedError( # pragma: no cover
383 "axes and right or left have axes in common, "
384 "axes=%r left=%r right=%r ndim=%r." % (
385 axes, left, right, ndim))
388def _decompose_einsum_equation_simple(equation, *shapes, verbose=False,
389 op_matmul='matmul'):
390 """
391 Applies strategy `simple`, `numpy`
392 defined in by function @see fn decompose_einsum_equation.
394 :param op_matmul: which operator to use for matrix multiplication,
395 a single operator *matmul*, or *batch_dot* with *transposes*,
396 *reduce_sum*, or just *dot*
397 """
398 letters, mat, lengths, duplicates = analyse_einsum_equation(equation)
399 if len(letters) != mat.shape[1]:
400 raise RuntimeError( # pragma: no cover
401 f"Unexpected number of letters {letters!r}, shape={mat.shape!r}.")
402 if len(shapes) == 0:
403 shapes = [(2, ) * le for le in lengths[:-1]]
404 _basic_verification(lengths, shapes, equation)
406 # last_row, current_row (row = shape)
407 rows = numpy.full((2, mat.shape[1]), -1)
408 graph = GraphEinsumSubOp(letters, mat, lengths, duplicates)
409 fd = mat.shape[1]
410 if verbose:
411 print(f"EQUATION={equation!r}")
412 print(f"LETTERS={letters!r}", f"LENGTHS={lengths!r}")
413 print(f"DUPLICATES={duplicates!r}")
415 for i, sh in enumerate(shapes):
416 if verbose:
417 print()
418 print("######### ROW %d shape=%r row=%r" % (i, sh, rows[1, :]))
419 graph.append(i)
421 # Input matrix aligned to the same dimensions.
422 op = EinsumSubOp(fd, 'id', i)
423 op.compute_output_row(rows[1, :], mat[i, :], verbose=verbose)
424 marked = graph.append(op)
426 duplicate = duplicates[i]
427 if duplicate is not None:
428 # Diagonal
429 diag = []
430 for _, v in duplicate.items():
431 if len(v) == 1:
432 continue
433 diag.append((v[0], tuple(v)))
434 op = EinsumSubOp(fd, 'diagonal', op, diag=diag)
435 op.compute_output_row(rows[1, :], mat[i, :], verbose=verbose)
436 tr_row = rows[1, :]
437 marked = graph.append(op)
438 else:
439 diag = None
440 tr_row = mat[i]
442 for op in _apply_transpose_reshape(op, tr_row):
443 op.compute_output_row(rows[1, :], verbose=verbose)
444 marked = graph.append(op)
446 # Reduction? (a dimension not used later)
447 red = []
448 for d in range(0, mat.shape[1]):
449 if (mat[i + 1:, d].max() == -1 and rows[1, d] != -1 and
450 rows[0, d] == -1):
451 red.append(d)
452 if len(red) > 0:
453 if verbose:
454 print(" -- REDUCE1 row=%d axes=%r" % (i, red))
455 print(mat)
456 print(' -')
457 print(rows)
458 op = EinsumSubOp(fd, 'reduce_sum',
459 graph.last_added_op, axes=tuple(red))
460 op.compute_output_row(rows[1, :], verbose=verbose)
461 marked = graph.append(op)
463 if graph.last_op is not None:
464 # Matrix multiplication?
465 common_dims = []
466 left = []
467 right = []
468 for d in range(0, mat.shape[1]):
469 if rows[:, d].min() >= 0:
470 if mat[i + 1:, d].max() >= 0:
471 left.append(d)
472 right.append(d)
473 else:
474 common_dims.append(d)
475 else:
476 if rows[0, d] >= 0:
477 left.append(d)
478 if rows[1, d] >= 0:
479 right.append(d)
480 if verbose:
481 print(f" -- MATMUL common_dims={common_dims!r}")
482 print(rows)
483 for iop in _apply_einsum_matmul(
484 fd, graph.last_op, op, axes=tuple(common_dims),
485 left=tuple(left), right=tuple(right),
486 ndim=rows.shape[1], op_matmul=op_matmul,
487 row1=rows[0, :], row2=rows[1, :], verbose=verbose):
488 op = iop
489 op.compute_output_row(rows[0, :], rows[1, :],
490 ab=True, verbose=verbose)
491 marked = graph.append(op)
493 # End
494 graph.mark(i, marked)
495 rows[0, :] = rows[1, :]
497 # Final output
498 if verbose:
499 print()
500 print(f"######### FIN row={rows[1, :]!r}")
502 if mat[len(shapes), :].max() >= 0:
503 rows[1, :] = mat[len(shapes), :]
504 red = []
505 for d in range(0, mat.shape[1]):
506 if rows[0, d] > 0 and rows[1, d] == -1:
507 red.append(d)
508 elif rows[0, d] == -1 and rows[1, d] >= 0:
509 raise RuntimeError( # pragma: no cover
510 "Issue in equation %r, variable %d, last_result is %r, "
511 "output is %r." % (equation, d, rows[0, :], rows[1, :]))
512 if len(red) > 0:
513 if verbose: # pragma: no cover
514 print(f"-- REDUCE2 axes={red!r}")
515 print(mat)
516 op = EinsumSubOp(fd, 'reduce_sum', op, axes=tuple(red))
517 graph.append(op)
518 op.compute_output_row(rows[1, :], verbose=verbose)
520 # Removes empty axes.
521 for op in _apply_squeeze_transpose(op, rows[1, :], mat[len(shapes), :]):
522 op.compute_output_row(rows[1, :], verbose=verbose)
523 graph.append(op)
524 return graph