Coverage for mlprodict/testing/einsum/einsum_fct.py: 97%

193 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Main functions decomposing einsum computation into 

4more simple functions. 

5""" 

6from itertools import permutations 

7import time 

8import math 

9import numpy 

10from onnx import helper 

11from ...onnx_tools.onnx2py_helper import guess_proto_dtype 

12from ...onnxrt.onnx_micro_runtime import OnnxMicroRuntime 

13from ... import __max_supported_opset__, get_ir_version 

14from .einsum_impl import decompose_einsum_equation, apply_einsum_sequence 

15from .einsum_ml import predict_transposition_cost 

16 

17 

18_einsum_cache = {} 

19 

20 

21def enumerate_cached_einsum(): 

22 """ 

23 Enumerates all cached einsum function. 

24 """ 

25 global _einsum_cache # pylint: disable=W0603,W0602 

26 for k, v in _einsum_cache.items(): 

27 yield k, v 

28 

29 

30class CachedEinsum: 

31 """ 

32 Stores all the necessary information to cache the preprocessing 

33 of a an einsum equation. 

34 

35 :param equation: numpy equation 

36 :param runtime: see :func:`einsum 

37 <mlprodict.testing.einsum.einsum_fct.einsum>` 

38 :param opset: ONNX opset 

39 :param optimize: finds the best letter permutation 

40 :param dtype: dtype 

41 :param decompose: to decompose Einsum operator or to keep it as is 

42 :param key: key used to cache this class 

43 :param strategy: optimization strategy 

44 :param verbose: displays progress information 

45 

46 The class creates the following attributes: 

47 

48 * `equation_` corresponding to the best equivalent equation 

49 * `graph_`: the corresponding graph returned by function 

50 :func:`decompose_einsum_equation 

51 <mlprodict.testing.einsum.einsum_impl.decompose_einsum_equation> ` 

52 * `onnx_`: if a conversion to onnx is used, stores the onnx graph 

53 * `runtime_`: a function used by `__call__`, calls the runtime 

54 """ 

55 

56 def __init__(self, equation, runtime='batch_dot', opset=None, 

57 optimize=False, dtype=numpy.float64, decompose=True, 

58 strategy=None, verbose=None, key=None): 

59 self.equation = equation 

60 self.runtime = runtime 

61 self.opset = opset 

62 self.optimize = optimize 

63 self.dtype = dtype 

64 self.decompose = decompose 

65 self.strategy = strategy 

66 self.verbose = verbose 

67 self.key = key 

68 

69 def __repr__(self): 

70 "usual" 

71 return "%s(%r, %r, %r, %r, %r, %r, %r, key=%r)" % ( 

72 self.__class__.__name__, self.equation, self.runtime, 

73 self.opset, self.optimize, self.dtype, self.decompose, 

74 self.strategy, self.key) 

75 

76 def default_inputs(self, N=None): 

77 """ 

78 Returns default inputs (reshaped numpy.arange + 0.7i). 

79 

80 :param N: dimension (all dimension have the same size) 

81 

82 If *N is None*, N is given a size depending on the number of letters 

83 to avoid spending too much time on optimization. 

84 """ 

85 if N is None: 

86 letters = set(c for c in self.equation 

87 if "a" <= c <= "z" or "A" <= c <= "Z") 

88 nn = math.factorial(len(letters)) 

89 N = max(int(2 ** 11 / nn), 4) 

90 N = min(N, 15) 

91 inps = self.equation.split('->')[0].split(',') 

92 lens = [len(s) for s in inps] 

93 inputs = [numpy.arange(N ** d).reshape((N,) * d) for d in lens] 

94 inputs = [(i + 0.7 * ii).astype(self.dtype) 

95 for ii, i in enumerate(inputs)] 

96 return inputs 

97 

98 def build(self): 

99 """ 

100 Preprocesses the equation builds whatever is necessary 

101 to compute the result of the einsum equation. 

