Coverage for mlprodict/testing/einsum/einsum_impl_classes.py: 96%

889 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# pylint: disable=C0302 

2""" 

3@file 

4@brief Classes representing the sequence of matrix operations to 

5implement einsum computation. 

6""" 

7import numpy 

8from onnx import helper, numpy_helper 

9from ...onnx_tools.onnx2py_helper import guess_proto_dtype 

10from ...npy.xop_variable import guess_numpy_type 

11from ... import __max_supported_opset__, get_ir_version 

12from .blas_lapack import gemm_dot 

13from .einsum_impl_ext import ( 

14 numpy_extended_dot, numpy_diagonal, 

15 _numpy_extended_dot_equation, 

16 numpy_extended_dot_python, 

17 numpy_extended_dot_matrix) 

18 

19 

20def single_axes(axes): 

21 """ 

22 *axes* contains positive values, then it is the position 

23 of this axis in the original matrix, otherwise it is -1 

24 meaning this axis is an added single dimension to align 

25 all the dimensions based on the einsum equation. 

26 

27 :param axes: axes described above 

28 :return: list of integer in set `{1, 2}`, 1 for 

29 a single axis, 2 otherwise 

30 """ 

31 if axes is None: 

32 return axes 

33 return [(1 if a == -1 else 2) for a in axes] 

34 

35 

36class EinsumSubOp: 

37 """ 

38 Defines a sub operation used in Einsum decomposition. 

39 

40 :param name: name (reshape, transpose, reduce_sum, matmul, id, 

41 squeeze, diagonal, mul, batch_dot) 

42 :param inputs: inputs 

43 :param kwargs: arguments 

44 

45 Operator suffixed by `_mm` (*transpose_mm*, *reduce_sum_mm*) 

46 are equivalent to the same operator without the suffix 

47 but takes two inputs and only changes the first one. 

48 

49 Attributes `_info` summarizes the known information 

50 about dimensions. Many of them are empty because inserted. 

51 Value `1` means it was the case, `2` means it is a plain dimension. 

52 """ 

53 _allowed = {'expand_dims', 'transpose', 'reduce_sum', 'matmul', 'id', 

54 'squeeze', 'diagonal', 'mul', 'batch_dot', 

55 'transpose_mm', 'reduce_sum_mm'} 

56 

57 def __init__(self, full_dim, name, *inputs, **kwargs): 

58 self.full_dim = full_dim 

59 self.name = name 

60 self.inputs = inputs 

61 self.kwargs = kwargs 

62 self._info = {} 

63 if name not in EinsumSubOp._allowed: 

64 raise ValueError( 

65 f"Unexpected name {name!r}. It should be in {EinsumSubOp._allowed!r}.") 

66 if len(inputs) not in (1, 2): 

67 raise RuntimeError( 

68 f"Inputs must contains 1 or 2 inputs not {len(inputs)}.") 

69 if name == 'matmul' and len(inputs) != 2: 

70 raise RuntimeError( 

71 "Inputs must contains 2 inputs not %d for operator 'matmul'." 

72 "" % len(inputs)) 

73 for i, inp in enumerate(inputs): 

74 if not isinstance(inp, (int, EinsumSubOp)): 

75 raise TypeError( 

76 "Input %d has type %r, int or EinsumSubOp is expected." 

77 "" % (i, type(inp))) 

78 self._check_() 

79 

80 def _check_(self): 

81 if self.name == 'transpose': 

82 self._check_arg_('perm', tuple) 

83 perm = self.kwargs['perm'] 

84 if len(perm) != len(set(perm)): 

85 raise RuntimeError( # pragma: no cover 

86 f"perm has duplicated values {perm!r} (name={self.name!r}).") 

87 if list(perm) == list(range(len(perm))): 

88 raise ValueError( # pragma: no cover 

89 f"Transpose = identity perm={perm}. It must be removed.") 

90 elif self.name == 'matmul': 

91 self._check_arg_('axes', tuple) 

92 self._check_arg_('left', tuple) 

93 self._check_arg_('right', tuple) 

94 axes = self.kwargs['axes'] 

95 left = self.kwargs['left'] 

96 right = self.kwargs['right'] 

97 for a in axes: 

98 if a in left and a in right: 

99 raise RuntimeError( # pragma: no cover 

100 "One axis belongs to every set (axes, left, right). " 

101 "axes=%r, left=%r, right=%r." % (axes, left, right)) 

102 

103 def __repr__(self): 

104 inps = ", ".join(map(str, self.inputs)) 

105 kw = ", ".join(f"{k}={w!r}" for k, w in self.kwargs.items()) 

106 m = f"{self.__class__.__name__}({self.name!r}, {inps}, {kw})" 

107 return m 

108 

109 def dot_label(self): 

110 """ 

111 Displays some informations useful to understand the operator. 

112 """ 

113 if self.name == "matmul": 

114 ndim = self.kwargs['ndim'] 

115 axes = self.kwargs['axes'] 

116 left = self.kwargs['left'] 

117 right = self.kwargs['right'] 

118 eq = _numpy_extended_dot_equation(ndim, ndim, axes, left, right) 

119 eq = eq.replace(">", "\\\\>") 

120 return "~" + eq 

121 return None 

122 

123 def _check_arg_(self, name, typ, empty=False): 

124 if name not in self.kwargs: 

125 raise RuntimeError( # pragma: no cover 

126 f"Parameter {name!r} not found for operator {self.name!r}.") 

127 if empty and self.kwargs[name] is None: 

128 return 

129 if not isinstance(self.kwargs[name], typ): 

130 raise TypeError( # pragma: no cover 

131 "Unexpected type %r for parameter %r and parameter %r." 

132 "" % (type(self.kwargs[name]), name, self.name)) 

133 

134 def _check_row_(self, row, inp=False, verbose=False): 

135 """ 

136 Checks input or output is valid. 

137 """ 

138 if verbose: 

139 if inp: 

140 print('<<' if inp else '>>', self.name, row, self.kwargs) 

141 else: 

142 print('<<' if inp else '>>', self.name, row) 

143 

144 def _compute_output_row_id(self, row, row2=None, ab=False, verbose=False): 

145 if ab: 

146 raise RuntimeError("ab option not allowed.") # pragma: no cover 

147 self._check_row_(row, True, verbose=verbose) 

148 row[:] = row2[:] 

149 self._check_row_(row, verbose=verbose) 

150 

151 def _compute_output_row_transpose(self, row, row2=None, ab=False, verbose=False): 

152 if ab: 

153 self._compute_output_row_transpose(row2, verbose=verbose) 

154 return 

155 self._check_row_(row, True, verbose=verbose) 

156 self._check_arg_('perm', tuple) 

157 if len(self.kwargs['perm']) != len(row): 

158 raise RuntimeError( # pragma: no cover 

159 f"Unexpected permutation {self.kwargs['perm']!r} (row={row!r}).") 

160 perm = self.kwargs['perm'] 

161 cpy = row.copy() 

162 for i, p in enumerate(perm): 

163 row[i] = cpy[p] 

164 self._check_row_(row, verbose=verbose) 

165 

166 def _compute_output_row_transpose_mm(self, row, row2=None, ab=False, verbose=False): 

167 if not ab: 

168 raise RuntimeError("ab must be True.") # pragma: no cover 

169 self._check_row_(row, True, verbose=verbose) 

170 if row2 is None: 

171 raise RuntimeError( # pragma: no cover 

172 "transpose_mm expects a second input.") 

173 self._compute_output_row_transpose(row, row2=None, verbose=verbose) 

174 

175 def _compute_output_row_expand_dims(self, row, row2=None, ab=False, verbose=False): 

176 if ab: 

177 raise RuntimeError("ab option not allowed.") # pragma: no cover 

178 self._check_row_(row, True, verbose=verbose) 

179 self._check_arg_('axes', tuple) 

180 axes = self.kwargs['axes'] 

181 for axis in axes: 

182 if not isinstance(axis, tuple): 

183 raise TypeError( # pragma: no cover 

184 "Parameter axes of expand_dims should be a tuple of " 

185 "tuple, axes=%r." % axes) 

186 if row[axis[1]] != -1: 

