Coverage for mlprodict/testing/einsum/einsum_impl_classes.py: 96%
889 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# pylint: disable=C0302
2"""
3@file
4@brief Classes representing the sequence of matrix operations to
5implement einsum computation.
6"""
7import numpy
8from onnx import helper, numpy_helper
9from ...onnx_tools.onnx2py_helper import guess_proto_dtype
10from ...npy.xop_variable import guess_numpy_type
11from ... import __max_supported_opset__, get_ir_version
12from .blas_lapack import gemm_dot
13from .einsum_impl_ext import (
14 numpy_extended_dot, numpy_diagonal,
15 _numpy_extended_dot_equation,
16 numpy_extended_dot_python,
17 numpy_extended_dot_matrix)
20def single_axes(axes):
21 """
22 *axes* contains positive values, then it is the position
23 of this axis in the original matrix, otherwise it is -1
24 meaning this axis is an added single dimension to align
25 all the dimensions based on the einsum equation.
27 :param axes: axes described above
28 :return: list of integer in set `{1, 2}`, 1 for
29 a single axis, 2 otherwise
30 """
31 if axes is None:
32 return axes
33 return [(1 if a == -1 else 2) for a in axes]
36class EinsumSubOp:
37 """
38 Defines a sub operation used in Einsum decomposition.
40 :param name: name (reshape, transpose, reduce_sum, matmul, id,
41 squeeze, diagonal, mul, batch_dot)
42 :param inputs: inputs
43 :param kwargs: arguments
45 Operator suffixed by `_mm` (*transpose_mm*, *reduce_sum_mm*)
46 are equivalent to the same operator without the suffix
47 but takes two inputs and only changes the first one.
49 Attributes `_info` summarizes the known information
50 about dimensions. Many of them are empty because inserted.
51 Value `1` means it was the case, `2` means it is a plain dimension.
52 """
53 _allowed = {'expand_dims', 'transpose', 'reduce_sum', 'matmul', 'id',
54 'squeeze', 'diagonal', 'mul', 'batch_dot',
55 'transpose_mm', 'reduce_sum_mm'}
57 def __init__(self, full_dim, name, *inputs, **kwargs):
58 self.full_dim = full_dim
59 self.name = name
60 self.inputs = inputs
61 self.kwargs = kwargs
62 self._info = {}
63 if name not in EinsumSubOp._allowed:
64 raise ValueError(
65 f"Unexpected name {name!r}. It should be in {EinsumSubOp._allowed!r}.")
66 if len(inputs) not in (1, 2):
67 raise RuntimeError(
68 f"Inputs must contains 1 or 2 inputs not {len(inputs)}.")
69 if name == 'matmul' and len(inputs) != 2:
70 raise RuntimeError(
71 "Inputs must contains 2 inputs not %d for operator 'matmul'."
72 "" % len(inputs))
73 for i, inp in enumerate(inputs):
74 if not isinstance(inp, (int, EinsumSubOp)):
75 raise TypeError(
76 "Input %d has type %r, int or EinsumSubOp is expected."
77 "" % (i, type(inp)))
78 self._check_()
80 def _check_(self):
81 if self.name == 'transpose':
82 self._check_arg_('perm', tuple)
83 perm = self.kwargs['perm']
84 if len(perm) != len(set(perm)):
85 raise RuntimeError( # pragma: no cover
86 f"perm has duplicated values {perm!r} (name={self.name!r}).")
87 if list(perm) == list(range(len(perm))):
88 raise ValueError( # pragma: no cover
89 f"Transpose = identity perm={perm}. It must be removed.")
90 elif self.name == 'matmul':
91 self._check_arg_('axes', tuple)
92 self._check_arg_('left', tuple)
93 self._check_arg_('right', tuple)
94 axes = self.kwargs['axes']
95 left = self.kwargs['left']
96 right = self.kwargs['right']
97 for a in axes:
98 if a in left and a in right:
99 raise RuntimeError( # pragma: no cover
100 "One axis belongs to every set (axes, left, right). "
101 "axes=%r, left=%r, right=%r." % (axes, left, right))
103 def __repr__(self):
104 inps = ", ".join(map(str, self.inputs))
105 kw = ", ".join(f"{k}={w!r}" for k, w in self.kwargs.items())
106 m = f"{self.__class__.__name__}({self.name!r}, {inps}, {kw})"
107 return m
109 def dot_label(self):
110 """
111 Displays some informations useful to understand the operator.
112 """
113 if self.name == "matmul":
114 ndim = self.kwargs['ndim']
115 axes = self.kwargs['axes']
116 left = self.kwargs['left']
117 right = self.kwargs['right']
118 eq = _numpy_extended_dot_equation(ndim, ndim, axes, left, right)
119 eq = eq.replace(">", "\\\\>")
120 return "~" + eq
121 return None
123 def _check_arg_(self, name, typ, empty=False):
124 if name not in self.kwargs:
125 raise RuntimeError( # pragma: no cover
126 f"Parameter {name!r} not found for operator {self.name!r}.")
127 if empty and self.kwargs[name] is None:
128 return
129 if not isinstance(self.kwargs[name], typ):
130 raise TypeError( # pragma: no cover
131 "Unexpected type %r for parameter %r and parameter %r."
132 "" % (type(self.kwargs[name]), name, self.name))
134 def _check_row_(self, row, inp=False, verbose=False):
135 """
136 Checks input or output is valid.
137 """
138 if verbose:
139 if inp:
140 print('<<' if inp else '>>', self.name, row, self.kwargs)
141 else:
142 print('<<' if inp else '>>', self.name, row)
144 def _compute_output_row_id(self, row, row2=None, ab=False, verbose=False):
145 if ab:
146 raise RuntimeError("ab option not allowed.") # pragma: no cover
147 self._check_row_(row, True, verbose=verbose)
148 row[:] = row2[:]
149 self._check_row_(row, verbose=verbose)
151 def _compute_output_row_transpose(self, row, row2=None, ab=False, verbose=False):
152 if ab:
153 self._compute_output_row_transpose(row2, verbose=verbose)
154 return
155 self._check_row_(row, True, verbose=verbose)
156 self._check_arg_('perm', tuple)
157 if len(self.kwargs['perm']) != len(row):
158 raise RuntimeError( # pragma: no cover
159 f"Unexpected permutation {self.kwargs['perm']!r} (row={row!r}).")
160 perm = self.kwargs['perm']
161 cpy = row.copy()
162 for i, p in enumerate(perm):
163 row[i] = cpy[p]
164 self._check_row_(row, verbose=verbose)
166 def _compute_output_row_transpose_mm(self, row, row2=None, ab=False, verbose=False):
167 if not ab:
168 raise RuntimeError("ab must be True.") # pragma: no cover
169 self._check_row_(row, True, verbose=verbose)
170 if row2 is None:
171 raise RuntimeError( # pragma: no cover
172 "transpose_mm expects a second input.")
173 self._compute_output_row_transpose(row, row2=None, verbose=verbose)
175 def _compute_output_row_expand_dims(self, row, row2=None, ab=False, verbose=False):
176 if ab:
177 raise RuntimeError("ab option not allowed.") # pragma: no cover
178 self._check_row_(row, True, verbose=verbose)
179 self._check_arg_('axes', tuple)
180 axes = self.kwargs['axes']
181 for axis in axes:
182 if not isinstance(axis, tuple):
183 raise TypeError( # pragma: no cover
184 "Parameter axes of expand_dims should be a tuple of "
185 "tuple, axes=%r." % axes)
186 if row[axis[1]] != -1:
187 raise RuntimeError( # pragma: no cover
188 "Dimension should be -1 in row %r axis=%r." % (
189 row, self.kwargs['axis']))
190 self._check_row_(row, verbose=verbose)
192 def _compute_output_row_reduce_sum(self, row, row2=None, ab=False, verbose=False):
193 if ab:
194 raise RuntimeError("ab option not allowed.") # pragma: no cover
195 self._check_row_(row, True, verbose=verbose)
196 self._check_arg_('axes', tuple)
197 for a in self.kwargs['axes']:
198 row[a] = -1
199 self._check_row_(row, verbose=verbose)
201 def _compute_output_row_reduce_sum_mm(self, row, row2=None, ab=False, verbose=False):
202 if not ab:
203 raise RuntimeError("ab must be true.") # pragma: no cover
204 self._check_row_(row2, True, verbose=verbose)
205 if row2 is None:
206 raise RuntimeError( # pragma: no cover
207 "reduce_sum_mm expects a second input.")
208 self._compute_output_row_reduce_sum(row, row2=None, verbose=verbose)
210 def _compute_output_row_squeeze(self, row, row2=None, ab=False, verbose=False):
211 if ab:
212 raise RuntimeError("ab option not allowed.") # pragma: no cover
213 self._check_row_(row, True, verbose=verbose)
214 self._check_arg_('axes', tuple)
215 for a in self.kwargs['axes']:
216 row[a] = -1
217 self._check_row_(row, verbose=verbose)
219 def _compute_output_row_diagonal(self, row, row2=None, ab=False, verbose=False):
220 if ab:
221 raise RuntimeError("ab option not allowed.") # pragma: no cover
222 self._check_row_(row, True, verbose=verbose)
223 self._check_arg_('diag', list)
224 to_remove = []
225 for choice, choices in self.kwargs['diag']:
226 for ch in choices:
227 if ch != choice:
228 to_remove.append(ch)
229 for i in range(len(row)): # pylint: disable=C0200
230 if row[i] in choices:
231 if row[i] != choice:
232 row[i] = choice
233 to_remove.sort()
234 for r in to_remove:
235 for i in range(len(row)): # pylint: disable=C0200
236 if row[i] == r:
237 raise RuntimeError( # pragma: no cover
238 "Unexpected result r=%r row=%r to_remove=%r "
239 "diag=%r." % (
240 r, row, to_remove, self.kwargs['diag']))
241 if row[i] > r:
242 row[i] -= 1
243 self._check_row_(row, verbose=verbose)
245 def _compute_output_row_matmul(self, row, row2=None, ab=False, verbose=False):
246 if not ab:
247 raise RuntimeError("ab must be True.") # pragma: no cover
248 self._check_row_(row, True, verbose=verbose)
249 self._check_row_(row2, True, verbose=verbose)
250 self._check_arg_('axes', tuple)
251 self._check_arg_('left', tuple)
252 self._check_arg_('right', tuple)
253 self._check_arg_('ndim', int)
254 if row2 is None:
255 raise RuntimeError( # pragma: no cover
256 "matmul expects two inputs.")
257 if verbose:
258 ndim = self.kwargs['ndim']
259 axes = self.kwargs['axes']
260 left = self.kwargs['left']
261 right = self.kwargs['right']
262 print(" MATMUL %r @ %r axes=%r left=%r right=%r - eq=%s" % (
263 row, row2, axes, left, right,
264 _numpy_extended_dot_equation(ndim, ndim, axes, left, right)))
265 row2[:] = numpy.maximum(row, row2)
266 for a in self.kwargs['axes']:
267 if a not in self.kwargs['right']:
268 row2[a] = -1
269 self._check_row_(row2, verbose=verbose)
271 def _compute_output_row_batch_dot(self, row, row2=None, ab=False, verbose=False):
272 if not ab:
273 raise RuntimeError("ab must be True.") # pragma: no cover
274 self._check_row_(row, True, verbose=verbose)
275 self._check_row_(row2, True, verbose=verbose)
276 self._check_arg_('batch_axes', tuple)
277 self._check_arg_('keep_axes', tuple, empty=True)
278 self._check_arg_('sum_axes', tuple)
279 self._check_arg_('left', tuple)
280 self._check_arg_('right', tuple)
281 self._check_arg_('ndim', int)
282 if row2 is None:
283 raise RuntimeError(
284 "batch_dot expects two inputs.") # pragma: no cover
285 if verbose:
286 batch_axes = self.kwargs['batch_axes']
287 keep_axes = self.kwargs['keep_axes']
288 sum_axes = self.kwargs['sum_axes']
289 left = self.kwargs['left']
290 right = self.kwargs['right']
291 ndim = self.kwargs['ndim']
292 print(" BATCH_DOT batch_axes=%r keep_axes=%r sum_axes=%r "
293 "left=%r right=%r eq=%r" % (
294 batch_axes, keep_axes, sum_axes, left, right,
295 _numpy_extended_dot_equation(ndim, ndim, sum_axes, left, right)))
296 row2[:] = numpy.maximum(row, row2)
297 for a in self.kwargs['sum_axes']:
298 if a not in self.kwargs['right']:
299 row2[a] = -1
300 self._check_row_(row2, verbose=verbose)
302 def _compute_output_row_mul(self, row, row2=None, ab=False, verbose=False):
303 if not ab:
304 raise RuntimeError("ab must be True.") # pragma: no cover
305 self._check_row_(row, True, verbose=verbose)
306 self._check_row_(row2, True, verbose=verbose)
307 if row2 is None:
308 raise RuntimeError("mul expects two inputs.") # pragma: no cover
309 if verbose:
310 print( # pragma: no cover
311 f" MUL {row!r} @ {row2!r}")
312 row2[:] = numpy.maximum(row, row2)
313 self._check_row_(row2, verbose=verbose)
315 def compute_output_row(self, row, row2=None, ab=False, verbose=False):
316 """
317 Updates *row* based on the operator.
318 """
319 method_name = f"_compute_output_row_{self.name}"
320 meth = getattr(self, method_name, None)
321 if meth is None:
322 raise NotImplementedError( # pragma: no cover
323 f"compute_output_row not implemented for {self.name!r}.")
324 if verbose and ab:
325 print(" -- called as a binary operator")
326 self.add_info(i_row=single_axes(row), i_row2=single_axes(row2))
327 meth(row, row2=row2, ab=ab, verbose=verbose)
328 self.add_info(o_row=single_axes(row), o_row2=single_axes(row2))
330 def add_info(self, **kwargs):
331 """
332 Adds information to the node.
334 :param kwargs: dictionary
335 """
336 for k, v in kwargs.items():
337 if k in self._info:
338 raise KeyError( # pragma: no cover
339 f"Key {k!r} already added (operator {self.name!r}).")
340 self._info[k] = v
342 def _check_inputs_(self, n_expected, check_dim=False):
343 if len(self.inputs) != n_expected:
344 raise RuntimeError( # pragma: no cover
345 "Number of inputs must be %d not %d for operator %r."
346 "" % (n_expected, len(self.inputs), self.name))
348 def _check_shape_(self, m):
349 if len(m.shape) != self.full_dim:
350 raise RuntimeError( # pragma: no cover
351 "Number of dimensions %r is different from expected value "
352 "%d." % (m.shape, self.full_dim))
354 def _get_data(self, data, key):
355 if isinstance(key, int):
356 if key not in data:
357 raise RuntimeError( # pragma: no cover
358 "Unable to find key %d in %r." % (
359 key, list(sorted(data))))
360 return data[key]
361 if isinstance(key, EinsumSubOp):
362 if id(key) not in data:
363 raise RuntimeError( # pragma: no cover
364 "Unable to find key %d in %r." % (
365 id(key), list(sorted(data))))
366 return data[id(key)]
367 raise TypeError( # pragma: no cover
368 f"Unexpected input type {type(key)!r}.")
370 def _apply_id(self, data, verbose=False, **kwargs):
371 self._check_inputs_(1)
372 inp = self.inputs[0]
373 output = self._get_data(data, inp)
374 return output
376 def _apply_diagonal(self, data, verbose=False, **kwargs):
377 self._check_inputs_(1)
378 inp = self.inputs[0]
379 m = self._get_data(data, inp)
380 if verbose:
381 print( # pragma: no cover
382 f"- {self.name}, shape={m.shape!r} diag={self.kwargs['diag']!r}")
383 diag = self.kwargs['diag']
384 if len(diag) != 1:
385 raise NotImplementedError( # pragma: no cover
386 f"Not implemented with more than one duplicated indice {diag!r}.")
387 diag0 = diag[0]
388 output = numpy_diagonal(m, axis=diag0[0], axes=diag0[1])
389 return output
391 def _apply_expand_dims(self, data, verbose=False, **kwargs):
392 self._check_inputs_(1)
393 inp = self.inputs[0]
394 m = self._get_data(data, inp)
395 if verbose:
396 print(
397 f"- {self.name}, shape={m.shape!r} axes={self.kwargs['axes']!r}")
398 output = m
399 for axis in reversed(self.kwargs['axes']):
400 output = numpy.expand_dims(output, axis[0])
401 return output
403 def _apply_transpose(self, data, verbose=False, **kwargs):
404 self._check_inputs_(1, True)
405 inp = self.inputs[0]
406 m = self._get_data(data, inp)
407 self._check_shape_(m)
408 if verbose:
409 print(
410 f"- {self.name}, shape={m.shape!r} perm={self.kwargs['perm']!r}")
411 output = numpy.transpose(m, self.kwargs['perm'])
412 self._check_shape_(output)
413 return output
415 def _apply_transpose_mm(self, data, verbose=False, **kwargs):
416 self._check_inputs_(2, True)
417 inp = self.inputs[0]
418 m = self._get_data(data, inp)
419 self._check_shape_(m)
420 if verbose:
421 print( # pragma: no cover
422 f"- {self.name}, shape={m.shape!r} perm={self.kwargs['perm']!r}")
423 output = numpy.transpose(m, self.kwargs['perm'])
424 self._check_shape_(output)
425 return output
427 def _apply_matmul(self, data, verbose=False, **kwargs):
428 self._check_inputs_(2)
429 inp1 = self.inputs[0]
430 inp2 = self.inputs[1]
431 m1 = self._get_data(data, inp1)
432 m2 = self._get_data(data, inp2)
433 self._check_shape_(m1)
434 self._check_shape_(m2)
435 axes = self.kwargs['axes']
436 left = self.kwargs['left']
437 right = self.kwargs['right']
439 if verbose:
440 print("- %s, shapes=%r @ %r axes=%r left=%r right=%r" % (
441 self.name, m1.shape, m2.shape, axes, left, right))
443 impl = kwargs.get('matmul_impl', None)
444 if impl == 'pyf':
445 output = numpy_extended_dot_matrix(m1, m2, axes, left, right,
446 verbose=verbose)
447 elif impl == 'py':
448 output = numpy_extended_dot_python(m1, m2, axes, left, right,
449 verbose=verbose)
450 elif impl is None:
451 output = numpy_extended_dot(m1, m2, axes, left, right,
452 verbose=verbose)
453 else:
454 raise ValueError(
455 f"Unknown implementation of numpy_extended_dot ({impl}).")
456 self._check_shape_(output)
457 return output
459 def _apply_mul(self, data, verbose=False, **kwargs):
460 self._check_inputs_(2)
461 inp1 = self.inputs[0]
462 inp2 = self.inputs[1]
463 m1 = self._get_data(data, inp1)
464 m2 = self._get_data(data, inp2)
465 self._check_shape_(m1)
466 self._check_shape_(m2)
468 if verbose:
469 print( # pragma: no cover
470 f"- {self.name}, shapes={m1.shape!r} @ {m2.shape!r}")
472 output = m1 * m2
473 self._check_shape_(output)
474 return output
476 def _apply_batch_dot(self, data, verbose=False, **kwargs):
477 self._check_inputs_(2)
478 inp1 = self.inputs[0]
479 inp2 = self.inputs[1]
480 m1 = self._get_data(data, inp1)
481 m2 = self._get_data(data, inp2)
482 self._check_shape_(m1)
483 self._check_shape_(m2)
484 batch_axes = self.kwargs['batch_axes']
485 keep_axes = self.kwargs['keep_axes']
486 sum_axes = self.kwargs['sum_axes']
487 left = self.kwargs['left']
488 right = self.kwargs['right']
490 if verbose:
491 print("- %s, shapes=%r @ %r batch_axes=%r keep_axes=%r "
492 "sum_axes=%r" % (
493 self.name, m1.shape, m2.shape, batch_axes, keep_axes, sum_axes))
495 if len(m1.shape) != len(m2.shape):
496 raise RuntimeError( # pragma: no cover
497 "batch_dot only work with two tensors with the same number "
498 "of dimensions not %r @ %r." % (m1.shape, m2.shape))
500 dim0 = int(numpy.prod([m1.shape[i] for i in batch_axes]))
501 dim0b = int(numpy.prod([m2.shape[i] for i in batch_axes]))
502 dimb = int(-1 if keep_axes is None else numpy.prod(
503 [m1.shape[i] for i in keep_axes]))
504 dim1 = int(numpy.prod([m1.shape[i] for i in sum_axes]))
505 dim2 = int(numpy.prod([m2.shape[i] for i in sum_axes]))
507 if verbose:
508 print(f"- {self.name}, reshape={m1.shape!r} into {dim0, dimb, dim1!r}")
509 print(f"- {self.name}, reshape={m2.shape!r} into {dim0b, dimb, dim2!r}")
510 m1sh = m1.reshape((dim0, dimb, dim1))
511 m2sh = m2.reshape((dim0b, dimb, dim2))
513 batch_kind = self.get_dot_kind()
514 if batch_kind in ('11', 'N1', 'N1'):
515 m1sh = m1sh.reshape((-1, m1sh.shape[-1]))
516 m2sh = m2sh.reshape((-1, m2sh.shape[-1]))
517 if verbose:
518 print("- %s, use gemm with shape %r, %r" % (
519 self.name, m1sh.shape, m2sh.shape))
520 dot = gemm_dot(m1sh, m2sh, False, True)
521 else:
522 dot = m1sh @ numpy.transpose(m2sh, (0, 2, 1))
524 # new shape
525 new_shape = ([max(m1.shape[i], m2.shape[i]) for i in batch_axes] +
526 [m1.shape[i] for i in left if i not in batch_axes] +
527 [m2.shape[i] for i in right if i not in batch_axes])
528 while len(new_shape) < len(m1.shape):
529 new_shape.append(1)
531 if verbose:
532 taken = set(batch_axes) | set(sum_axes)
533 ax = [i for i in range(len(m1.shape)) if i not in taken]
534 print("- %s, shapes=%r @ %r -> %r" % (
535 self.name, m1sh.shape, m2sh.shape, dot.shape))
536 print("- %s, batch_axes=%r ax=%r new_shape=%r left=%r right=%r" % (
537 self.name, batch_axes, ax, new_shape, left, right))
539 output = dot.reshape(tuple(new_shape))
540 self._check_shape_(output)
541 return output
543 def _apply_reduce_sum(self, data, verbose=False, **kwargs):
544 self._check_inputs_(1)
545 inp = self.inputs[0]
546 m = self._get_data(data, inp)
547 self._check_shape_(m)
548 axes = self.kwargs['axes']
549 if verbose:
550 print(
551 f"- {self.name}, shape={m.shape!r} axes={self.kwargs['axes']!r}")
552 output = numpy.sum(m, axis=axes, keepdims=True)
553 self._check_shape_(output)
554 return output
556 def _apply_reduce_sum_mm(self, data, verbose=False, **kwargs):
557 self._check_inputs_(2, True)
558 inp = self.inputs[0]
559 m = self._get_data(data, inp)
560 self._check_shape_(m)
561 if verbose:
562 print(
563 f"- {self.name}, shape={m.shape!r} axes={self.kwargs['axes']!r}")
564 output = numpy.sum(m, self.kwargs['axes'])
565 self._check_shape_(output)
566 return output
568 def _apply_squeeze(self, data, verbose=False, **kwargs):
569 self._check_inputs_(1)
570 inp = self.inputs[0]
571 m = self._get_data(data, inp)
572 axes = self.kwargs['axes']
573 if verbose:
574 print(
575 f"- {self.name}, shape={m.shape!r} axes={self.kwargs['axes']!r}")
576 output = m
577 for a in axes[::-1]:
578 output = numpy.squeeze(output, axis=a)
579 return output
581 def apply(self, data, verbose=False, **kwargs):
582 """
583 Applies one operator on the data.
585 :param data: dictionary storing the results
586 :param verbose: prints out intermediate results
587 :param kwargs: additional parameters, see
588 methods `_apply*`
589 :return: output
591 Known additional paramaters:
593 * 'matmul_impl': if None calls :epkg:`numpy:einsum` through
594 @see fn numpy_extended_dot (default) or 'py' to call
595 @see fn numpy_extended_dot_python instead.
596 """
597 if verbose:
598 print()
599 print("apply %r (%s)." % (
600 self.name, ", ".join(map(lambda s: str(id(s)), self.inputs))))
602 method_name = f"_apply_{self.name}"
603 meth = getattr(self, method_name, None)
604 if meth is None:
605 raise NotImplementedError( # pragma: no cover
606 f"apply not implemented for {self.name!r}.")
607 output = meth(data, verbose, **kwargs)
609 data[id(self)] = output
610 if verbose:
611 print("+ %s, shape=%r -- %d" % (self.name, output.shape, id(self)))
612 return output
614 def _onnx_name(self):
615 return 'einsum%d_%s' % (id(self), self.name[:2])
617 def _check_onnx_opset_(self, opset, limit):
618 if opset is not None and opset < limit:
619 raise RuntimeError( # pragma: no cover
620 f"Opset ({opset!r}) must be >= {limit!r} for operator {self.name!r}.")
622 def _to_onnx_id(self, names, opset, verbose=False, **kwargs):
623 self._check_inputs_(1)
624 inp = self.inputs[0]
625 name = self._get_data(names, inp)
626 yield helper.make_node('Identity', [name], [self._onnx_name()])
628 def _to_onnx_expand_dims(self, names, opset, verbose=False, **kwargs):
629 self._check_inputs_(1)
630 self._check_onnx_opset_(opset, 11)
631 inp = self.inputs[0]
632 name = self._get_data(names, inp)
633 axes = self.kwargs['axes']
634 name_axes = name + '_axes'
635 yield numpy_helper.from_array(
636 numpy.array([a[1] for a in axes], dtype=numpy.int64), name=name_axes)
637 s_axes = "".join(map(str, [a[1] for a in axes]))
638 yield helper.make_node(
639 'Unsqueeze', [name, name_axes], [self._onnx_name()],
640 name='Unsqueeze%s_%d' % (s_axes, id(self)))
642 def _to_onnx_squeeze(self, names, opset, verbose=False, **kwargs):
643 self._check_inputs_(1)
644 self._check_onnx_opset_(opset, 11)
645 inp = self.inputs[0]
646 name = self._get_data(names, inp)
647 axes = self.kwargs['axes']
648 name_axes = name + '_axes'
649 yield numpy_helper.from_array(
650 numpy.array(axes, dtype=numpy.int64), name=name_axes)
651 s_axes = "".join(map(str, axes))
652 yield helper.make_node(
653 'Squeeze', [name, name_axes], [self._onnx_name()],
654 name='Squeeze%s_%d' % (s_axes, id(self)))
656 def _to_onnx_transpose(self, names, opset, verbose=False, **kwargs):
657 self._check_inputs_(1)
658 inp = self.inputs[0]
659 name = self._get_data(names, inp)
660 perm = self.kwargs['perm']
661 s_perm = "".join(map(str, perm))
662 yield helper.make_node(
663 'Transpose', [name], [self._onnx_name()], perm=perm,
664 name='Transpose%s_%d' % (s_perm, id(self)))
666 def _to_onnx_reduce_sum(self, names, opset, verbose=False, **kwargs):
667 self._check_inputs_(1)
668 self._check_onnx_opset_(opset, 11)
669 inp = self.inputs[0]
670 name = self._get_data(names, inp)
671 axes = self.kwargs['axes']
672 name_axes = self._onnx_name() + '_axes'
673 yield numpy_helper.from_array(
674 numpy.array(axes, dtype=numpy.int64), name=name_axes)
675 s_axes = "".join(map(str, axes))
676 yield helper.make_node(
677 'ReduceSum', [name, name_axes], [self._onnx_name()], keepdims=1,
678 name='ReduceSum%s_%d' % (s_axes, id(self)))
680 def _to_onnx_mul(self, data, verbose=False, **kwargs):
681 self._check_inputs_(2)
682 inp1 = self.inputs[0]
683 inp2 = self.inputs[1]
684 m1 = self._get_data(data, inp1)
685 m2 = self._get_data(data, inp2)
686 yield helper.make_node('Mul', [m1, m2], [self._onnx_name()])
688 def _to_onnx_batch_dot(self, names, opset, verbose=False, **kwargs): # pylint: disable=R0914
689 self._check_inputs_(2)
690 self._check_onnx_opset_(opset, 13)
691 inp1, inp2 = self.inputs[:2] # pylint: disable=W0632
692 name1 = self._get_data(names, inp1)
693 name2 = self._get_data(names, inp2)
695 batch_axes = self.kwargs['batch_axes']
696 keep_axes = self.kwargs['keep_axes']
697 sum_axes = self.kwargs['sum_axes']
698 left = self.kwargs['left']
699 right = self.kwargs['right']
700 root = self._onnx_name()
702 def return_name_one():
703 name_one = root + "_1"
704 return name_one, numpy_helper.from_array(
705 numpy.array([1], dtype=numpy.int64), name=name_one)
707 name_one = None
708 name_shape1 = root + "_shape1"
709 name_shape2 = root + "_shape2"
710 concat_left = []
711 concat_right = []
712 yield helper.make_node('Shape', [name1], [name_shape1])
713 yield helper.make_node('Shape', [name2], [name_shape2])
715 if len(batch_axes) > 0:
716 name_batch_axes = root + "_batch_axes"
717 yield numpy_helper.from_array(
718 numpy.array(batch_axes, dtype=numpy.int64), name=name_batch_axes)
720 if len(sum_axes) > 0:
721 name_sum_axes = root + "_sum_axes"
722 yield numpy_helper.from_array(
723 numpy.array(sum_axes, dtype=numpy.int64), name=name_sum_axes)
725 # dim0 = int(numpy.prod([m1.shape[i] for i in batch_axes]))
726 # dim0b = int(numpy.prod([m2.shape[i] for i in batch_axes]))
727 if len(batch_axes) > 1:
728 name_dim0 = root + "_dim0"
729 name_dim0b = root + "_dim0b"
730 name_dim0g = name_dim0 + 'g'
731 name_dim0bg = name_dim0b + 'g'
732 concat_left.append(name_dim0)
733 concat_right.append(name_dim0b)
734 yield helper.make_node(
735 'Gather', [name_shape1, name_batch_axes], [name_dim0g])
736 yield helper.make_node(
737 'Gather', [name_shape2, name_batch_axes], [name_dim0bg])
738 yield helper.make_node(
739 'ReduceProd', [name_dim0g], [name_dim0], keepdims=1)
740 yield helper.make_node(
741 'ReduceProd', [name_dim0bg], [name_dim0b], keepdims=1)
742 elif len(batch_axes) == 1:
743 name_dim0g = root + "_dim0g"
744 name_dim0bg = root + "_dim0bg"
745 name_dim0 = name_dim0g
746 name_dim0b = name_dim0bg
747 concat_left.append(name_dim0)
748 concat_right.append(name_dim0b)
749 yield helper.make_node(
750 'Gather', [name_shape1, name_batch_axes], [name_dim0g])
751 yield helper.make_node(
752 'Gather', [name_shape2, name_batch_axes], [name_dim0bg])
753 else:
754 if name_one is None:
755 name_one, cst_init = return_name_one()
756 yield cst_init
757 name_dim0 = name_one
758 name_dim0b = name_one
759 concat_left.append(name_dim0)
760 concat_right.append(name_dim0b)
762 # dimb = int(-1 if keep_axes is None else numpy.prod(
763 # [m1.shape[i] for i in keep_axes]))
764 if keep_axes in (-1, None) or len(keep_axes) == 0:
765 name_dimb = root + "__1"
766 concat_left.append(name_dimb)
767 concat_right.append(name_dimb)
768 yield numpy_helper.from_array(
769 numpy.array([-1], dtype=numpy.int64), name=name_dimb)
770 elif len(keep_axes) == 1:
771 name_keep_axes = root + "_keep_axes"
772 name_dimb = root + "_dimb"
773 name_dimbg = name_dimb
774 concat_left.append(name_dimb)
775 concat_right.append(name_dimb)
776 yield numpy_helper.from_array(
777 numpy.array(keep_axes, dtype=numpy.int64), name=name_keep_axes)
778 yield helper.make_node(
779 'Gather', [name_shape1, name_keep_axes], [name_dimbg])
780 else:
781 name_keep_axes = root + "_keep_axes"
782 name_dimb = root + "_dimb"
783 name_dimbg = name_dimb + 'g'
784 concat_left.append(name_dimb)
785 concat_right.append(name_dimb)
786 yield numpy_helper.from_array(
787 numpy.array(keep_axes, dtype=numpy.int64), name=name_keep_axes)
788 yield helper.make_node(
789 'Gather', [name_shape1, name_keep_axes], [name_dimbg])
790 yield helper.make_node(
791 'ReduceProd', [name_dimbg], [name_dimb], keepdims=1)
793 # dim1 = int(numpy.prod([m1.shape[i] for i in sum_axes]))
794 # dim2 = int(numpy.prod([m2.shape[i] for i in sum_axes]))
796 if len(sum_axes) == 0:
797 if name_one is None:
798 name_one, cst_init = return_name_one()
799 yield cst_init
800 name_dim1 = name_one
801 name_dim2 = name_one
802 concat_left.append(name_dim1)
803 concat_right.append(name_dim2)
804 elif len(sum_axes) == 1:
805 name_dim1 = root + "_dim1"
806 name_dim2 = root + "_dim2"
807 name_dim1g = name_dim1
808 name_dim2g = name_dim2
809 concat_left.append(name_dim1)
810 concat_right.append(name_dim2)
811 yield helper.make_node(
812 'Gather', [name_shape1, name_sum_axes], [name_dim1g])
813 yield helper.make_node(
814 'Gather', [name_shape2, name_sum_axes], [name_dim2g])
815 else:
816 name_dim1 = root + "_dim1"
817 name_dim2 = root + "_dim2"
818 name_dim1g = name_dim1 + 'g'
819 name_dim2g = name_dim2 + 'g'
820 concat_left.append(name_dim1)
821 concat_right.append(name_dim2)
822 yield helper.make_node(
823 'Gather', [name_shape1, name_sum_axes], [name_dim1g])
824 yield helper.make_node(
825 'Gather', [name_shape2, name_sum_axes], [name_dim2g])
826 yield helper.make_node(
827 'ReduceProd', [name_dim1g], [name_dim1], keepdims=1)
828 yield helper.make_node(
829 'ReduceProd', [name_dim2g], [name_dim2], keepdims=1)
831 batch_kind = self.get_dot_kind()
832 if batch_kind in ('11', 'N1', 'N1'):
833 # *shape1, *shape2
834 name_minus_one = root + "__01"
835 yield numpy_helper.from_array(
836 numpy.array([-1], dtype=numpy.int64), name=name_minus_one)
837 name_agg_shape1_2 = root + f"_resh1_{batch_kind}"
838 name_agg_shape2_2 = root + f"_resh2_{batch_kind}"
839 yield helper.make_node(
840 'Concat', [name_minus_one, name_dim1], [name_agg_shape1_2], axis=0)
841 yield helper.make_node(
842 'Concat', [name_minus_one, name_dim2], [name_agg_shape2_2], axis=0)
844 # m1sh = m1.reshape((-1, dim1))
845 # m2sh = m2.reshape((-1, dim2))
846 name_agg1_2 = root + "_aresh1"
847 name_agg2_2 = root + "_aresh2"
848 yield helper.make_node('Reshape', [name1, name_agg_shape1_2], [name_agg1_2])
849 yield helper.make_node('Reshape', [name2, name_agg_shape2_2], [name_agg2_2])
851 # dot = gemm(m1sh, m2sh, False, True)
852 name_dot = root + "_gemm"
853 yield helper.make_node(
854 'Gemm', [name_agg1_2, name_agg2_2], [name_dot],
855 alpha=1., beta=0., transA=0, transB=1)
856 else:
857 # *shape1, *shape2
858 name_agg_shape1 = root + "_resh1"
859 name_agg_shape2 = root + "_resh2"
860 yield helper.make_node(
861 'Concat', concat_left, [name_agg_shape1], axis=0)
862 yield helper.make_node(
863 'Concat', concat_right, [name_agg_shape2], axis=0)
865 # m1sh = m1.reshape((dim0, dimb, dim1))
866 # m2sh = m2.reshape((dim0b, dimb, dim2))
867 name_agg1 = root + "_aresh1"
868 name_agg2 = root + "_aresh2"
869 yield helper.make_node('Reshape', [name1, name_agg_shape1], [name_agg1])
870 yield helper.make_node('Reshape', [name2, name_agg_shape2], [name_agg2])
872 # dot = m1sh @ numpy.transpose(m2sh, (0, 2, 1))
873 name_agg2_tr = root + "_aresh2_tr"
874 yield helper.make_node(
875 'Transpose', [name_agg2], [name_agg2_tr], perm=[0, 2, 1],
876 name=f"Transpose021_{id(self)}")
878 name_dot = root + "_dot"
879 yield helper.make_node(
880 'MatMul', [name_agg1, name_agg2_tr], [name_dot])
882 # new_shape = ([max(m1.shape[i], m2.shape[i]) for i in batch_axes] +
883 # [m1.shape[i] for i in left if i not in batch_axes] +
884 # [m2.shape[i] for i in right if i not in batch_axes])
885 concat_final = []
886 if len(batch_axes) > 0:
887 name_max_dim = root + "_max_dim"
888 concat_final.append(name_max_dim)
889 yield helper.make_node(
890 'Max', [name_dim0g, name_dim0bg], [name_max_dim])
892 left_set = list(sorted(set(left) - (set(batch_axes) & set(left))))
893 if len(left_set) > 0:
894 name_left_dim = root + "_left_dim"
895 name_left_set = root + "_left_set"
896 yield numpy_helper.from_array(
897 numpy.array(left_set, dtype=numpy.int64), name=name_left_set)
898 yield helper.make_node(
899 'Gather', [name_shape1, name_left_set], [name_left_dim])
900 concat_final.append(name_left_dim)
902 right_set = list(sorted(set(right) - (set(batch_axes) & set(right))))
903 if len(right_set) > 0:
904 name_right_dim = root + "_right_dim"
905 name_right_set = root + "_right_set"
906 yield numpy_helper.from_array(
907 numpy.array(right_set, dtype=numpy.int64), name=name_right_set)
908 yield helper.make_node(
909 'Gather', [name_shape2, name_right_set], [name_right_dim])
910 concat_final.append(name_right_dim)
912 name_new_shape = root + '_new_shape'
913 diff = (
914 self.full_dim -
915 (len(batch_axes) + len(left_set) + len(right_set)))
916 if diff > 0:
917 names_ones = root + "_ones"
918 yield numpy_helper.from_array(
919 numpy.array([1 for i in range(diff)], dtype=numpy.int64),
920 name=names_ones)
921 concat_final.append(names_ones)
923 yield helper.make_node(
924 'Concat', concat_final, [name_new_shape], axis=0)
926 name_final = root + '_final'
927 yield helper.make_node(
928 'Reshape', [name_dot, name_new_shape], [name_final])
930 def to_onnx(self, names, opset=None, verbose=False, **kwargs):
931 """
932 Converts this node into ONNX. Enumerates all ONNX node
933 which participate to the conversion. The last one
934 is the final output.
936 :param names: dictionary where to find already converted name
937 :param opset: opset
938 :param verbose: prints out intermediate results
939 :param kwargs: additional parameter for the conversion
940 :return: output
941 """
942 if opset is None:
943 opset = __max_supported_opset__ # pragma: no cover
944 if verbose:
945 print()
946 print("to_onnx %r (%s) opset=%r." % (
947 self.name,
948 ", ".join(map(lambda s: str(id(s)), self.inputs)),
949 opset))
951 method_name = f"_to_onnx_{self.name}"
952 meth = getattr(self, method_name, None)
953 if meth is None:
954 if self.name.endswith("_mm"):
955 raise NotImplementedError(
956 "to_onnx not implemented for %r."
957 "You should call method simplify_mm_nodes "
958 "to remove it." % self.name)
959 raise NotImplementedError(
960 f"to_onnx not implemented for {self.name!r}.")
961 for node in meth(names, verbose=verbose, opset=opset, **kwargs):
962 if hasattr(node, 'output'):
963 names[id(self)] = node.output[0]
964 if verbose:
965 print("+ OP %r -- (%s - %d)" %
966 (node.output[0], self.name, id(self)))
967 elif verbose:
968 # Initializer
969 print("+ CT %r -- (%s - %d)" %
970 (node.name, self.name, id(self)))
971 yield node
973 def get_dot_kind(self):
974 """
975 Every matrix multiplication can be either:
977 * a simple multiplication (`M`) (undetected)
978 * a 2D matrix multiplication (`11`)
979 * a broadcasted matrix multiplication (`N1` or `1N`)
980 * a batch matrix multiplication (`NN`)
982 This method returns which kind it is.
983 """
984 batch_axes = self.kwargs['batch_axes']
985 # keep_axes = self.kwargs['keep_axes']
986 # sum_axes = self.kwargs['sum_axes']
987 # left = self.kwargs['left']
988 # right = self.kwargs['right']
989 info = self._info
990 row_left = info['i_row']
991 row_right = info['i_row2']
993 batch_left = [row_left[k] for k in batch_axes]
994 batch_right = [row_right[k] for k in batch_axes]
995 n_left = len(batch_left) > 0 and max(batch_left) == 2
996 n_right = len(batch_right) > 0 and max(batch_right) == 2
997 return f"{'N' if n_left else '1'}{'N' if n_right else '1'}"
1000class GraphEinsumSubOp:
1001 """
1002 Class gathering all nodes produced to explicit einsum
1003 operators.
1005 :param letters: list of distinct letters
1006 :param mat: matrix, see @see fn analyse_einsum_equation
1007 :param lengths: lengths of every input
1008 :param duplicates: see @see fn analyse_einsum_equation
1009 """
1011 def __init__(self, letters, mat, lengths, duplicates):
1012 self._nodes = {}
1013 self._mark = {}
1014 self._ops = []
1015 self._inputs = {}
1016 self.last_op = None
1017 self.last_added_op = None
1018 self.metadata = dict(
1019 letters=letters, mat=mat, lengths=lengths,
1020 mat0=mat.copy(), duplicates=duplicates)
1022 def append(self, op):
1023 """
1024 Adds one input or result.
1026 :param op: integer (an input) or an instance of @see cl EinsumSubOp.
1027 :return: op or None if op is an integer
1028 """
1029 if isinstance(op, int):
1030 if op in self._nodes:
1031 raise RuntimeError( # pragma: no cover
1032 "Key %d already added." % op)
1033 self._nodes[op] = op
1034 self.last_added_op = op
1035 self._inputs[op] = op
1036 return None
1037 if isinstance(op, EinsumSubOp):
1038 if op in self._nodes:
1039 raise RuntimeError( # pragma: no cover
1040 "Key %d already added, op=%r." % (id(op), op))
1041 self._nodes[id(op)] = op
1042 self._ops.append(op)
1043 self.last_added_op = op
1044 return op
1045 raise TypeError( # pragma: no cover
1046 f"Unexpected type {type(op)!r}.")
1048 def mark_last_node(self):
1049 """
1050 Marks the last node as the final output.
1051 """
1052 if self.last_added_op is None:
1053 raise RuntimeError("last_added_op is None.") # pragma: no cover
1054 self.mark(-1, self.last_added_op)
1056 def mark(self, i, op):
1057 """
1058 Marks one input or result as an intermediate result
1059 after a full einsum step.
1061 :param op: integer (an input) or an instance of @see cl EinsumSubOp.
1062 """
1063 if not isinstance(i, int):
1064 raise TypeError( # pragma: no cover
1065 f"i must an integer not {type(i)!r}.")
1066 if i != -1 and i not in self._inputs:
1067 raise RuntimeError( # pragma: no cover
1068 "Input %d was not registered in %r." % (i, self._inputs))
1069 if isinstance(op, EinsumSubOp):
1070 if id(op) not in self._nodes:
1071 raise RuntimeError( # pragma: no cover
1072 "Key %d not found, op=%r." % (id(op), op))
1073 self._mark[i] = op
1074 self._mark[id(op)] = i
1075 self.last_op = op
1076 else:
1077 raise TypeError( # pragma: no cover
1078 f"Unexpected type {type(i)!r}.")
1080 def __iter__(self):
1081 "Iterates on nodes."
1082 for op in self._ops:
1083 yield op
1085 def to_dot(self, **kwargs):
1086 """
1087 Produces a graph in :epkg:`dot`.
1089 :param kwargs: additional graph option
1090 :return: string
1091 """
1092 options = {
1093 'orientation': 'portrait',
1094 'ranksep': '0.25',
1095 'nodesep': '0.05',
1096 'width': '0.5',
1097 'height': '0.1',
1098 'size': '5',
1099 'node': '[shape=record]',
1100 }
1101 options.update(kwargs)
1103 def d2s(d):
1104 it = []
1105 for k, v in sorted(d.items()):
1106 it.append(f"{k}={v}")
1107 return " ".join(it)
1109 def d2sd(d):
1110 it = []
1111 for k, v in sorted(d.items()):
1112 if len(v) > 1:
1113 it.append(f"{k}={','.join(map(str, v))}")
1114 return " ".join(it)
1116 rows = ["digraph{"]
1117 for k, v in options.items():
1118 if isinstance(v, str) and "[" in v:
1119 rows.append(f"{k} {v};")
1120 else:
1121 rows.append(f"{k}={v};")
1122 for k, v in self._nodes.items():
1123 if isinstance(v, int):
1124 let = [(r, self.metadata['letters'][i])
1125 for i, r in enumerate(self.metadata['mat0'][v])
1126 if r != -1]
1127 dup = self.metadata['duplicates'][v]
1128 if dup is None:
1129 dup = ""
1130 else:
1131 dup = f" - {d2sd(dup)}"
1132 let.sort()
1133 letters = "".join(_[1] for _ in let)
1134 lab = "input %d\\\\n%s\\\\n%s%s" % (
1135 v, letters, str(self.metadata['mat0'][v]), dup)
1136 sk = v
1137 extended_lab = ""
1138 else:
1139 lab = f"{v.name}\\\\n{d2s(v.kwargs)}"
1140 sk = id(v)
1141 extended_lab = v.dot_label()
1142 if extended_lab:
1143 extended_lab = "\\\\n" + extended_lab
1145 if sk in self._mark and isinstance(self._mark[sk], int):
1146 la = self._mark[sk]
1147 lab = lab.replace("\\\\n", " - I%d\\\\n" % la)
1148 s = ('%d [label="%s%s" style=filled '
1149 'fillcolor=red];' % (k, lab, extended_lab))
1150 else:
1151 s = '%d [label="%s%s"];' % (k, lab, extended_lab)
1152 rows.append(s)
1153 if not hasattr(v, 'inputs'):
1154 continue
1155 for i in v.inputs:
1156 vid = i if isinstance(i, int) else id(i)
1157 s = "%d -> %d;" % (vid, k)
1158 rows.append(s)
1159 rows.append("}")
1160 return "\n".join(rows)
1162 def apply_sequence(self, *inputs, verbose=False, **kwargs):
1163 """
1164 Applies a sequence of operations on a list of inputs.
1166 :param inputs: inputs:
1167 :param verbose: prints out intermediate results
1168 :param kwargs: additional parameters,
1169 see :meth:`apply
1170 <mlprodict.testing.einsum.einsum_impl_classes.EinsumSubOp.apply>`.
1171 :return: output
1172 """
1173 if verbose:
1174 print('######### apply_sequence')
1175 data = {i: inp for i, inp in enumerate(inputs)}
1176 last = None
1177 for op in self:
1178 last = op.apply(data, verbose=verbose, **kwargs)
1179 if last is None:
1180 raise RuntimeError( # pragma: no cover
1181 "Sequence of operations is empty.")
1182 return last
1184 def clean_unused_nodes(self, verbose=False):
1185 """
1186 Cleans nodes with unused outputs.
1188 :param verbose: display intermediate information
1189 """
1191 def iteration(it):
1192 # Walks through all nodes.
1193 is_used = {}
1194 for node in self._ops:
1195 if not isinstance(node, EinsumSubOp):
1196 continue
1197 if id(node) not in is_used:
1198 is_used[id(node)] = []
1199 for inp in node.inputs:
1200 if not isinstance(inp, EinsumSubOp):
1201 continue
1202 idn = id(inp)
1203 if idn not in is_used:
1204 is_used[idn] = []
1205 is_used[idn].append(id(node))
1207 # Remove unused nodes.
1208 removed = []
1209 for k, v in is_used.items():
1210 if len(v) == 0:
1211 removed.append(k)
1212 removed = set(removed)
1213 i_rem = []
1214 for i, op in enumerate(self._ops):
1215 if not isinstance(op, EinsumSubOp):
1216 continue
1217 if id(op) in removed and id(op) not in self._mark:
1218 i_rem.append((i, id(op)))
1219 for i, idn in reversed(i_rem):
1220 if verbose:
1221 print("[GraphEinsumSubOp.clean_nodes] remove node "
1222 "i=%d: %d - id=%d" % (it, i, idn))
1223 del self._ops[i]
1224 del self._nodes[idn]
1225 return len(i_rem) > 0
1227 it = 1
1228 while iteration(it):
1229 it += 1
1231 self.last_op = None
1232 self.last_added_op = None
1234 def simplify_mm_nodes(self, verbose=False):
1235 """
1236 Node name suffixed by `mm` are an artifact to keep
1237 the graph consistent while building it. They can
1238 now be replaced by the equivalent node without suffix `mm`.
1240 :param verbose: display intermediate information
1241 """
1242 for op in self:
1243 if not isinstance(op, EinsumSubOp):
1244 continue
1245 if op.name.endswith('_mm'):
1246 if verbose:
1247 print("[GraphEinsumSubOp.simplify_mm_nodes] node %r"
1248 " - id=%d" % (op.name, id(op)))
1249 if len(op.inputs) != 2:
1250 raise RuntimeError( # pragma: no cover
1251 "Expecting 2 inputs for node %r not %r id=%r." % (
1252 op.name, len(op.inputs), id(op)))
1253 op.name = op.name[:-3]
1254 op.inputs = op.inputs[:1]
1256 def _get_forward_nodes(self):
1257 """
1258 Returns the forward nodes.
1259 """
1260 forward = {}
1261 for op in self:
1262 if isinstance(op, int):
1263 continue
1264 for inp in op.inputs:
1265 key = inp if isinstance(inp, int) else id(inp)
1266 if key in forward:
1267 forward[key].append(op)
1268 else:
1269 forward[key] = [op]
1270 return forward
1272 def _pprint_forward(self):
1273 rows = []
1274 for op in self:
1275 line = "%r <- %s(%s)" % (
1276 id(op), op.name,
1277 ", ".join(map(str, [id(_) for _ in op.inputs])))
1278 rows.append(line)
1279 return "\n".join(rows)
1281 def _replace_node_sequence(self, added, deleted):
1282 """
1283 Removes a sequence of nodes. The method does not check
1284 that the graph remains consistent.
1285 """
1286 forward = self._get_forward_nodes()
1287 key = id(deleted[-1])
1288 if key not in forward:
1289 raise RuntimeError( # pragma: no cover
1290 "Key {} missing in all forward nodes (other keys {}), "
1291 "all keys:\n{}".format(
1292 key, [id(_) for _ in deleted],
1293 self._pprint_forward()))
1295 # deletion
1296 mark_input = None
1297 for d in deleted:
1298 del self._nodes[id(d)]
1299 if id(d) in self._mark:
1300 del self._mark[id(d)]
1301 dels = []
1302 for k, v in self._mark.items():
1303 if id(v) == id(d):
1304 mark_input = k
1305 dels.append(k)
1306 if len(dels) != 1:
1307 raise RuntimeError( # pragma: no cover
1308 "Input %d has more than one marked operator "
1309 "(%r)." % (id(d), dels))
1310 del self._mark[dels[0]]
1312 dels = set(id(o) for o in deleted)
1313 rem = []
1314 for i, op in enumerate(self._ops):
1315 if id(op) in dels:
1316 rem.append(i)
1317 if len(rem) != len(deleted):
1318 raise RuntimeError( # pragma: no cover
1319 f"Mismatched length {rem!r}, {dels!r}, len={len(deleted)!r}.")
1320 for i in reversed(rem):
1321 del self._ops[i]
1322 self.last_add_op = None
1324 # insertion
1325 if added is not None:
1326 self._ops.insert(rem[0], added)
1327 self._nodes[id(added)] = added
1328 for op in forward[key]:
1329 new_inputs = list(op.inputs)
1330 for i in range(len(op.inputs)): # pylint: disable=C0200
1331 if id(op.inputs[i]) == key:
1332 new_inputs[i] = added
1333 op.inputs = tuple(new_inputs)
1334 if mark_input is not None:
1335 self.mark(mark_input, added)
1336 else:
1337 inps = deleted[0].inputs
1338 if len(inps) != 1:
1339 raise RuntimeError( # pragma: no cover
1340 "More than one input. Call another method.")
1341 inp = inps[0]
1342 for op in forward[key]:
1343 new_inputs = list(op.inputs)
1344 for i in range(len(op.inputs)): # pylint: disable=C0200
1345 if id(op.inputs[i]) == key:
1346 new_inputs[i] = inp
1347 op.inputs = tuple(new_inputs)
1348 if mark_input is not None:
1349 self.mark(mark_input, inp)
1351 def remove_duplicate_transpose(self, verbose=False):
1352 """
1353 Removes consecutive transpose by merging them.
1355 :param verbose: display intermediate information
1356 """
1357 modif = 1
1358 while modif > 0:
1359 modif = 0
1360 candidates = []
1361 forward = self._get_forward_nodes()
1362 for op in self:
1363 if op.name == "transpose":
1364 inp = op.inputs[0]
1365 if (isinstance(inp, EinsumSubOp) and
1366 inp.name == 'transpose' and
1367 len(forward[id(inp)]) == 1):
1368 candidates.append(op)
1370 if len(candidates) > 0:
1371 modif = 1
1372 # Not efficient to take the first one and to
1373 # start again but the graph should not be too big.
1374 cand = candidates[0]
1375 op2 = cand
1376 op1 = cand.inputs[0]
1377 perm1 = op1.kwargs['perm']
1378 perm2 = op2.kwargs['perm']
1379 if len(perm1) != len(perm2):
1380 raise RuntimeError( # pragma: no cover
1381 "Transposition should have the same length "
1382 "%r, %r." % (perm1, perm2))
1383 perm = list(perm1)
1384 for i in range(len(perm)): # pylint: disable=C0200
1385 perm[i] = perm1[perm2[i]]
1386 if list(range(len(perm))) == perm:
1387 # identity, everything needs to be removed
1388 new_op = None
1389 else:
1390 new_op = op2.__class__(
1391 op2.full_dim, op2.name, op1.inputs[0],
1392 perm=tuple(perm))
1393 self._replace_node_sequence(new_op, [op1, op2])
1394 if verbose:
1395 print( # pragma: no cover
1396 "[GraphEinsumSubOp.remove_duplicate_transpose] remove nodes %r"
1397 " - id=%d,%d + %d perm1=%r perm2=%r -> perm=%r" % (
1398 op2.name, id(op1), id(op2),
1399 id(new_op) if new_op is not None else -1,
1400 perm1, perm2, perm))
1402 def to_onnx(self, output, *inputs, dtype=None, verbose=False,
1403 opset=None, **kwargs):
1404 """
1405 Converts the graph into ONNX.
1407 :param output: output name
1408 :param inputs: input names
1409 :param dtype: type used for all operators
1410 :param opset: desired opset, None for the last one
1411 :param verbose: display intermediate operators
1412 :param kwargs: additional parameter to use when building
1413 the ONNX graph, list of supported parameters:
1414 *name*, *ir_version*, *producer_name*,
1415 *producer_version*, *initializer*
1416 :return: ONNX graph
1418 Not all graphs can be converted into ONNX. Only graphs produced
1419 with `strategy='numpy'` can be converted otherwise the following
1420 error shows up:
1422 ::
1424 NotImplementedError: to_onnx not implemented for 'matmul'.
1425 """
1426 from ...onnx_tools.optim import onnx_remove_node_unused
1428 # inputs
1429 if opset is None:
1430 opset = __max_supported_opset__
1431 if verbose:
1432 print("[GraphEinsumSubOp.to_onnx] %r -> %s opset=%r "
1433 "dtype=%r" % (inputs, output, opset, dtype))
1434 onx_inputs = []
1435 proto = guess_proto_dtype(
1436 numpy.float32 if dtype is None else dtype)
1437 lengths = self.metadata['lengths']
1438 names = {}
1439 for inp, le in zip(inputs, lengths):
1440 if isinstance(inp, tuple):
1441 name, typ = inp
1442 if le != len(typ.shape):
1443 raise ValueError( # pragma: no cover
1444 "Irreconcialable shapes for input %r: "
1445 "%r != len(%r)." % (name, le, typ.shape))
1446 proto = guess_proto_dtype(guess_numpy_type(typ))
1447 onx_inputs.append(
1448 helper.make_tensor_value_info(name, proto, typ.shape))
1449 names[len(names)] = name
1450 else:
1451 onx_inputs.append(
1452 helper.make_tensor_value_info(
1453 inp, proto, [None for i in range(le)]))
1454 names[len(names)] = inp
1456 # output
1457 onx_output = helper.make_tensor_value_info(
1458 output, proto, [None for i in range(lengths[-1])])
1460 # nodes
1461 nodes = []
1462 inits = []
1463 if "initializer" in kwargs:
1464 inits.extend(kwargs['initializer'])
1465 for op in self:
1466 for onx_node in op.to_onnx(names, verbose=verbose, opset=opset):
1467 if hasattr(onx_node, 'output'):
1468 nodes.append(onx_node)
1469 else:
1470 inits.append(onx_node)
1472 # last node
1473 last_node = nodes[-1]
1474 nodes.append(helper.make_node(
1475 'Identity', [last_node.output[0]], [output]))
1477 # Builds the graph
1478 model = helper.make_model(
1479 opset_imports=[helper.make_operatorsetid('', opset)],
1480 ir_version=kwargs.get('ir_version', get_ir_version(opset)),
1481 producer_name=kwargs.get('producer_name', 'mlprodict'),
1482 producer_version=kwargs.get('producer_version', "0.0.dev"),
1483 graph=helper.make_graph(
1484 name=kwargs.get('name', 'einsum'),
1485 inputs=onx_inputs, outputs=[onx_output],
1486 initializer=inits, nodes=nodes))
1488 return onnx_remove_node_unused(model)