102 """ 

103 if not self.optimize and not hasattr(self, 'equation_'): 

104 self.equation_ = self.equation 

105 elif self.strategy is None: 

106 self.equation_ = self._build_optimize() 

107 elif self.strategy == 'ml': 

108 self.equation_ = self._build_optimize_ml() 

109 else: 

110 raise ValueError( # pragma error 

111 f"Unknown strategy {self.strategy!r}.") 

112 self.build_runtime() 

113 

114 def _build_optimize(self): 

115 # loops over all permutations 

116 if self.equation.lower() != self.equation: 

117 raise RuntimeError( # pragma: no cover 

118 f"Only lower equation can be optimized, {self.equation!r} is not.") 

119 letters = list( 

120 sorted(set(c for c in self.equation if "a" <= c <= "z"))) 

121 possible = list(permutations(letters)) 

122 possible.insert(0, letters) 

123 if self.verbose: 

124 from tqdm import tqdm # pragma: no cover 

125 subset = tqdm(possible) # pragma: no cover 

126 else: 

127 subset = possible 

128 best = [] 

129 confs = [] 

130 very_best = None 

131 inputs = None 

132 for perm in subset: 

133 replace = {d: c for c, d in zip(letters, perm)} 

134 eq = self.equation 

135 for k, v in replace.items(): 

136 eq = eq.replace(k, v.upper()) 

137 eq = eq.lower() 

138 inst = CachedEinsum(eq, runtime=self.runtime, opset=self.opset, 

139 optimize=False, dtype=self.dtype, 

140 decompose=self.decompose) 

141 inst.build() 

142 if inputs is None: 

143 inputs = inst.default_inputs() 

144 inst(*inputs) 

145 ts = time.perf_counter() 

146 for _ in range(0, 10): 

147 inst(*inputs) 

148 delta = time.perf_counter() - ts 

149 confs.append((delta, eq)) 

150 if len(best) < 10: 

151 best.append((delta, eq)) 

152 best.sort() 

153 elif delta < best[-1][0]: 

154 best[-1] = (delta, eq) 

155 best.sort() 

156 if self.verbose and ( 

157 very_best is None or very_best != best[0][0]): 

158 very_best = best[0][0] 

159 subset.set_description("%1.2g rtbest=%r" % best[0]) 

160 self.optimized_ = best 

161 self.timed_permutations_ = confs 

162 return best[0][1] 

163 

164 def _build_optimize_ml(self): 

165 # loops over all permutations 

166 if self.equation.lower() != self.equation: 

167 raise RuntimeError( # pragma: no cover 

168 f"Only lower equation can be optimized, {self.equation!r} is not.") 

169 letters = list( 

170 sorted(set(c for c in self.equation if "a" <= c <= "z"))) 

171 possible = list(permutations(letters)) 

172 possible.insert(0, letters) 

173 if self.verbose: 

174 from tqdm import tqdm # pragma: no cover 

175 subset = tqdm(possible) # pragma: no cover 

176 else: 

177 subset = possible 

178 best = [] 

179 confs = [] 

180 very_best = None 

181 inputs = None 

182 for perm in subset: 

183 replace = {d: c for c, d in zip(letters, perm)} 

184 eq = self.equation 

185 for k, v in replace.items(): 

186 eq = eq.replace(k, v.upper()) 

187 eq = eq.lower() 

188 inst = CachedEinsum(eq, runtime=self.runtime, opset=self.opset, 

189 optimize=False, dtype=self.dtype, 

190 decompose=self.decompose) 

191 inst.build() 

192 if inputs is None: 

193 inputs = inst.default_inputs() 

194 if hasattr(inst, 'onnx_'): 

195 onx = inst.onnx_ 

196 else: 

197 from skl2onnx.common.data_types import FloatTensorType # delayed 

198 inits = [ 

199 ('X%d' % i, FloatTensorType(list(inputs[i].shape))) 

200 for i in range(len(inputs))] 

201 onx = inst.graph_.to_onnx('Y', *inits, opset=self.opset) 

202 

203 rt = OnnxMicroRuntime(onx) 

204 dict_inputs = {'X%d' % i: inp for i, inp in enumerate(inputs)} 

205 out = rt.run(dict_inputs) 

206 

207 transposes = [] 

208 for node in onx.graph.node: # pylint: disable=E1101 

209 if node.op_type == 'Transpose': 

210 shape = [(d * 10 if d > 1 else d) 

211 for d in out[node.input[0]].shape] 

212 transposes.append( 

213 [shape, list(node.attribute[0].ints)]) 

214 

215 delta = sum(max(0, predict_transposition_cost(*v)) 

216 for v in transposes) 

217 

218 confs.append((delta, eq)) 

219 if len(best) < 10: 

220 best.append((delta, eq)) 

221 best.sort() 

222 elif delta < best[-1][0]: 

223 best[-1] = (delta, eq) 

224 best.sort() 

225 if self.verbose and ( 

226 very_best is None or very_best != best[0][0]): 

227 very_best = best[0][0] 

228 subset.set_description("%1.2g mlbest=%r" % best[0]) 

229 self.optimized_ = best 

230 self.timed_permutations_ = confs 

231 return best[0][1] 

232 

233 def build_onnx_einsum(self, input_names): 

234 """ 