187 raise RuntimeError( # pragma: no cover 

188 "Dimension should be -1 in row %r axis=%r." % ( 

189 row, self.kwargs['axis'])) 

190 self._check_row_(row, verbose=verbose) 

191 

192 def _compute_output_row_reduce_sum(self, row, row2=None, ab=False, verbose=False): 

193 if ab: 

194 raise RuntimeError("ab option not allowed.") # pragma: no cover 

195 self._check_row_(row, True, verbose=verbose) 

196 self._check_arg_('axes', tuple) 

197 for a in self.kwargs['axes']: 

198 row[a] = -1 

199 self._check_row_(row, verbose=verbose) 

200 

201 def _compute_output_row_reduce_sum_mm(self, row, row2=None, ab=False, verbose=False): 

202 if not ab: 

203 raise RuntimeError("ab must be true.") # pragma: no cover 

204 self._check_row_(row2, True, verbose=verbose) 

205 if row2 is None: 

206 raise RuntimeError( # pragma: no cover 

207 "reduce_sum_mm expects a second input.") 

208 self._compute_output_row_reduce_sum(row, row2=None, verbose=verbose) 

209 

210 def _compute_output_row_squeeze(self, row, row2=None, ab=False, verbose=False): 

211 if ab: 

212 raise RuntimeError("ab option not allowed.") # pragma: no cover 

213 self._check_row_(row, True, verbose=verbose) 

214 self._check_arg_('axes', tuple) 

215 for a in self.kwargs['axes']: 

216 row[a] = -1 

217 self._check_row_(row, verbose=verbose) 

218 

219 def _compute_output_row_diagonal(self, row, row2=None, ab=False, verbose=False): 

220 if ab: 

221 raise RuntimeError("ab option not allowed.") # pragma: no cover 

222 self._check_row_(row, True, verbose=verbose) 

223 self._check_arg_('diag', list) 

224 to_remove = [] 

225 for choice, choices in self.kwargs['diag']: 

226 for ch in choices: 

227 if ch != choice: 

228 to_remove.append(ch) 

229 for i in range(len(row)): # pylint: disable=C0200 

230 if row[i] in choices: 

231 if row[i] != choice: 

232 row[i] = choice 

233 to_remove.sort() 

234 for r in to_remove: 

235 for i in range(len(row)): # pylint: disable=C0200 

236 if row[i] == r: 

237 raise RuntimeError( # pragma: no cover 

238 "Unexpected result r=%r row=%r to_remove=%r " 

239 "diag=%r." % ( 

240 r, row, to_remove, self.kwargs['diag'])) 

241 if row[i] > r: 

242 row[i] -= 1 

243 self._check_row_(row, verbose=verbose) 

244 

245 def _compute_output_row_matmul(self, row, row2=None, ab=False, verbose=False): 

246 if not ab: 

247 raise RuntimeError("ab must be True.") # pragma: no cover 

248 self._check_row_(row, True, verbose=verbose) 

249 self._check_row_(row2, True, verbose=verbose) 

250 self._check_arg_('axes', tuple) 

251 self._check_arg_('left', tuple) 

252 self._check_arg_('right', tuple) 

253 self._check_arg_('ndim', int) 

254 if row2 is None: 

255 raise RuntimeError( # pragma: no cover 

256 "matmul expects two inputs.") 

257 if verbose: 

258 ndim = self.kwargs['ndim'] 

259 axes = self.kwargs['axes'] 

260 left = self.kwargs['left'] 

261 right = self.kwargs['right'] 

262 print(" MATMUL %r @ %r axes=%r left=%r right=%r - eq=%s" % ( 

263 row, row2, axes, left, right, 

264 _numpy_extended_dot_equation(ndim, ndim, axes, left, right))) 

265 row2[:] = numpy.maximum(row, row2) 

266 for a in self.kwargs['axes']: 

267 if a not in self.kwargs['right']: 

268 row2[a] = -1 

269 self._check_row_(row2, verbose=verbose) 

270 

271 def _compute_output_row_batch_dot(self, row, row2=None, ab=False, verbose=False): 

272 if not ab: 

273 raise RuntimeError("ab must be True.") # pragma: no cover 

274 self._check_row_(row, True, verbose=verbose) 

275 self._check_row_(row2, True, verbose=verbose) 

276 self._check_arg_('batch_axes', tuple) 

277 self._check_arg_('keep_axes', tuple, empty=True) 

278 self._check_arg_('sum_axes', tuple) 

279 self._check_arg_('left', tuple) 

280 self._check_arg_('right', tuple) 

281 self._check_arg_('ndim', int) 

282 if row2 is None: 

283 raise RuntimeError( 

284 "batch_dot expects two inputs.") # pragma: no cover 

285 if verbose: 

286 batch_axes = self.kwargs['batch_axes'] 

287 keep_axes = self.kwargs['keep_axes'] 

288 sum_axes = self.kwargs['sum_axes'] 

289 left = self.kwargs['left'] 

290 right = self.kwargs['right'] 

291 ndim = self.kwargs['ndim'] 

292 print(" BATCH_DOT batch_axes=%r keep_axes=%r sum_axes=%r " 

293 "left=%r right=%r eq=%r" % ( 

294 batch_axes, keep_axes, sum_axes, left, right, 

295 _numpy_extended_dot_equation(ndim, ndim, sum_axes, left, right))) 

296 row2[:] = numpy.maximum(row, row2) 

297 for a in self.kwargs['sum_axes']: 

298 if a not in self.kwargs['right']: 

299 row2[a] = -1 

300 self._check_row_(row2, verbose=verbose) 

301 

302 def _compute_output_row_mul(self, row, row2=None, ab=False, verbose=False): 

303 if not ab: 

304 raise RuntimeError("ab must be True.") # pragma: no cover 

305 self._check_row_(row, True, verbose=verbose) 

306 self._check_row_(row2, True, verbose=verbose) 

307 if row2 is None: 

308 raise RuntimeError("mul expects two inputs.") # pragma: no cover 

309 if verbose: 

310 print( # pragma: no cover 

311 f" MUL {row!r} @ {row2!r}") 

312 row2[:] = numpy.maximum(row, row2) 

313 self._check_row_(row2, verbose=verbose) 

314 

315 def compute_output_row(self, row, row2=None, ab=False, verbose=False): 

316 """ 

317 Updates *row* based on the operator. 

318 """ 

319 method_name = f"_compute_output_row_{self.name}" 

320 meth = getattr(self, method_name, None) 

321 if meth is None: 

322 raise NotImplementedError( # pragma: no cover 

323 f"compute_output_row not implemented for {self.name!r}.") 

324 if verbose and ab: 

325 print(" -- called as a binary operator") 

326 self.add_info(i_row=single_axes(row), i_row2=single_axes(row2)) 

327 meth(row, row2=row2, ab=ab, verbose=verbose) 

328 self.add_info(o_row=single_axes(row), o_row2=single_axes(row2)) 

329 

330 def add_info(self, **kwargs): 

331 """ 

332 Adds information to the node. 

333 

334 :param kwargs: dictionary 

335 """ 

336 for k, v in kwargs.items(): 

337 if k in self._info: 

338 raise KeyError( # pragma: no cover 

339 f"Key {k!r} already added (operator {self.name!r}).") 

340 self._info[k] = v 

341 

342 def _check_inputs_(self, n_expected, check_dim=False): 

343 if len(self.inputs) != n_expected: 

344 raise RuntimeError( # pragma: no cover 

345 "Number of inputs must be %d not %d for operator %r." 

346 "" % (n_expected, len(self.inputs), self.name)) 

347 

348 def _check_shape_(self, m): 

349 if len(m.shape) != self.full_dim: 

350 raise RuntimeError( # pragma: no cover 

351 "Number of dimensions %r is different from expected value " 

352 "%d." % (m.shape, self.full_dim)) 

353 

354 def _get_data(self, data, key): 

355 if isinstance(key, int): 

356 if key not in data: 

357 raise RuntimeError( # pragma: no cover 

358 "Unable to find key %d in %r." % ( 

359 key, list(sorted(data)))) 

360 return data[key] 

