Coverage for mlprodict/testing/einsum/einsum_bench.py: 100%

81 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Function to measure the performance of einsum decomposition. 

4""" 

5from itertools import permutations 

6import numpy 

7from onnx import helper, TensorProto 

8from cpyquickhelper.numbers import measure_time 

9from ... import __max_supported_opset__, get_ir_version 

10from ...tools.ort_wrapper import InferenceSession 

11from ...onnxrt import OnnxInference 

12from .einsum_impl import decompose_einsum_equation, apply_einsum_sequence 

13 

14 

15def _measure_time(stmt, *x, repeat=5, number=5, div_by_number=True, 

16 first_run=True, max_time=None): 

17 """ 

18 Measures a statement and returns the results as a dictionary. 

19 

20 :param stmt: string 

21 :param *x: inputs 

22 :param repeat: average over *repeat* experiment 

23 :param number: number of executions in one row 

24 :param div_by_number: divide by the number of executions 

25 :param first_run: if True, runs the function once before measuring 

26 :param max_time: execute the statement until the total goes 

27 beyond this time (approximatively), *repeat* is ignored, 

28 *div_by_number* must be set to True 

29 :return: dictionary 

30 

31 See `Timer.repeat 

32 <https://docs.python.org/3/library/timeit.html?timeit.Timer.repeat>`_ 

33 for a better understanding of parameter *repeat* and *number*. 

34 The function returns a duration corresponding to 

35 *number* times the execution of the main statement. 

36 """ 

37 if first_run: 

38 try: 

39 stmt(*x) 

40 except RuntimeError as e: # pragma: no cover 

41 raise RuntimeError(f"{type(x)}-{getattr(x, 'dtype', '?')}") from e 

42 

43 def fct(): 

44 stmt(*x) 

45 

46 if first_run: 

47 fct() 

48 

49 return measure_time(fct, context={}, repeat=repeat, number=number, 

50 div_by_number=div_by_number, max_time=max_time) 

51 

52 

53def _make_einsum_model(equation, opset=__max_supported_opset__): 

54 inputs = equation.split('->')[0].split(',') 

55 

56 model = helper.make_model( 

57 opset_imports=[helper.make_operatorsetid('', opset)], 

58 ir_version=get_ir_version(opset), 

59 producer_name='mlprodict', 

60 producer_version='0.1', 

61 graph=helper.make_graph( 

62 name='einsum_test', 

63 inputs=[ 

64 helper.make_tensor_value_info( 

65 "X%d" % i, TensorProto.FLOAT, None) # pylint: disable=E1101 

66 for i in range(len(inputs))], 

67 outputs=[ 

68 helper.make_tensor_value_info( 

69 "Y", TensorProto.FLOAT, None)], # pylint: disable=E1101 

70 nodes=[ 

71 helper.make_node( 

72 "Einsum", ["X%d" % i for i in range(len(inputs))], ["Y"], 

73 equation=equation) 

74 ] 

75 ) 

76 ) 

77 return model 

78 

79 

80def _make_inputs(equation, shapes): 

81 inputs = equation.split('->')[0].split(',') 

82 dims = [len(i) for i in inputs] 

83 

84 if isinstance(shapes, int): 

85 N = shapes 

86 shapes = [(N, ) * le for le in dims] 

87 else: 

88 if len(shapes) != len(inputs): 

89 raise ValueError( # pragma: no cover 

90 f"Unexpected number of shapes {shapes!r} with equation {equation!r}.") 

91 inputs = [numpy.random.randn(*sh) for sh in shapes] 

92 return [i.astype(numpy.float32) for i in inputs] 

93 

94 

95def einsum_benchmark(equation="abc,cd->abd", shape=30, perm=False, 

96 runtime='python', use_tqdm=False, 

97 number=5, repeat=5, opset=__max_supported_opset__): 

98 """ 

99 Investigates whether or not the decomposing einsum is faster. 

