Coverage for mlprodict/cli/einsum.py: 100%

27 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Command line to check einsum scenarios. 

4""" 

5import os 

6 

7 

8def einsum_test(equation="abc,cd->abd", shape="30", perm=False, 

9 runtime='python', verbose=1, fLOG=print, 

10 output=None, number=5, repeat=5): 

11 """ 

12 Investigates whether or not the decomposing einsum is faster. 

13 

14 :param equation: einsum equation to test 

15 :param shape: an integer (all dimension gets the same size) or 

16 a list of shapes in a string separated with `;`) or 

17 a list of integer to try out multiple shapes, 

18 example: `5`, `(5,5,5),(5,5)`, `5,6` 

19 :param perm: check on permutation or all letter permutations 

20 :param runtime: `'numpy'`, `'python'`, `'onnxruntime'` 

21 :param verbose: verbose 

22 :param fLOG: logging function 

23 :param output: output file (usually a csv file or an excel file), 

24 it requires pandas 

25 :param number: usual parameter to measure a function 

26 :param repeat: usual parameter to measure a function 

27 

28 .. cmdref:: 

29 :title: Investigates whether or not the decomposing einsum is faster. 

30 :cmd: -m mlprodict einsum_test --help 

31 :lid: l-cmd-einsum_test 

32 

33 The command checks whether or not decomposing an einsum function 

34 is faster than einsum implementation. 

35 

36 Example:: 

37 

38 python -m mlprodict einsum_test --equation="abc,cd->abd" --output=res.csv 

39 """ 

40 from ..testing.einsum.einsum_bench import einsum_benchmark # pylint: disable=E0402 

41 

42 perm = perm in ('True', '1', 1, True) 

43 if "(" not in shape: 

44 if "," in shape: 

45 shape = list(map(int, shape.split(","))) 

46 else: 

47 shape = int(shape) 

48 else: 

49 shapes = shape.replace('(', '').replace(')', '').split(";") 

50 shape = [] 

51 for sh in shapes: 

52 spl = sh.split(',') 

53 shape.append(tuple(map(int, spl))) 

54 verbose = int(verbose) 

55 number = int(number) 

56 repeat = int(repeat) 

57 

58 res = einsum_benchmark(equation=equation, shape=shape, perm=perm, 

59 runtime=runtime, use_tqdm=verbose > 0, 

60 number=number, repeat=repeat) 

61 if output not in ('', None): 

62 import pandas 

63 df = pandas.DataFrame(res) 

64 ext = os.path.splitext(output)[-1] 

65 if ext == '.csv': 

66 df.to_csv(output, index=False) 

67 fLOG(f'[einsum_test] wrote file {output!r}.') 

68 elif ext == '.xlsx': 

69 df.to_excel(output, index=False) 

70 fLOG(f'[einsum_test] wrote file {output!r}.') 

71 else: 

72 raise ValueError( # pragma: no cover 

73 f"Unknown extension {ext!r} in file {output!r}.") 

74 else: 

75 for r in res: # pragma: no cover 

76 fLOG(r)