361 if isinstance(key, EinsumSubOp): 

362 if id(key) not in data: 

363 raise RuntimeError( # pragma: no cover 

364 "Unable to find key %d in %r." % ( 

365 id(key), list(sorted(data)))) 

366 return data[id(key)] 

367 raise TypeError( # pragma: no cover 

368 f"Unexpected input type {type(key)!r}.") 

369 

370 def _apply_id(self, data, verbose=False, **kwargs): 

371 self._check_inputs_(1) 

372 inp = self.inputs[0] 

373 output = self._get_data(data, inp) 

374 return output 

375 

376 def _apply_diagonal(self, data, verbose=False, **kwargs): 

377 self._check_inputs_(1) 

378 inp = self.inputs[0] 

379 m = self._get_data(data, inp) 

380 if verbose: 

381 print( # pragma: no cover 

382 f"- {self.name}, shape={m.shape!r} diag={self.kwargs['diag']!r}") 

383 diag = self.kwargs['diag'] 

384 if len(diag) != 1: 

385 raise NotImplementedError( # pragma: no cover 

386 f"Not implemented with more than one duplicated indice {diag!r}.") 

387 diag0 = diag[0] 

388 output = numpy_diagonal(m, axis=diag0[0], axes=diag0[1]) 

389 return output 

390 

391 def _apply_expand_dims(self, data, verbose=False, **kwargs): 

392 self._check_inputs_(1) 

393 inp = self.inputs[0] 

394 m = self._get_data(data, inp) 

395 if verbose: 

396 print( 

397 f"- {self.name}, shape={m.shape!r} axes={self.kwargs['axes']!r}") 

398 output = m 

399 for axis in reversed(self.kwargs['axes']): 

400 output = numpy.expand_dims(output, axis[0]) 

401 return output 

402 

403 def _apply_transpose(self, data, verbose=False, **kwargs): 

404 self._check_inputs_(1, True) 

405 inp = self.inputs[0] 

406 m = self._get_data(data, inp) 

407 self._check_shape_(m) 

408 if verbose: 

409 print( 

410 f"- {self.name}, shape={m.shape!r} perm={self.kwargs['perm']!r}") 

411 output = numpy.transpose(m, self.kwargs['perm']) 

412 self._check_shape_(output) 

413 return output 

414 

415 def _apply_transpose_mm(self, data, verbose=False, **kwargs): 

416 self._check_inputs_(2, True) 

417 inp = self.inputs[0] 

418 m = self._get_data(data, inp) 

419 self._check_shape_(m) 

420 if verbose: 

421 print( # pragma: no cover 

422 f"- {self.name}, shape={m.shape!r} perm={self.kwargs['perm']!r}") 

423 output = numpy.transpose(m, self.kwargs['perm']) 

424 self._check_shape_(output) 

425 return output 

426 

427 def _apply_matmul(self, data, verbose=False, **kwargs): 

428 self._check_inputs_(2) 

429 inp1 = self.inputs[0] 

430 inp2 = self.inputs[1] 

431 m1 = self._get_data(data, inp1) 

432 m2 = self._get_data(data, inp2) 

433 self._check_shape_(m1) 

434 self._check_shape_(m2) 

435 axes = self.kwargs['axes'] 

436 left = self.kwargs['left'] 

437 right = self.kwargs['right'] 

438 

439 if verbose: 

440 print("- %s, shapes=%r @ %r axes=%r left=%r right=%r" % ( 

441 self.name, m1.shape, m2.shape, axes, left, right)) 

442 

443 impl = kwargs.get('matmul_impl', None) 

444 if impl == 'pyf': 

445 output = numpy_extended_dot_matrix(m1, m2, axes, left, right, 

446 verbose=verbose) 

447 elif impl == 'py': 

448 output = numpy_extended_dot_python(m1, m2, axes, left, right, 

449 verbose=verbose) 

450 elif impl is None: 

451 output = numpy_extended_dot(m1, m2, axes, left, right, 

452 verbose=verbose) 

453 else: 

454 raise ValueError( 

455 f"Unknown implementation of numpy_extended_dot ({impl}).") 

456 self._check_shape_(output) 

457 return output 

458 

459 def _apply_mul(self, data, verbose=False, **kwargs): 

460 self._check_inputs_(2) 

461 inp1 = self.inputs[0] 

462 inp2 = self.inputs[1] 

463 m1 = self._get_data(data, inp1) 

464 m2 = self._get_data(data, inp2) 

465 self._check_shape_(m1) 

466 self._check_shape_(m2) 

467 

468 if verbose: 

469 print( # pragma: no cover 

470 f"- {self.name}, shapes={m1.shape!r} @ {m2.shape!r}") 

471 

472 output = m1 * m2 

473 self._check_shape_(output) 

474 return output 

475 

476 def _apply_batch_dot(self, data, verbose=False, **kwargs): 

477 self._check_inputs_(2) 

478 inp1 = self.inputs[0] 

479 inp2 = self.inputs[1] 

480 m1 = self._get_data(data, inp1) 

481 m2 = self._get_data(data, inp2) 

482 self._check_shape_(m1) 

483 self._check_shape_(m2) 

484 batch_axes = self.kwargs['batch_axes'] 

485 keep_axes = self.kwargs['keep_axes'] 

486 sum_axes = self.kwargs['sum_axes'] 

487 left = self.kwargs['left'] 

488 right = self.kwargs['right'] 

489 

490 if verbose: 

491 print("- %s, shapes=%r @ %r batch_axes=%r keep_axes=%r " 

492 "sum_axes=%r" % ( 

493 self.name, m1.shape, m2.shape, batch_axes, keep_axes, sum_axes)) 

494 

495 if len(m1.shape) != len(m2.shape): 

496 raise RuntimeError( # pragma: no cover 

497 "batch_dot only work with two tensors with the same number " 

498 "of dimensions not %r @ %r." % (m1.shape, m2.shape)) 

499 

500 dim0 = int(numpy.prod([m1.shape[i] for i in batch_axes])) 

501 dim0b = int(numpy.prod([m2.shape[i] for i in batch_axes])) 

502 dimb = int(-1 if keep_axes is None else numpy.prod( 

503 [m1.shape[i] for i in keep_axes])) 

504 dim1 = int(numpy.prod([m1.shape[i] for i in sum_axes])) 

505 dim2 = int(numpy.prod([m2.shape[i] for i in sum_axes])) 

506 

507 if verbose: 

508 print(f"- {self.name}, reshape={m1.shape!r} into {dim0, dimb, dim1!r}") 

509 print(f"- {self.name}, reshape={m2.shape!r} into {dim0b, dimb, dim2!r}") 

510 m1sh = m1.reshape((dim0, dimb, dim1)) 

511 m2sh = m2.reshape((dim0b, dimb, dim2)) 

512 

513 batch_kind = self.get_dot_kind() 

514 if batch_kind in ('11', 'N1', 'N1'): 

515 m1sh = m1sh.reshape((-1, m1sh.shape[-1])) 

516 m2sh = m2sh.reshape((-1, m2sh.shape[-1])) 

517 if verbose: 

518 print("- %s, use gemm with shape %r, %r" % ( 

519 self.name, m1sh.shape, m2sh.shape)) 

520 dot = gemm_dot(m1sh, m2sh, False, True) 

521 else: 

522 dot = m1sh @ numpy.transpose(m2sh, (0, 2, 1)) 

523 

524 # new shape 

525 new_shape = ([max(m1.shape[i], m2.shape[i]) for i in batch_axes] + 

526 [m1.shape[i] for i in left if i not in batch_axes] + 

527 [m2.shape[i] for i in right if i not in batch_axes]) 

528 while len(new_shape) < len(m1.shape): 

529 new_shape.append(1) 

530 

531 if verbose: 

532 taken = set(batch_axes) | set(sum_axes) 

533 ax = [i for i in range(len(m1.shape)) if i not in taken] 

534 print("- %s, shapes=%r @ %r -> %r" % ( 

535 self.name, m1sh.shape, m2sh.shape, dot.shape)) 

536 print("- %s, batch_axes=%r ax=%r new_shape=%r left=%r right=%r" % ( 

537 self.name, batch_axes, ax, new_shape, left, right)) 