235 Builds an ONNX graph with a single einsum operator. 

236 """ 

237 opset = (self.opset if self.opset is not None 

238 else __max_supported_opset__) 

239 ir_version = get_ir_version(opset) 

240 proto_type = guess_proto_dtype( 

241 numpy.float32 if self.dtype is None else self.dtype) 

242 

243 model = helper.make_model( 

244 opset_imports=[helper.make_operatorsetid('', opset)], 

245 ir_version=ir_version, 

246 producer_name='mlprodict', 

247 producer_version='0.0.1', 

248 graph=helper.make_graph( 

249 name='einsum', 

250 inputs=[helper.make_tensor_value_info(n, proto_type, None) 

251 for n in input_names], 

252 outputs=[helper.make_tensor_value_info("Y", proto_type, None)], 

253 nodes=[ 

254 helper.make_node( 

255 'Einsum', input_names, ["Y"], equation=self.equation_)])) 

256 return model 

257 

258 def build_runtime(self): 

259 """ 

260 Builds the runtime associated to the 

261 equation `self.equation_`. 

262 """ 

263 if self.decompose: 

264 self.graph_ = decompose_einsum_equation( 

265 self.equation_, strategy='numpy', clean=True) 

266 if self.runtime == 'batch_dot': 

267 self.runtime_ = lambda *inputs: apply_einsum_sequence( 

268 self.graph_, *inputs) 

269 elif self.runtime in ('python', 'onnxruntime1'): 

270 from ...onnxrt import OnnxInference 

271 n_inputs = len(self.graph_.metadata['lengths']) - 1 

272 input_names = ['X%d' % i for i in range(n_inputs)] 

273 self.onnx_names_ = input_names 

274 onx = self.graph_.to_onnx( 

275 'Y', *input_names, opset=self.opset, dtype=self.dtype) 

276 self.onnx_ = onx 

277 rt = ('python_compiled' 

278 if self.runtime == 'python' 

279 else self.runtime) 

280 self.oinf_ = OnnxInference( 

281 self.onnx_, runtime=rt, runtime_options=dict( 

282 log_severity_level=3)) 

283 self.runtime_ = lambda *inputs: self.oinf_.run( 

284 {i: v for i, v in zip(self.onnx_names_, inputs)})['Y'] 

285 else: 

286 raise ValueError( # pragma: no cover 

287 f"Unexpected runtime {self.runtime!r}.") 

288 else: 

289 if self.runtime in ('python', 'onnxruntime1'): 

290 from ...onnxrt import OnnxInference 

291 n_inputs = len(self.equation.split('->')[0].split(',')) 

292 input_names = ['X%d' % i for i in range(n_inputs)] 

293 self.onnx_ = self.build_onnx_einsum(input_names) 

294 self.onnx_names_ = input_names 

295 rt = ('python_compiled' 

296 if self.runtime == 'python' 

297 else self.runtime) 

298 self.oinf_ = OnnxInference( 

299 self.onnx_, runtime=rt, runtime_options=dict( 

300 log_severity_level=3)) 

301 self.runtime_ = lambda *inputs: self.oinf_.run( 

302 {i: v for i, v in zip(self.onnx_names_, inputs)})['Y'] 

303 else: 

304 raise ValueError( # pragma: no cover 

305 f"Unexpected runtime {self.runtime!r}.") 

306 

307 def __call__(self, *inputs): 

308 """ 

