Coverage for mlprodict/testing/einsum/einsum_bench.py: 100%
81 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Function to measure the performance of einsum decomposition.
4"""
5from itertools import permutations
6import numpy
7from onnx import helper, TensorProto
8from cpyquickhelper.numbers import measure_time
9from ... import __max_supported_opset__, get_ir_version
10from ...tools.ort_wrapper import InferenceSession
11from ...onnxrt import OnnxInference
12from .einsum_impl import decompose_einsum_equation, apply_einsum_sequence
15def _measure_time(stmt, *x, repeat=5, number=5, div_by_number=True,
16 first_run=True, max_time=None):
17 """
18 Measures a statement and returns the results as a dictionary.
20 :param stmt: string
21 :param *x: inputs
22 :param repeat: average over *repeat* experiment
23 :param number: number of executions in one row
24 :param div_by_number: divide by the number of executions
25 :param first_run: if True, runs the function once before measuring
26 :param max_time: execute the statement until the total goes
27 beyond this time (approximatively), *repeat* is ignored,
28 *div_by_number* must be set to True
29 :return: dictionary
31 See `Timer.repeat
32 <https://docs.python.org/3/library/timeit.html?timeit.Timer.repeat>`_
33 for a better understanding of parameter *repeat* and *number*.
34 The function returns a duration corresponding to
35 *number* times the execution of the main statement.
36 """
37 if first_run:
38 try:
39 stmt(*x)
40 except RuntimeError as e: # pragma: no cover
41 raise RuntimeError(f"{type(x)}-{getattr(x, 'dtype', '?')}") from e
43 def fct():
44 stmt(*x)
46 if first_run:
47 fct()
49 return measure_time(fct, context={}, repeat=repeat, number=number,
50 div_by_number=div_by_number, max_time=max_time)
53def _make_einsum_model(equation, opset=__max_supported_opset__):
54 inputs = equation.split('->')[0].split(',')
56 model = helper.make_model(
57 opset_imports=[helper.make_operatorsetid('', opset)],
58 ir_version=get_ir_version(opset),
59 producer_name='mlprodict',
60 producer_version='0.1',
61 graph=helper.make_graph(
62 name='einsum_test',
63 inputs=[
64 helper.make_tensor_value_info(
65 "X%d" % i, TensorProto.FLOAT, None) # pylint: disable=E1101
66 for i in range(len(inputs))],
67 outputs=[
68 helper.make_tensor_value_info(
69 "Y", TensorProto.FLOAT, None)], # pylint: disable=E1101
70 nodes=[
71 helper.make_node(
72 "Einsum", ["X%d" % i for i in range(len(inputs))], ["Y"],
73 equation=equation)
74 ]
75 )
76 )
77 return model
80def _make_inputs(equation, shapes):
81 inputs = equation.split('->')[0].split(',')
82 dims = [len(i) for i in inputs]
84 if isinstance(shapes, int):
85 N = shapes
86 shapes = [(N, ) * le for le in dims]
87 else:
88 if len(shapes) != len(inputs):
89 raise ValueError( # pragma: no cover
90 f"Unexpected number of shapes {shapes!r} with equation {equation!r}.")
91 inputs = [numpy.random.randn(*sh) for sh in shapes]
92 return [i.astype(numpy.float32) for i in inputs]
95def einsum_benchmark(equation="abc,cd->abd", shape=30, perm=False,
96 runtime='python', use_tqdm=False,
97 number=5, repeat=5, opset=__max_supported_opset__):
98 """
99 Investigates whether or not the decomposing einsum is faster.
101 :param equation: einsum equation to test
102 :param shape: an integer (all dimension gets the same size) or
103 a list of shapes in a string separated with `;`)
104 :param perm: check on permutation or all letter permutations
105 :param runtime: numpy, python, onnxruntime
106 :param use_tqdm: show progress
107 :param output: output file (usually a csv file or an excel file),
108 it requires pandas
109 :param number: usual parameter to measure a function
110 :param repeat: usual parameter to measure a function
111 :param opset: target opset
112 :return: list of dictionaries as an iterator
113 """
114 scenarios = []
115 if (isinstance(shape, list) and
116 all(map(lambda t: isinstance(t, int), shape))):
117 shape_list = shape
118 else:
119 shape_list = [shape]
121 if perm:
122 if equation.lower() != equation:
123 raise ValueError(
124 "Only equations with lower letters are allowed but equation %r "
125 "is not." % equation)
126 letters = list(sorted(set(
127 c for c in equation if "a" <= c < "z" or "A" <= c < "Z")))
128 for p in permutations(letters):
129 replace = {d: c for c, d in zip(letters, p)}
130 eq = equation
131 for k, v in replace.items():
132 eq = eq.replace(k, v.upper())
133 eq = eq.lower()
134 for dec in ['einsum', 'dec']:
135 for sh in shape_list:
136 scenarios.append((eq, runtime, dec, sh))
137 else:
138 for dec in ['einsum', 'dec']:
139 for sh in shape_list:
140 scenarios.append((equation, runtime, dec, sh))
142 if use_tqdm:
143 from tqdm import tqdm # pragma: no cover
144 loop = tqdm(scenarios) # pragma: no cover
145 else:
146 loop = scenarios
148 for eq, rt, dec, sh in loop:
149 inputs = _make_inputs(equation, sh)
151 if dec == 'dec':
152 seq = decompose_einsum_equation(eq, strategy='numpy', clean=True)
153 else:
154 seq = None
156 if rt == 'numpy':
157 if dec == 'einsum':
158 fct = lambda *x, eq=eq: numpy.einsum(eq, *x, optimize=True)
159 else:
160 fct = lambda *x, seq=seq: apply_einsum_sequence(seq, *x)
161 elif rt == 'onnxruntime':
162 if dec == 'einsum':
163 onx = _make_einsum_model(equation, opset=opset)
164 else:
165 onx = seq.to_onnx('Y', *["X%d" % i for i in range(len(inputs))],
166 opset=opset)
167 sess = InferenceSession(
168 onx.SerializeToString(),
169 providers=['CPUExecutionProvider']) # pylint: disable=W0612
170 fct = lambda *x, se=sess: se.run(
171 None, {"X%d" % i: v for i, v in enumerate(x)})
172 elif rt == 'python':
173 if dec == 'einsum':
174 onx = _make_einsum_model(equation, opset=opset)
175 else:
176 onx = seq.to_onnx('Y', *["X%d" % i for i in range(len(inputs))],
177 opset=opset)
178 oinf = OnnxInference(onx) # pylint: disable=W0612
179 fct = lambda *x, oi=oinf: oi.run(
180 {"X%d" % i: v for i, v in enumerate(x)})
181 else:
182 raise ValueError(f"Unexpected runtime {rt!r}.")
184 res = _measure_time(fct, *inputs, repeat=repeat, number=number)
185 res['rt'] = rt
186 res['dec'] = dec
187 res['eq'] = eq
188 res['shapes'] = ";".join(
189 map(str, [m.shape for m in inputs])).replace(' ', '')
190 yield res