538 

539 output = dot.reshape(tuple(new_shape)) 

540 self._check_shape_(output) 

541 return output 

542 

543 def _apply_reduce_sum(self, data, verbose=False, **kwargs): 

544 self._check_inputs_(1) 

545 inp = self.inputs[0] 

546 m = self._get_data(data, inp) 

547 self._check_shape_(m) 

548 axes = self.kwargs['axes'] 

549 if verbose: 

550 print( 

551 f"- {self.name}, shape={m.shape!r} axes={self.kwargs['axes']!r}") 

552 output = numpy.sum(m, axis=axes, keepdims=True) 

553 self._check_shape_(output) 

554 return output 

555 

556 def _apply_reduce_sum_mm(self, data, verbose=False, **kwargs): 

557 self._check_inputs_(2, True) 

558 inp = self.inputs[0] 

559 m = self._get_data(data, inp) 

560 self._check_shape_(m) 

561 if verbose: 

562 print( 

563 f"- {self.name}, shape={m.shape!r} axes={self.kwargs['axes']!r}") 

564 output = numpy.sum(m, self.kwargs['axes']) 

565 self._check_shape_(output) 

566 return output 

567 

568 def _apply_squeeze(self, data, verbose=False, **kwargs): 

569 self._check_inputs_(1) 

570 inp = self.inputs[0] 

571 m = self._get_data(data, inp) 

572 axes = self.kwargs['axes'] 

573 if verbose: 

574 print( 

575 f"- {self.name}, shape={m.shape!r} axes={self.kwargs['axes']!r}") 

576 output = m 

577 for a in axes[::-1]: 

578 output = numpy.squeeze(output, axis=a) 

579 return output 

580 

581 def apply(self, data, verbose=False, **kwargs): 

582 """ 

583 Applies one operator on the data. 

584 

585 :param data: dictionary storing the results 

586 :param verbose: prints out intermediate results 

587 :param kwargs: additional parameters, see 

588 methods `_apply*` 

589 :return: output 

590 

591 Known additional paramaters: 

592 

593 * 'matmul_impl': if None calls :epkg:`numpy:einsum` through 

594 @see fn numpy_extended_dot (default) or 'py' to call 

595 @see fn numpy_extended_dot_python instead. 

596 """ 

597 if verbose: 

598 print() 

599 print("apply %r (%s)." % ( 

600 self.name, ", ".join(map(lambda s: str(id(s)), self.inputs)))) 

601 

602 method_name = f"_apply_{self.name}" 

603 meth = getattr(self, method_name, None) 

604 if meth is None: 

605 raise NotImplementedError( # pragma: no cover 

606 f"apply not implemented for {self.name!r}.") 

607 output = meth(data, verbose, **kwargs) 

608 

609 data[id(self)] = output 

610 if verbose: 

611 print("+ %s, shape=%r -- %d" % (self.name, output.shape, id(self))) 

612 return output 

613 

614 def _onnx_name(self): 

615 return 'einsum%d_%s' % (id(self), self.name[:2]) 

616 

617 def _check_onnx_opset_(self, opset, limit): 

618 if opset is not None and opset < limit: 

619 raise RuntimeError( # pragma: no cover 

620 f"Opset ({opset!r}) must be >= {limit!r} for operator {self.name!r}.") 

621 

622 def _to_onnx_id(self, names, opset, verbose=False, **kwargs): 

623 self._check_inputs_(1) 

624 inp = self.inputs[0] 

625 name = self._get_data(names, inp) 

626 yield helper.make_node('Identity', [name], [self._onnx_name()]) 

627 

628 def _to_onnx_expand_dims(self, names, opset, verbose=False, **kwargs): 

629 self._check_inputs_(1) 

630 self._check_onnx_opset_(opset, 11) 

631 inp = self.inputs[0] 

632 name = self._get_data(names, inp) 

633 axes = self.kwargs['axes'] 

634 name_axes = name + '_axes' 

635 yield numpy_helper.from_array( 

636 numpy.array([a[1] for a in axes], dtype=numpy.int64), name=name_axes) 

637 s_axes = "".join(map(str, [a[1] for a in axes])) 

638 yield helper.make_node( 

639 'Unsqueeze', [name, name_axes], [self._onnx_name()], 

640 name='Unsqueeze%s_%d' % (s_axes, id(self))) 

641 

642 def _to_onnx_squeeze(self, names, opset, verbose=False, **kwargs): 

643 self._check_inputs_(1) 

644 self._check_onnx_opset_(opset, 11) 

645 inp = self.inputs[0] 

646 name = self._get_data(names, inp) 

647 axes = self.kwargs['axes'] 

648 name_axes = name + '_axes' 

649 yield numpy_helper.from_array( 

650 numpy.array(axes, dtype=numpy.int64), name=name_axes) 

651 s_axes = "".join(map(str, axes)) 

652 yield helper.make_node( 

653 'Squeeze', [name, name_axes], [self._onnx_name()], 

654 name='Squeeze%s_%d' % (s_axes, id(self))) 

655 

656 def _to_onnx_transpose(self, names, opset, verbose=False, **kwargs): 

657 self._check_inputs_(1) 

658 inp = self.inputs[0] 

659 name = self._get_data(names, inp) 

660 perm = self.kwargs['perm'] 

661 s_perm = "".join(map(str, perm)) 

662 yield helper.make_node( 

663 'Transpose', [name], [self._onnx_name()], perm=perm, 

664 name='Transpose%s_%d' % (s_perm, id(self))) 

665 

666 def _to_onnx_reduce_sum(self, names, opset, verbose=False, **kwargs): 

667 self._check_inputs_(1) 

668 self._check_onnx_opset_(opset, 11) 

669 inp = self.inputs[0] 

670 name = self._get_data(names, inp) 

671 axes = self.kwargs['axes'] 

672 name_axes = self._onnx_name() + '_axes' 

673 yield numpy_helper.from_array( 

674 numpy.array(axes, dtype=numpy.int64), name=name_axes) 

675 s_axes = "".join(map(str, axes)) 

676 yield helper.make_node( 

677 'ReduceSum', [name, name_axes], [self._onnx_name()], keepdims=1, 

678 name='ReduceSum%s_%d' % (s_axes, id(self))) 

679 

680 def _to_onnx_mul(self, data, verbose=False, **kwargs): 

681 self._check_inputs_(2) 

682 inp1 = self.inputs[0] 

683 inp2 = self.inputs[1] 

684 m1 = self._get_data(data, inp1) 

685 m2 = self._get_data(data, inp2) 

686 yield helper.make_node('Mul', [m1, m2], [self._onnx_name()]) 

687 

688 def _to_onnx_batch_dot(self, names, opset, verbose=False, **kwargs): # pylint: disable=R0914 

689 self._check_inputs_(2) 

690 self._check_onnx_opset_(opset, 13) 

691 inp1, inp2 = self.inputs[:2] # pylint: disable=W0632 

692 name1 = self._get_data(names, inp1) 

693 name2 = self._get_data(names, inp2) 

694 

695 batch_axes = self.kwargs['batch_axes'] 

696 keep_axes = self.kwargs['keep_axes'] 

697 sum_axes = self.kwargs['sum_axes'] 

698 left = self.kwargs['left'] 

699 right = self.kwargs['right'] 

700 root = self._onnx_name() 

701 

702 def return_name_one(): 

703 name_one = root + "_1" 

704 return name_one, numpy_helper.from_array( 

705 numpy.array([1], dtype=numpy.int64), name=name_one) 

706 

707 name_one = None 

708 name_shape1 = root + "_shape1" 

709 name_shape2 = root + "_shape2" 

710 concat_left = [] 

711 concat_right = [] 

712 yield helper.make_node('Shape', [name1], [name_shape1]) 

713 yield helper.make_node('Shape', [name2], [name_shape2]) 

714 

715 if len(batch_axes) > 0: 

716 name_batch_axes = root + "_batch_axes" 

717 yield numpy_helper.from_array( 

718 numpy.array(batch_axes, dtype=numpy.int64), name=name_batch_axes) 

719 