309 Calls the runtime `self.runtime_`. 

310 """ 

311 if not hasattr(self, 'runtime_'): 

312 raise RuntimeError( # pragma: no cover 

313 "Method build_runtime was not called.") 

314 return self.runtime_(*inputs) 

315 

316 @staticmethod 

317 def build_einsum(equation, runtime, opset, optimize, 

318 dtype, decompose=True, strategy=None, 

319 verbose=None, key=None): 

320 """ 

321 Creates an instance of *CachedEinsum*. 

322 """ 

323 inst = CachedEinsum(equation, runtime=runtime, opset=opset, 

324 optimize=optimize, dtype=dtype, 

325 decompose=decompose, strategy=strategy, 

326 verbose=verbose, key=key) 

327 inst.build() 

328 return inst 

329 

330 

331def _einsum(equation, dtype, optimize=False, runtime="batch_dot", 

332 cache=True, opset=None, decompose=True, strategy=None, 

333 verbose=None): 

334 global _einsum_cache # pylint: disable=W0603,W0602 

335 cached = None 

336 if cache: 

337 key = equation, runtime, opset, optimize, dtype, decompose, strategy 

338 cached = _einsum_cache.get(key, None) 

339 else: 

340 key = None 

341 if cached is None: 

342 cached = CachedEinsum.build_einsum( 

343 equation, runtime, opset, optimize, 

344 dtype, decompose=decompose, strategy=strategy, 

345 verbose=verbose, key=key) 

346 else: 

347 cache = False 

348 if cache: 

349 _einsum_cache[key] = cached 

350 return cached 

351 

352 

353def optimize_decompose_einsum_equation( 

354 equation, dtype, optimize=False, runtime="batch_dot", 

355 cache=True, opset=None, decompose=True, strategy=None, 

356 verbose=None): 

357 """ 

358 Proposes a new implementation of :epkg:`numpy:einsum`. 

359 It does not allow expresion using `...` and expects 

360 a right member. 

361 

362 :param equation: einsum equation 

363 :param optimize: permutes all letters to find the best 

364 permutation 

365 :param runtime: runtime used to compute the results once the 

366 computation graph is produced (see below) 

367 :param cache: if True, the function stores the preprocessing 

368 done for a specific equation, the second call with the same 

369 equation is much faster 

370 :param opset: ONNX opset to use for some runtimes 

371 :param decompose: by default, the function decomposes 

372 the equation into more simple operators but it can keep 

373 the original ONNX einsum operator. 

374 :param strategy: optimisation strategy (see below) 

375 :param verbose: display progress if optimize is True 

376 :return: einsum result 

377 

378 The available runtimes are: 

379 

380 * `batch_dot`: the runtime is @see fn apply_einsum_sequence, 

381 * `python`: one ONNX graph executed with a python runtime, 

382 * `onnxruntime1`: one ONNX graph executed with :epkg:`onnxruntime`. 

383 

384 The optimisation strategy can be: 

385 

386 * `None`: the same runtime is used to find the best permutation of letters 

387 * `'ml'`: a machine learned model is used to predict the 

388 best permutation of letters, this model comes from 

389 notebook :ref:`onnxoperatorcostrst`. 

390 

391 The function works in two steps: 

392 

393 * first step analyses the equation to produce a computation graph, 

394 this graph can also be converted into ONNX, 

395 * second step runs the graph whatever the graph is. 

396 

397 The function returns an object of type @see cl CachedEinsum 

398 which has the following members after optimization: 

399 

400 * `equation_` corresponding to the best equivalent equation 

401 * `graph_`: the corresponding graph returned by function 

402 :func:`decompose_einsum_equation 

403 <mlprodict.testing.einsum.einsum_impl.decompose_einsum_equation> ` 

404 * `onnx_`: if a conversion to onnx is used, stores the onnx graph 

405 * `runtime_`: a function used by `__call__`, calls the runtime 