100 

101 :param equation: einsum equation to test 

102 :param shape: an integer (all dimension gets the same size) or 

103 a list of shapes in a string separated with `;`) 

104 :param perm: check on permutation or all letter permutations 

105 :param runtime: numpy, python, onnxruntime 

106 :param use_tqdm: show progress 

107 :param output: output file (usually a csv file or an excel file), 

108 it requires pandas 

109 :param number: usual parameter to measure a function 

110 :param repeat: usual parameter to measure a function 

111 :param opset: target opset 

112 :return: list of dictionaries as an iterator 

113 """ 

114 scenarios = [] 

115 if (isinstance(shape, list) and 

116 all(map(lambda t: isinstance(t, int), shape))): 

117 shape_list = shape 

118 else: 

119 shape_list = [shape] 

120 

121 if perm: 

122 if equation.lower() != equation: 

123 raise ValueError( 

124 "Only equations with lower letters are allowed but equation %r " 

125 "is not." % equation) 

126 letters = list(sorted(set( 

127 c for c in equation if "a" <= c < "z" or "A" <= c < "Z"))) 

128 for p in permutations(letters): 

129 replace = {d: c for c, d in zip(letters, p)} 

130 eq = equation 

131 for k, v in replace.items(): 

132 eq = eq.replace(k, v.upper()) 

133 eq = eq.lower() 

134 for dec in ['einsum', 'dec']: 

135 for sh in shape_list: 

136 scenarios.append((eq, runtime, dec, sh)) 

137 else: 

138 for dec in ['einsum', 'dec']: 

139 for sh in shape_list: 

140 scenarios.append((equation, runtime, dec, sh)) 

141 

142 if use_tqdm: 

143 from tqdm import tqdm # pragma: no cover 

144 loop = tqdm(scenarios) # pragma: no cover 

145 else: 

146 loop = scenarios 

147 

148 for eq, rt, dec, sh in loop: 

149 inputs = _make_inputs(equation, sh) 

150 

151 if dec == 'dec': 

152 seq = decompose_einsum_equation(eq, strategy='numpy', clean=True) 

153 else: 

154 seq = None 

155 

156 if rt == 'numpy': 

157 if dec == 'einsum': 

158 fct = lambda *x, eq=eq: numpy.einsum(eq, *x, optimize=True) 

159 else: 

160 fct = lambda *x, seq=seq: apply_einsum_sequence(seq, *x) 

161 elif rt == 'onnxruntime': 

162 if dec == 'einsum': 

163 onx = _make_einsum_model(equation, opset=opset) 

164 else: 

165 onx = seq.to_onnx('Y', *["X%d" % i for i in range(len(inputs))], 

166 opset=opset) 

167 sess = InferenceSession( 

168 onx.SerializeToString(), 

169 providers=['CPUExecutionProvider']) # pylint: disable=W0612 

170 fct = lambda *x, se=sess: se.run( 

171 None, {"X%d" % i: v for i, v in enumerate(x)}) 

172 elif rt == 'python': 

173 if dec == 'einsum': 

174 onx = _make_einsum_model(equation, opset=opset) 

175 else: 

176 onx = seq.to_onnx('Y', *["X%d" % i for i in range(len(inputs))], 

177 opset=opset) 

178 oinf = OnnxInference(onx) # pylint: disable=W0612 

179 fct = lambda *x, oi=oinf: oi.run( 

180 {"X%d" % i: v for i, v in enumerate(x)}) 

181 else: 

182 raise ValueError(f"Unexpected runtime {rt!r}.") 

183 

184 res = _measure_time(fct, *inputs, repeat=repeat, number=number) 

185 res['rt'] = rt 

186 res['dec'] = dec 

187 res['eq'] = eq 

188 res['shapes'] = ";".join( 

189 map(str, [m.shape for m in inputs])).replace(' ', '') 

190 yield res