720 if len(sum_axes) > 0: 

721 name_sum_axes = root + "_sum_axes" 

722 yield numpy_helper.from_array( 

723 numpy.array(sum_axes, dtype=numpy.int64), name=name_sum_axes) 

724 

725 # dim0 = int(numpy.prod([m1.shape[i] for i in batch_axes])) 

726 # dim0b = int(numpy.prod([m2.shape[i] for i in batch_axes])) 

727 if len(batch_axes) > 1: 

728 name_dim0 = root + "_dim0" 

729 name_dim0b = root + "_dim0b" 

730 name_dim0g = name_dim0 + 'g' 

731 name_dim0bg = name_dim0b + 'g' 

732 concat_left.append(name_dim0) 

733 concat_right.append(name_dim0b) 

734 yield helper.make_node( 

735 'Gather', [name_shape1, name_batch_axes], [name_dim0g]) 

736 yield helper.make_node( 

737 'Gather', [name_shape2, name_batch_axes], [name_dim0bg]) 

738 yield helper.make_node( 

739 'ReduceProd', [name_dim0g], [name_dim0], keepdims=1) 

740 yield helper.make_node( 

741 'ReduceProd', [name_dim0bg], [name_dim0b], keepdims=1) 

742 elif len(batch_axes) == 1: 

743 name_dim0g = root + "_dim0g" 

744 name_dim0bg = root + "_dim0bg" 

745 name_dim0 = name_dim0g 

746 name_dim0b = name_dim0bg 

747 concat_left.append(name_dim0) 

748 concat_right.append(name_dim0b) 

749 yield helper.make_node( 

750 'Gather', [name_shape1, name_batch_axes], [name_dim0g]) 

751 yield helper.make_node( 

752 'Gather', [name_shape2, name_batch_axes], [name_dim0bg]) 

753 else: 

754 if name_one is None: 

755 name_one, cst_init = return_name_one() 

756 yield cst_init 

757 name_dim0 = name_one 

758 name_dim0b = name_one 

759 concat_left.append(name_dim0) 

760 concat_right.append(name_dim0b) 

761 

762 # dimb = int(-1 if keep_axes is None else numpy.prod( 

763 # [m1.shape[i] for i in keep_axes])) 

764 if keep_axes in (-1, None) or len(keep_axes) == 0: 

765 name_dimb = root + "__1" 

766 concat_left.append(name_dimb) 

767 concat_right.append(name_dimb) 

768 yield numpy_helper.from_array( 

769 numpy.array([-1], dtype=numpy.int64), name=name_dimb) 

770 elif len(keep_axes) == 1: 

771 name_keep_axes = root + "_keep_axes" 

772 name_dimb = root + "_dimb" 

773 name_dimbg = name_dimb 

774 concat_left.append(name_dimb) 

775 concat_right.append(name_dimb) 

776 yield numpy_helper.from_array( 

777 numpy.array(keep_axes, dtype=numpy.int64), name=name_keep_axes) 

778 yield helper.make_node( 

779 'Gather', [name_shape1, name_keep_axes], [name_dimbg]) 

780 else: 

781 name_keep_axes = root + "_keep_axes" 

782 name_dimb = root + "_dimb" 

783 name_dimbg = name_dimb + 'g' 

784 concat_left.append(name_dimb) 

785 concat_right.append(name_dimb) 

786 yield numpy_helper.from_array( 

787 numpy.array(keep_axes, dtype=numpy.int64), name=name_keep_axes) 

788 yield helper.make_node( 

789 'Gather', [name_shape1, name_keep_axes], [name_dimbg]) 

790 yield helper.make_node( 

791 'ReduceProd', [name_dimbg], [name_dimb], keepdims=1) 

792 

793 # dim1 = int(numpy.prod([m1.shape[i] for i in sum_axes])) 

794 # dim2 = int(numpy.prod([m2.shape[i] for i in sum_axes])) 

795 

796 if len(sum_axes) == 0: 

797 if name_one is None: 

798 name_one, cst_init = return_name_one() 

799 yield cst_init 

800 name_dim1 = name_one 

801 name_dim2 = name_one 

802 concat_left.append(name_dim1) 

803 concat_right.append(name_dim2) 

804 elif len(sum_axes) == 1: 

805 name_dim1 = root + "_dim1" 

806 name_dim2 = root + "_dim2" 

807 name_dim1g = name_dim1 

808 name_dim2g = name_dim2 

809 concat_left.append(name_dim1) 

810 concat_right.append(name_dim2) 

811 yield helper.make_node( 

812 'Gather', [name_shape1, name_sum_axes], [name_dim1g]) 

813 yield helper.make_node( 

814 'Gather', [name_shape2, name_sum_axes], [name_dim2g]) 

815 else: 

816 name_dim1 = root + "_dim1" 

817 name_dim2 = root + "_dim2" 

818 name_dim1g = name_dim1 + 'g' 

819 name_dim2g = name_dim2 + 'g' 

820 concat_left.append(name_dim1) 

821 concat_right.append(name_dim2) 

822 yield helper.make_node( 

823 'Gather', [name_shape1, name_sum_axes], [name_dim1g]) 

824 yield helper.make_node( 

825 'Gather', [name_shape2, name_sum_axes], [name_dim2g]) 

826 yield helper.make_node( 

827 'ReduceProd', [name_dim1g], [name_dim1], keepdims=1) 

828 yield helper.make_node( 

829 'ReduceProd', [name_dim2g], [name_dim2], keepdims=1) 

830 

831 batch_kind = self.get_dot_kind() 

832 if batch_kind in ('11', 'N1', 'N1'): 

833 # *shape1, *shape2 

834 name_minus_one = root + "__01" 

835 yield numpy_helper.from_array( 

836 numpy.array([-1], dtype=numpy.int64), name=name_minus_one) 

837 name_agg_shape1_2 = root + f"_resh1_{batch_kind}" 

838 name_agg_shape2_2 = root + f"_resh2_{batch_kind}" 

839 yield helper.make_node( 

840 'Concat', [name_minus_one, name_dim1], [name_agg_shape1_2], axis=0) 

841 yield helper.make_node( 

842 'Concat', [name_minus_one, name_dim2], [name_agg_shape2_2], axis=0) 

843 

844 # m1sh = m1.reshape((-1, dim1)) 

845 # m2sh = m2.reshape((-1, dim2)) 

846 name_agg1_2 = root + "_aresh1" 

847 name_agg2_2 = root + "_aresh2" 

848 yield helper.make_node('Reshape', [name1, name_agg_shape1_2], [name_agg1_2]) 

849 yield helper.make_node('Reshape', [name2, name_agg_shape2_2], [name_agg2_2]) 

850 

851 # dot = gemm(m1sh, m2sh, False, True) 

852 name_dot = root + "_gemm" 

853 yield helper.make_node( 

854 'Gemm', [name_agg1_2, name_agg2_2], [name_dot], 

855 alpha=1., beta=0., transA=0, transB=1) 

856 else: 

857 # *shape1, *shape2 

858 name_agg_shape1 = root + "_resh1" 

859 name_agg_shape2 = root + "_resh2" 

860 yield helper.make_node( 

861 'Concat', concat_left, [name_agg_shape1], axis=0) 

862 yield helper.make_node( 

863 'Concat', concat_right, [name_agg_shape2], axis=0) 

864 

865 # m1sh = m1.reshape((dim0, dimb, dim1)) 

866 # m2sh = m2.reshape((dim0b, dimb, dim2)) 

867 name_agg1 = root + "_aresh1" 

868 name_agg2 = root + "_aresh2" 

869 yield helper.make_node('Reshape', [name1, name_agg_shape1], [name_agg1]) 

870 yield helper.make_node('Reshape', [name2, name_agg_shape2], [name_agg2]) 

871 

872 # dot = m1sh @ numpy.transpose(m2sh, (0, 2, 1)) 

873 name_agg2_tr = root + "_aresh2_tr" 

874 yield helper.make_node( 

875 'Transpose', [name_agg2], [name_agg2_tr], perm=[0, 2, 1], 

876 name=f"Transpose021_{id(self)}") 

877 

878 name_dot = root + "_dot" 