406 * `oinf_`: an object of type @see cl OnnxInference 

407 * `timed_permutations_`: memorizes the results of the optimization 

408 

409 .. runpython:: 

410 :showcode: 

411 

412 import numpy 

413 from mlprodict.testing.einsum import optimize_decompose_einsum_equation 

414 

415 seq_opt = optimize_decompose_einsum_equation( 

416 "bsnh,btnh->bnts", numpy.float64, strategy='ml', verbose=1, 

417 runtime="python", optimize=True) 

418 

419 print("best equation:", seq_opt.equation_) 

420 

421 """ 

422 res = _einsum(equation, dtype, optimize=optimize, runtime=runtime, 

423 cache=cache, opset=opset, decompose=decompose, 

424 strategy=strategy, verbose=verbose) 

425 return res 

426 

427 

428def einsum(equation, *inputs, optimize=False, runtime="batch_dot", 

429 cache=True, opset=None, decompose=True, 

430 strategy=None, verbose=None): 

431 """ 

432 Proposes a new implementation of :epkg:`numpy:einsum`. 

433 It does not allow expresion using `...` and expects 

434 a right member. 

435 

436 :param equation: einsum equation 

437 :param inputs: inputs 

438 :param optimize: permutes all letters to find the best 

439 permutation 

440 :param runtime: runtime used to compute the results once the 

441 computation graph is produced (see below) 

442 :param cache: if True, the function stores the preprocessing 

443 done for a specific equation, the second call with the same 

444 equation is much faster 

445 :param opset: ONNX opset to use for some runtimes 

446 :param decompose: by default, the function decomposes 

447 the equation into more simple operators but it can keep 

448 the original ONNX einsum operator. 

449 :param strategy: optimisation strategy (see below) 

450 :param verbose: display progress if optimize is True 

451 :return: einsum result 

452 

453 The available runtimes are: 

454 

455 * `batch_dot`: the runtime is @see fn apply_einsum_sequence, 

456 * `python`: one ONNX graph executed with a python runtime, 

457 * `onnxruntime1`: one ONNX graph executed with :epkg:`onnxruntime`. 

458 

459 The optimisation strategy can be: 

460 

461 * `None`: the same runtime is used to find the best permutation of letters 

462 * `'ml'`: a machine learned model is used to predict the 

463 best permutation of letters, this model comes from 

464 notebook :ref:`onnxoperatorcostrst`. 

465 

466 The function works in two steps: 

467 

468 * first step analyses the equation to produce a computation graph, 

469 this graph can also be converted into ONNX, 

470 * second step runs the graph whatever the graph is. 

471 

472 Further details are available in the documentation of function 

473 @see fn optimize_decompose_einsum_equation. 

474 The function works the same way as :epkg:`numpy:einsum`: 

475 

476 .. runpython:: 

477 :showcode: 

478 

479 import numpy 

480 from mlprodict.testing.einsum import einsum 

481 

482 equation = "abc,cd->abd" 

483 

484 m1 = numpy.random.randn(2, 2, 2) 

485 m2 = numpy.random.randn(2, 2) 

486 

487 np = numpy.einsum(equation, m1, m2) 

488 print('numpy.einsum') 

489 print(np) 

490 

491 print('mlprodict.testing.einsum') 

492 mp = einsum(equation, m1, m2) 

493 print(mp) 

494 

495 In some case, the einsum implementation can be optimized by looping 

496 on possible permutation: 

497 

498 .. runpython:: 

499 :showcode: 

500 :process: 

501 

502 import timeit 

503 import numpy 

504 from mlprodict.testing.einsum import einsum 

505 from mlprodict.testing.einsum.einsum_fct import enumerate_cached_einsum 

506 

507 equation = "cab,cd->ad" 

508 

509 m1 = numpy.random.randn(20, 20, 20) 

510 m2 = numpy.random.randn(20, 20) 

511 

512 print('numpy.einsum', 

513 timeit.timeit('numpy.einsum(equation, m1, m2)', 

514 number=200, 

515 globals=globals())) 

516 

517 einsum(equation, m1, m2) 

