Coverage for mlprodict/onnxrt/validate/validate_benchmark_replay.py: 100%

75 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Measures time processing for ONNX models. 

4""" 

5import pickle 

6import os 

7import sklearn 

8from ...tools.ort_wrapper import InferenceSession 

9from .. import OnnxInference 

10from .validate_helper import default_time_kwargs, measure_time, _multiply_time_kwargs 

11from .validate_benchmark import make_n_rows 

12 

13 

14class SimplifiedOnnxInference: 

15 """ 

16 Simple wrapper around InferenceSession which imitates 

17 @see cl OnnxInference. It only enable *CPUExecutionProvider*. 

18 

19 :param runtime: see :class:`InferenceSession 

20 <mlprodict.tools.ort_wrapper.InferenceSession>` 

21 """ 

22 

23 def __init__(self, ort, runtime='onnxruntime'): 

24 self.sess = InferenceSession(ort, runtime=runtime) 

25 

26 @property 

27 def input_names(self): 

28 "Returns InferenceSession input names." 

29 return [_.name for _ in self.sess.get_inputs()] 

30 

31 def run(self, input): 

32 "Calls InferenceSession.run." 

33 return self.sess.run(None, input) 

34 

35 

36def enumerate_benchmark_replay(folder, runtime='python', time_kwargs=None, 

37 skip_long_test=True, time_kwargs_fact=None, 

38 time_limit=4, verbose=1, fLOG=None): 

39 """ 

40 Replays a benchmark stored with function 

41 :func:`enumerate_validated_operator_opsets 

42 <mlprodict.onnxrt.validate.validate.enumerate_validated_operator_opsets>` 

43 or command line :ref:`validate_runtime <l-cmd-validate_runtime>`. 

44 Enumerates the results. 

45 

46 @param folder folder where to find pickled files, all files must have 

47 *pkl* or *pickle* extension 

48 @param runtime runtime or runtimes 

49 @param time_kwargs to define a more precise way to measure a model 

50 @param skip_long_test skips tests for high values of N if they seem too long 

51 @param time_kwargs_fact see :func:`_multiply_time_kwargs <mlprodict.onnxrt.validate.validate_helper._multiply_time_kwargs>` 

52 @param time_limit to skip the rest of the test after this limit (in second) 

53 @param verbose if >= 1, uses :epkg:`tqdm` 

54 @param fLOG logging function 

55 @return iterator on results 

56 """ 

57 from onnxruntime.capi._pybind_state import Fail as OrtFail # pylint: disable=E0611 

58 

59 files = [_ for _ in os.listdir(folder) if _.endswith( 

60 ".pkl") or _.endswith("_.pickle")] 

61 if len(files) == 0: 

62 raise FileNotFoundError( 

63 f"Unable to find any file in folder '{folder}'.") 

64 

65 if time_kwargs in (None, ''): 

66 time_kwargs = default_time_kwargs() 

67 

68 if isinstance(runtime, str): 

69 runtime = runtime.split(",") 

70 

71 loop = files 

72 if verbose >= 1: 

73 try: 

74 from tqdm import tqdm 

75 loop = tqdm(files) 

76 except ImportError: # pragma: no cover 

77 pass 

78 

79 for pkl in loop: 

80 if "ERROR" in pkl: 

81 # An error. 

82 if verbose >= 2 and fLOG is not None: # pragma: no cover 

83 fLOG( # pragma: no cover 

84 f"[enumerate_benchmark_replay] skip '{pkl}'.") 

85 continue # pragma: no cover 

86 if verbose >= 2 and fLOG is not None: 

87 fLOG(f"[enumerate_benchmark_replay] process '{pkl}'.") 

88 row = {} 

89 with open(os.path.join(folder, pkl), 'rb') as f: 

90 obj = pickle.load(f) 

91 X_test = obj['X_test'] 

92 ort_test = obj['Xort_test'] 

93 onx = obj['onnx_bytes'] 

94 model = obj['skl_model'] 

95 tkw = _multiply_time_kwargs(time_kwargs, time_kwargs_fact, model) 

96 row['folder'] = folder 

97 row['filename'] = pkl 

98 row['n_features'] = X_test.shape[1] 

99 

100 for key in ['assume_finite', 'conv_options', 

101 'init_types', 'idtype', 'method_name', 'n_features', 

102 'name', 'optim', 'opset', 'predict_kwargs', 

103 'output_index', 'problem', 'scenario']: 

104 row[key] = obj['obs_op'][key] 

105 

106 # 'bench-batch', 

107 # 'bench-skl', 

108 

109 oinfs = {} 

110 for rt in runtime: 

111 if rt == 'onnxruntime': 

112 try: 

113 oinfs[rt] = SimplifiedOnnxInference(onx) 

114 except (OrtFail, RuntimeError) as e: # pragma: no cover 

115 row['ERROR'] = str(e) 

116 oinfs[rt] = None 

117 else: 

118 try: 

119 oinfs[rt] = OnnxInference( 

120 onx, runtime=rt, runtime_options=dict( 

121 log_severity_level=3)) 

122 except (OrtFail, RuntimeError) as e: # pragma: no cover 

123 row['ERROR'] = str(e) 

124 oinfs[rt] = None 

125 

126 for k, v in sorted(tkw.items()): 

127 if verbose >= 3 and fLOG is not None: 

128 fLOG( # pragma: no cover 

129 f"[enumerate_benchmark_replay] process n_rows={k} - {v}") 

130 xt = make_n_rows(X_test, k) 

131 number = v['number'] 

132 repeat = v['repeat'] 

133 

134 meth = getattr(model, row['method_name']) 

135 with sklearn.config_context(assume_finite=row['assume_finite']): 

136 skl = measure_time(lambda x: meth(x), xt, 

137 number=number, repeat=repeat, 

138 div_by_number=True) 

139 if verbose >= 4 and fLOG is not None: 

140 fLOG( # pragma: no cover 

141 f"[enumerate_benchmark_replay] skl={skl}") 

142 row['%d-skl-details' % k] = skl 

143 row['%d-skl' % k] = skl['average'] 

144 

145 xto = make_n_rows(ort_test, k) 

146 for rt in runtime: 

147 oinf = oinfs[rt] 

148 if oinf is None: 

149 continue # pragma: no cover 

150 if len(oinf.input_names) != 1: 

151 raise NotImplementedError( # pragma: no cover 

152 "This function only allows one input not {}".format( 

153 len(oinf.input_names))) 

154 name = oinf.input_names[0] 

155 ort = measure_time(lambda x: oinf.run({name: x}), xto, 

156 number=number, repeat=repeat, 

157 div_by_number=True) 

158 if verbose >= 4 and fLOG is not None: 

159 fLOG( # pragma: no cover 

160 f"[enumerate_benchmark_replay] {rt}={ort}") 

161 row['%d-%s-detail' % (k, rt)] = ort 

162 row['%d-%s' % (k, rt)] = ort['average'] 

163 yield row