879 yield helper.make_node( 

880 'MatMul', [name_agg1, name_agg2_tr], [name_dot]) 

881 

882 # new_shape = ([max(m1.shape[i], m2.shape[i]) for i in batch_axes] + 

883 # [m1.shape[i] for i in left if i not in batch_axes] + 

884 # [m2.shape[i] for i in right if i not in batch_axes]) 

885 concat_final = [] 

886 if len(batch_axes) > 0: 

887 name_max_dim = root + "_max_dim" 

888 concat_final.append(name_max_dim) 

889 yield helper.make_node( 

890 'Max', [name_dim0g, name_dim0bg], [name_max_dim]) 

891 

892 left_set = list(sorted(set(left) - (set(batch_axes) & set(left)))) 

893 if len(left_set) > 0: 

894 name_left_dim = root + "_left_dim" 

895 name_left_set = root + "_left_set" 

896 yield numpy_helper.from_array( 

897 numpy.array(left_set, dtype=numpy.int64), name=name_left_set) 

898 yield helper.make_node( 

899 'Gather', [name_shape1, name_left_set], [name_left_dim]) 

900 concat_final.append(name_left_dim) 

901 

902 right_set = list(sorted(set(right) - (set(batch_axes) & set(right)))) 

903 if len(right_set) > 0: 

904 name_right_dim = root + "_right_dim" 

905 name_right_set = root + "_right_set" 

906 yield numpy_helper.from_array( 

907 numpy.array(right_set, dtype=numpy.int64), name=name_right_set) 

908 yield helper.make_node( 

909 'Gather', [name_shape2, name_right_set], [name_right_dim]) 

910 concat_final.append(name_right_dim) 

911 

912 name_new_shape = root + '_new_shape' 

913 diff = ( 

914 self.full_dim - 

915 (len(batch_axes) + len(left_set) + len(right_set))) 

916 if diff > 0: 

917 names_ones = root + "_ones" 

918 yield numpy_helper.from_array( 

919 numpy.array([1 for i in range(diff)], dtype=numpy.int64), 

920 name=names_ones) 

921 concat_final.append(names_ones) 

922 

923 yield helper.make_node( 

924 'Concat', concat_final, [name_new_shape], axis=0) 

925 

926 name_final = root + '_final' 

927 yield helper.make_node( 

928 'Reshape', [name_dot, name_new_shape], [name_final]) 

929 

930 def to_onnx(self, names, opset=None, verbose=False, **kwargs): 

931 """ 

932 Converts this node into ONNX. Enumerates all ONNX node 

933 which participate to the conversion. The last one 

934 is the final output. 

935 

936 :param names: dictionary where to find already converted name 

937 :param opset: opset 

938 :param verbose: prints out intermediate results 

939 :param kwargs: additional parameter for the conversion 

940 :return: output 

941 """ 

942 if opset is None: 

943 opset = __max_supported_opset__ # pragma: no cover 

944 if verbose: 

945 print() 

946 print("to_onnx %r (%s) opset=%r." % ( 

947 self.name, 

948 ", ".join(map(lambda s: str(id(s)), self.inputs)), 

949 opset)) 

950 

951 method_name = f"_to_onnx_{self.name}" 

952 meth = getattr(self, method_name, None) 

953 if meth is None: 

954 if self.name.endswith("_mm"): 

955 raise NotImplementedError( 

956 "to_onnx not implemented for %r." 

957 "You should call method simplify_mm_nodes " 

958 "to remove it." % self.name) 

959 raise NotImplementedError( 

960 f"to_onnx not implemented for {self.name!r}.") 

961 for node in meth(names, verbose=verbose, opset=opset, **kwargs): 

962 if hasattr(node, 'output'): 

963 names[id(self)] = node.output[0] 

964 if verbose: 

965 print("+ OP %r -- (%s - %d)" % 

966 (node.output[0], self.name, id(self))) 

967 elif verbose: 

968 # Initializer 

969 print("+ CT %r -- (%s - %d)" % 

970 (node.name, self.name, id(self))) 

971 yield node 

972 

973 def get_dot_kind(self): 

974 """ 

975 Every matrix multiplication can be either: 

976 

977 * a simple multiplication (`M`) (undetected) 

978 * a 2D matrix multiplication (`11`) 

979 * a broadcasted matrix multiplication (`N1` or `1N`) 

980 * a batch matrix multiplication (`NN`) 

981 

982 This method returns which kind it is. 

983 """ 

984 batch_axes = self.kwargs['batch_axes'] 

985 # keep_axes = self.kwargs['keep_axes'] 

986 # sum_axes = self.kwargs['sum_axes'] 

987 # left = self.kwargs['left'] 

988 # right = self.kwargs['right'] 

989 info = self._info 

990 row_left = info['i_row'] 

991 row_right = info['i_row2'] 

992 

993 batch_left = [row_left[k] for k in batch_axes] 

994 batch_right = [row_right[k] for k in batch_axes] 

995 n_left = len(batch_left) > 0 and max(batch_left) == 2 

996 n_right = len(batch_right) > 0 and max(batch_right) == 2 

997 return f"{'N' if n_left else '1'}{'N' if n_right else '1'}" 

998 

999 

1000class GraphEinsumSubOp: 

1001 """ 

1002 Class gathering all nodes produced to explicit einsum 

1003 operators. 

1004 

1005 :param letters: list of distinct letters 

1006 :param mat: matrix, see @see fn analyse_einsum_equation 

1007 :param lengths: lengths of every input 

1008 :param duplicates: see @see fn analyse_einsum_equation 

1009 """ 

1010 

1011 def __init__(self, letters, mat, lengths, duplicates): 

1012 self._nodes = {} 

1013 self._mark = {} 

1014 self._ops = [] 

1015 self._inputs = {} 

1016 self.last_op = None 

1017 self.last_added_op = None 

1018 self.metadata = dict( 

1019 letters=letters, mat=mat, lengths=lengths, 

1020 mat0=mat.copy(), duplicates=duplicates) 

1021 

1022 def append(self, op): 

1023 """ 

1024 Adds one input or result. 

1025 

1026 :param op: integer (an input) or an instance of @see cl EinsumSubOp. 

1027 :return: op or None if op is an integer 

1028 """ 

1029 if isinstance(op, int): 

1030 if op in self._nodes: 

1031 raise RuntimeError( # pragma: no cover 

1032 "Key %d already added." % op) 

1033 self._nodes[op] = op 

1034 self.last_added_op = op 

1035 self._inputs[op] = op 

1036 return None 

1037 if isinstance(op, EinsumSubOp): 

1038 if op in self._nodes: 

1039 raise RuntimeError( # pragma: no cover 

1040 "Key %d already added, op=%r." % (id(op), op)) 

1041 self._nodes[id(op)] = op 

1042 self._ops.append(op) 

1043 self.last_added_op = op 

1044 return op 

1045 raise TypeError( # pragma: no cover 

1046 f"Unexpected type {type(op)!r}.") 

1047 

1048 def mark_last_node(self): 

1049 """ 

1050 Marks the last node as the final output. 

1051 """ 

1052 if self.last_added_op is None: 

1053 raise RuntimeError("last_added_op is None.") # pragma: no cover 

1054 self.mark(-1, self.last_added_op) 

1055 

1056 def mark(self, i, op): 

1057 """ 

1058 Marks one input or result as an intermediate result 

1059 after a full einsum step. 

1060 

1061 :param op: integer (an input) or an instance of @see cl EinsumSubOp. 

1062 """ 

1063 if not isinstance(i, int): 

1064 raise TypeError( # pragma: no cover 

1065 f"i must an integer not {type(i)!r}.") 

1066 if i != -1 and i not in self._inputs: 

1067 raise RuntimeError( # pragma: no cover 

1068 "Input %d was not registered in %r." % (i, self._inputs)) 

1069 if isinstance(op, EinsumSubOp): 

1070 if id(op) not in self._nodes: 

1071 raise RuntimeError( # pragma: no cover 

1072 "Key %d not found, op=%r." % (id(op), op)) 

1073 self._mark[i] = op 

1074 self._mark[id(op)] = i 