518 print('einsum', 

519 timeit.timeit('einsum(equation, m1, m2)', 

520 number=200, 

521 globals=globals())) 

522 

523 einsum(equation, m1, m2, runtime='python') 

524 print('einsum-python', 

525 timeit.timeit('einsum(equation, m1, m2, runtime="python")', 

526 number=200, 

527 globals=globals())) 

528 

529 einsum(equation, m1, m2, runtime='onnxruntime1') 

530 print('einsum-onnxruntime1', 

531 timeit.timeit('einsum(equation, m1, m2, runtime="onnxruntime1")', 

532 number=200, 

533 globals=globals())) 

534 

535 einsum(equation, m1, m2, runtime='onnxruntime1', optimize=True, verbose=1) 

536 print('einsum-onnxruntime1', 

537 timeit.timeit('einsum(equation, m1, m2, runtime="onnxruntime1", optimize=True)', 

538 number=200, 

539 globals=globals())) 

540 

541 print("list of cached einsum equations") 

542 for k, v in enumerate_cached_einsum(): 

543 print(k, v.equation, v.equation_) 

544 

545 The last example shows the time taken by every function: 

546 

547 .. runpython:: 

548 :showcode: 

549 :process: 

550 

551 import os 

552 from pyquickhelper.pycode.profiling import profile 

553 import numpy 

554 from mlprodict.testing.einsum import einsum 

555 from mlprodict.testing.einsum.einsum_fct import enumerate_cached_einsum 

556 from mlprodict import __file__ as path 

557 

558 root = os.path.dirname(path) 

559 

560 equation = "cab,cd->ad" 

561 

562 m1 = numpy.random.randn(200, 20, 20) 

563 m2 = numpy.random.randn(200, 20) 

564 

565 def clean(txt): 

566 txt = txt.replace(root, "mlprodict") 

567 return "\\n".join(txt.split("\\n")[:30]) 

568 

569 def fct1(): 

570 for i in range(100): 

571 einsum(equation, m1, m2, cache=False) 

572 

573 print("Profile cache with default runtime.") 

574 res = profile(fct1) 

575 print(root) 

576 print(clean(res[1])) 

577 

578 def fct2(): 

579 for i in range(100): 

580 einsum(equation, m1, m2, cache=False, runtime='python') 

581 

582 print("Profile cache with runtime='python'.") 

583 res = profile(fct2) 

584 print(root) 

585 print(clean(res[1])) 

586 

587 

588 def fct3(): 

589 for i in range(100): 

590 einsum(equation, m1, m2, cache=True) 

591 

592 einsum(equation, m1, m2, cache=True) 

593 print("Profile execution with default runtime.") 

594 res = profile(fct3) 

595 print(root) 

596 print(clean(res[1])) 

597 

598 

599 

600 def fct4(): 

601 for i in range(100): 

602 einsum(equation, m1, m2, cache=True, runtime='python') 

603 

604 einsum(equation, m1, m2, cache=True, runtime='python') 

605 print("Profile execution with runtime='python'.") 

606 res = profile(fct4) 

607 print(root) 

608 print(clean(res[1])) 

609 

610 

611 def fct5(): 

612 for i in range(100): 

613 einsum(equation, m1, m2, cache=True, runtime='onnxruntime1') 

614 

615 einsum(equation, m1, m2, cache=True, runtime='onnxruntime1') 

616 print("Profile execution with runtime='onnxruntime1'.") 

617 res = profile(fct5) 

618 print(root) 

619 print(clean(res[1])) 

620 """ 

621 if len(inputs) == 0: 

622 raise ValueError("No inputs found.") # pragma: no cover 

623 dtypes = set(i.dtype for i in inputs) 

624 if len(dtypes) != 1: 

625 raise ValueError( # pragma: no cover 

626 "All inputs do not have the same type (%r), " 

627 "all of them should be cast before called einsum." 

628 "" % dtypes) 

629 cached = optimize_decompose_einsum_equation( 

630 equation, inputs[0].dtype, optimize=optimize, 

631 runtime=runtime, cache=cache, opset=opset, 

632 decompose=decompose, strategy=strategy, verbose=verbose) 

633 return cached(*inputs)