1075 self.last_op = op 

1076 else: 

1077 raise TypeError( # pragma: no cover 

1078 f"Unexpected type {type(i)!r}.") 

1079 

1080 def __iter__(self): 

1081 "Iterates on nodes." 

1082 for op in self._ops: 

1083 yield op 

1084 

1085 def to_dot(self, **kwargs): 

1086 """ 

1087 Produces a graph in :epkg:`dot`. 

1088 

1089 :param kwargs: additional graph option 

1090 :return: string 

1091 """ 

1092 options = { 

1093 'orientation': 'portrait', 

1094 'ranksep': '0.25', 

1095 'nodesep': '0.05', 

1096 'width': '0.5', 

1097 'height': '0.1', 

1098 'size': '5', 

1099 'node': '[shape=record]', 

1100 } 

1101 options.update(kwargs) 

1102 

1103 def d2s(d): 

1104 it = [] 

1105 for k, v in sorted(d.items()): 

1106 it.append(f"{k}={v}") 

1107 return " ".join(it) 

1108 

1109 def d2sd(d): 

1110 it = [] 

1111 for k, v in sorted(d.items()): 

1112 if len(v) > 1: 

1113 it.append(f"{k}={','.join(map(str, v))}") 

1114 return " ".join(it) 

1115 

1116 rows = ["digraph{"] 

1117 for k, v in options.items(): 

1118 if isinstance(v, str) and "[" in v: 

1119 rows.append(f"{k} {v};") 

1120 else: 

1121 rows.append(f"{k}={v};") 

1122 for k, v in self._nodes.items(): 

1123 if isinstance(v, int): 

1124 let = [(r, self.metadata['letters'][i]) 

1125 for i, r in enumerate(self.metadata['mat0'][v]) 

1126 if r != -1] 

1127 dup = self.metadata['duplicates'][v] 

1128 if dup is None: 

1129 dup = "" 

1130 else: 

1131 dup = f" - {d2sd(dup)}" 

1132 let.sort() 

1133 letters = "".join(_[1] for _ in let) 

1134 lab = "input %d\\\\n%s\\\\n%s%s" % ( 

1135 v, letters, str(self.metadata['mat0'][v]), dup) 

1136 sk = v 

1137 extended_lab = "" 

1138 else: 

1139 lab = f"{v.name}\\\\n{d2s(v.kwargs)}" 

1140 sk = id(v) 

1141 extended_lab = v.dot_label() 

1142 if extended_lab: 

1143 extended_lab = "\\\\n" + extended_lab 

1144 

1145 if sk in self._mark and isinstance(self._mark[sk], int): 

1146 la = self._mark[sk] 

1147 lab = lab.replace("\\\\n", " - I%d\\\\n" % la) 

1148 s = ('%d [label="%s%s" style=filled ' 

1149 'fillcolor=red];' % (k, lab, extended_lab)) 

1150 else: 

1151 s = '%d [label="%s%s"];' % (k, lab, extended_lab) 

1152 rows.append(s) 

1153 if not hasattr(v, 'inputs'): 

1154 continue 

1155 for i in v.inputs: 

1156 vid = i if isinstance(i, int) else id(i) 

1157 s = "%d -> %d;" % (vid, k) 

1158 rows.append(s) 

1159 rows.append("}") 

1160 return "\n".join(rows) 

1161 

1162 def apply_sequence(self, *inputs, verbose=False, **kwargs): 

1163 """ 

1164 Applies a sequence of operations on a list of inputs. 

1165 

1166 :param inputs: inputs: 

1167 :param verbose: prints out intermediate results 

1168 :param kwargs: additional parameters, 

1169 see :meth:`apply 

1170 <mlprodict.testing.einsum.einsum_impl_classes.EinsumSubOp.apply>`. 

1171 :return: output 

1172 """ 

1173 if verbose: 

1174 print('######### apply_sequence') 

1175 data = {i: inp for i, inp in enumerate(inputs)} 

1176 last = None 

1177 for op in self: 

1178 last = op.apply(data, verbose=verbose, **kwargs) 

1179 if last is None: 

1180 raise RuntimeError( # pragma: no cover 

1181 "Sequence of operations is empty.") 

1182 return last 

1183 

1184 def clean_unused_nodes(self, verbose=False): 

1185 """ 

1186 Cleans nodes with unused outputs. 

1187 

1188 :param verbose: display intermediate information 

1189 """ 

1190 

1191 def iteration(it): 

1192 # Walks through all nodes. 

1193 is_used = {} 

1194 for node in self._ops: 

1195 if not isinstance(node, EinsumSubOp): 

1196 continue 

1197 if id(node) not in is_used: 

1198 is_used[id(node)] = [] 

1199 for inp in node.inputs: 

1200 if not isinstance(inp, EinsumSubOp): 

1201 continue 

1202 idn = id(inp) 

1203 if idn not in is_used: 

1204 is_used[idn] = [] 

1205 is_used[idn].append(id(node)) 

1206 

1207 # Remove unused nodes. 

1208 removed = [] 

1209 for k, v in is_used.items(): 

1210 if len(v) == 0: 

1211 removed.append(k) 

1212 removed = set(removed) 

1213 i_rem = [] 

1214 for i, op in enumerate(self._ops): 

1215 if not isinstance(op, EinsumSubOp): 

1216 continue 

1217 if id(op) in removed and id(op) not in self._mark: 

1218 i_rem.append((i, id(op))) 

1219 for i, idn in reversed(i_rem): 

1220 if verbose: 

1221 print("[GraphEinsumSubOp.clean_nodes] remove node " 

1222 "i=%d: %d - id=%d" % (it, i, idn)) 

1223 del self._ops[i] 

1224 del self._nodes[idn] 

1225 return len(i_rem) > 0 

1226 

1227 it = 1 

1228 while iteration(it): 

1229 it += 1 

1230 

1231 self.last_op = None 

1232 self.last_added_op = None 

1233 

1234 def simplify_mm_nodes(self, verbose=False): 

1235 """ 

1236 Node name suffixed by `mm` are an artifact to keep 

1237 the graph consistent while building it. They can 

1238 now be replaced by the equivalent node without suffix `mm`. 

1239 

1240 :param verbose: display intermediate information 

1241 """ 

1242 for op in self: 

1243 if not isinstance(op, EinsumSubOp): 

1244 continue 

1245 if op.name.endswith('_mm'): 

1246 if verbose: 

1247 print("[GraphEinsumSubOp.simplify_mm_nodes] node %r" 

1248 " - id=%d" % (op.name, id(op))) 

1249 if len(op.inputs) != 2: 

1250 raise RuntimeError( # pragma: no cover 

1251 "Expecting 2 inputs for node %r not %r id=%r." % ( 

1252 op.name, len(op.inputs), id(op))) 

1253 op.name = op.name[:-3] 

1254 op.inputs = op.inputs[:1] 

1255 

1256 def _get_forward_nodes(self): 

1257 """ 

1258 Returns the forward nodes. 

1259 """ 

1260 forward = {} 

1261 for op in self: 

1262 if isinstance(op, int): 

1263 continue 

1264 for inp in op.inputs: 

1265 key = inp if isinstance(inp, int) else id(inp) 

1266 if key in forward: 

1267 forward[key].append(op) 

1268 else: 

1269 forward[key] = [op] 

1270 return forward 

1271 

1272 def _pprint_forward(self): 

1273 rows = [] 

1274 for op in self: 

1275 line = "%r <- %s(%s)" % ( 

1276 id(op), op.name, 

1277 ", ".join(map(str, [id(_) for _ in op.inputs]))) 

1278 rows.append(line) 

1279 return "\n".join(rows) 

1280 

1281 def _replace_node_sequence(self, added, deleted): 

1282 """ 

1283 Removes a sequence of nodes. The method does not check 

1284 that the graph remains consistent. 

1285 """ 

1286 forward = self._get_forward_nodes() 

1287 key = id(deleted[-1]) 

1288 if key not in forward: 

1289 raise RuntimeError( # pragma: no cover 

1290 "Key {} missing in all forward nodes (other keys {}), " 

1291 "all keys:\n{}".format( 

1292 key, [id(_) for _ in deleted], 

1293 self._pprint_forward())) 

1294 

1295 # deletion 

1296 mark_input = None 

1297 for d in deleted: 

1298 del self._nodes[id(d)] 

1299 if id(d) in self._mark: 

1300 del self._mark[id(d)] 

1301 dels = [] 

1302 for k, v in self._mark.items(): 

1303 if id(v) == id(d): 

1304 mark_input = k 

1305 dels.append(k) 

1306 if len(dels) != 1: 

1307 raise RuntimeError( # pragma: no cover 

1308 "Input %d has more than one marked operator " 

1309 "(%r)." % (id(d), dels)) 

1310 del self._mark[dels[0]] 

1311 

1312 dels = set(id(o) for o in deleted) 

1313 rem = [] 

1314 for i, op in enumerate(self._ops): 

1315 if id(op) in dels: 

1316 rem.append(i) 

1317 if len(rem) != len(deleted): 

1318 raise RuntimeError( # pragma: no cover 

1319 f"Mismatched length {rem!r}, {dels!r}, len={len(deleted)!r}.") 

1320 for i in reversed(rem): 

1321 del self._ops[i] 

1322 self.last_add_op = None 

1323 

1324 # insertion 

1325 if added is not None: 

1326 self._ops.insert(rem[0], added) 

1327 self._nodes[id(added)] = added 

1328 for op in forward[key]: 

1329 new_inputs = list(op.inputs) 

1330 for i in range(len(op.inputs)): # pylint: disable=C0200 

1331 if id(op.inputs[i]) == key: 

1332 new_inputs[i] = added 

1333 op.inputs = tuple(new_inputs) 

1334 if mark_input is not None: 

1335 self.mark(mark_input, added) 

1336 else: 

1337 inps = deleted[0].inputs 

1338 if len(inps) != 1: 

1339 raise RuntimeError( # pragma: no cover 

1340 "More than one input. Call another method.") 

1341 inp = inps[0] 

1342 for op in forward[key]: 

1343 new_inputs = list(op.inputs) 

1344 for i in range(len(op.inputs)): # pylint: disable=C0200 

1345 if id(op.inputs[i]) == key: 

1346 new_inputs[i] = inp 

1347 op.inputs = tuple(new_inputs) 

1348 if mark_input is not None: 

1349 self.mark(mark_input, inp) 

1350 

1351 def remove_duplicate_transpose(self, verbose=False): 

1352 """ 

1353 Removes consecutive transpose by merging them. 

1354 

1355 :param verbose: display intermediate information 

1356 """ 

1357 modif = 1 

1358 while modif > 0: 

1359 modif = 0 

1360 candidates = [] 

1361 forward = self._get_forward_nodes() 

1362 for op in self: 

1363 if op.name == "transpose": 

1364 inp = op.inputs[0] 

1365 if (isinstance(inp, EinsumSubOp) and 

1366 inp.name == 'transpose' and 

1367 len(forward[id(inp)]) == 1): 

1368 candidates.append(op) 

1369 

1370 if len(candidates) > 0: 

1371 modif = 1 

1372 # Not efficient to take the first one and to 

1373 # start again but the graph should not be too big. 

1374 cand = candidates[0] 

1375 op2 = cand 

1376 op1 = cand.inputs[0] 

1377 perm1 = op1.kwargs['perm'] 

1378 perm2 = op2.kwargs['perm'] 

1379 if len(perm1) != len(perm2): 

1380 raise RuntimeError( # pragma: no cover 

1381 "Transposition should have the same length " 

1382 "%r, %r." % (perm1, perm2)) 

1383 perm = list(perm1) 

1384 for i in range(len(perm)): # pylint: disable=C0200 

1385 perm[i] = perm1[perm2[i]] 

1386 if list(range(len(perm))) == perm: 

1387 # identity, everything needs to be removed 

1388 new_op = None 

1389 else: 

1390 new_op = op2.__class__( 

1391 op2.full_dim, op2.name, op1.inputs[0], 

1392 perm=tuple(perm)) 

1393 self._replace_node_sequence(new_op, [op1, op2]) 

1394 if verbose: 

1395 print( # pragma: no cover 

1396 "[GraphEinsumSubOp.remove_duplicate_transpose] remove nodes %r" 

1397 " - id=%d,%d + %d perm1=%r perm2=%r -> perm=%r" % ( 

1398 op2.name, id(op1), id(op2), 

1399 id(new_op) if new_op is not None else -1, 

1400 perm1, perm2, perm)) 

1401 

1402 def to_onnx(self, output, *inputs, dtype=None, verbose=False, 

1403 opset=None, **kwargs): 

1404 """ 

1405 Converts the graph into ONNX. 

1406 

1407 :param output: output name 

1408 :param inputs: input names 

1409 :param dtype: type used for all operators 

1410 :param opset: desired opset, None for the last one 

1411 :param verbose: display intermediate operators 

1412 :param kwargs: additional parameter to use when building 

1413 the ONNX graph, list of supported parameters: 

1414 *name*, *ir_version*, *producer_name*, 

1415 *producer_version*, *initializer* 

1416 :return: ONNX graph 

1417 

1418 Not all graphs can be converted into ONNX. Only graphs produced 

1419 with `strategy='numpy'` can be converted otherwise the following 

1420 error shows up: 

1421 

1422 :: 

1423 

1424 NotImplementedError: to_onnx not implemented for 'matmul'. 

1425 """ 

1426 from ...onnx_tools.optim import onnx_remove_node_unused 

1427 

1428 # inputs 

1429 if opset is None: 

1430 opset = __max_supported_opset__ 

1431 if verbose: 

1432 print("[GraphEinsumSubOp.to_onnx] %r -> %s opset=%r " 

1433 "dtype=%r" % (inputs, output, opset, dtype)) 

1434 onx_inputs = [] 

1435 proto = guess_proto_dtype( 

1436 numpy.float32 if dtype is None else dtype) 

1437 lengths = self.metadata['lengths'] 

1438 names = {} 

1439 for inp, le in zip(inputs, lengths): 

1440 if isinstance(inp, tuple): 

1441 name, typ = inp 

1442 if le != len(typ.shape): 

1443 raise ValueError( # pragma: no cover 

1444 "Irreconcialable shapes for input %r: " 

1445 "%r != len(%r)." % (name, le, typ.shape)) 

1446 proto = guess_proto_dtype(guess_numpy_type(typ)) 

1447 onx_inputs.append( 

1448 helper.make_tensor_value_info(name, proto, typ.shape)) 

1449 names[len(names)] = name 

1450 else: 

1451 onx_inputs.append( 

1452 helper.make_tensor_value_info( 

1453 inp, proto, [None for i in range(le)])) 

1454 names[len(names)] = inp 

1455 

1456 # output 

1457 onx_output = helper.make_tensor_value_info( 

1458 output, proto, [None for i in range(lengths[-1])]) 

1459 

1460 # nodes 

1461 nodes = [] 

1462 inits = [] 

1463 if "initializer" in kwargs: 

1464 inits.extend(kwargs['initializer']) 

1465 for op in self: 

1466 for onx_node in op.to_onnx(names, verbose=verbose, opset=opset): 

1467 if hasattr(onx_node, 'output'): 

1468 nodes.append(onx_node) 

1469 else: 

1470 inits.append(onx_node) 

1471 

1472 # last node 

1473 last_node = nodes[-1] 

1474 nodes.append(helper.make_node( 

1475 'Identity', [last_node.output[0]], [output])) 

1476 

1477 # Builds the graph 

1478 model = helper.make_model( 

1479 opset_imports=[helper.make_operatorsetid('', opset)], 

1480 ir_version=kwargs.get('ir_version', get_ir_version(opset)), 

1481 producer_name=kwargs.get('producer_name', 'mlprodict'), 

1482 producer_version=kwargs.get('producer_version', "0.0.dev"), 

1483 graph=helper.make_graph( 

1484 name=kwargs.get('name', 'einsum'), 

1485 inputs=onx_inputs, outputs=[onx_output], 

1486 initializer=inits, nodes=nodes)) 

1487 

1488 return onnx_remove_node_unused(model)