Coverage for src/pymlbenchmark/external/onnxruntime_perf.py: 99%

99 statements  

« prev     ^ index     » next       coverage.py v7.2.1, created at 2023-03-08 00:27 +0100

1""" 

2@file 

3@brief Implements a benchmark about performance for :epkg:`onnxruntime` 

4""" 

5import contextlib 

6from collections import OrderedDict 

7from io import BytesIO, StringIO 

8import numpy 

9from numpy.testing import assert_almost_equal 

10import pandas 

11from sklearn.ensemble._forest import BaseForest 

12from sklearn.tree._classes import BaseDecisionTree 

13from mlprodict.onnxrt import OnnxInference 

14from mlprodict import __max_supported_opset__, get_ir_version 

15from ..benchmark import BenchPerfTest 

16from ..benchmark.sklearn_helper import get_nb_skl_base_estimators 

17 

18 

19class OnnxRuntimeBenchPerfTest(BenchPerfTest): 

20 """ 

21 Specific test to compare computing time predictions 

22 with :epkg:`scikit-learn` and :epkg:`onnxruntime`. 

23 See example :ref:`l-example-onnxruntime-logreg`. 

24 The class requires the following modules to be installed: 

25 :epkg:`onnx`, :epkg:`onnxruntime`, :epkg:`skl2onnx`, 

26 :epkg:`mlprodict`. 

27 """ 

28 

29 def __init__(self, estimator, dim=None, N_fit=100000, 

30 runtimes=('python_compiled', 'onnxruntime1'), 

31 onnx_options=None, dtype=numpy.float32, 

32 **opts): 

33 """ 

34 @param estimator estimator class 

35 @param dim number of features 

36 @param N_fit number of observations to fit an estimator 

37 @param runtimes runtimes to test for class :epkg:`OnnxInference` 

38 @param opts training settings 

39 @param onnx_options ONNX conversion options 

40 @param dtype dtype (float32 or float64) 

41 """ 

42 # These libraries are optional. 

43 from skl2onnx import to_onnx # pylint: disable=E0401,C0415 

44 from skl2onnx.common.data_types import FloatTensorType, DoubleTensorType # pylint: disable=E0401,C0415 

45 

46 if dim is None: 

47 raise RuntimeError( # pragma: no cover 

48 "dim must be defined.") 

49 BenchPerfTest.__init__(self, **opts) 

50 

51 allowed = {"max_depth"} 

52 opts = {k: v for k, v in opts.items() if k in allowed} 

53 self.dtype = dtype 

54 self.skl = estimator(**opts) 

55 X, y = self._get_random_dataset(N_fit, dim) 

56 try: 

57 self.skl.fit(X, y) 

58 except Exception as e: # pragma: no cover 

59 raise RuntimeError("X.shape={}\nopts={}\nTraining failed for {}".format( 

60 X.shape, opts, self.skl)) from e 

61 

62 if dtype == numpy.float64: 

63 initial_types = [('X', DoubleTensorType([None, X.shape[1]]))] 

64 elif dtype == numpy.float32: 

65 initial_types = [('X', FloatTensorType([None, X.shape[1]]))] 

66 else: 

67 raise ValueError( # pragma: no cover 

68 "Unable to convert the model into ONNX, unsupported dtype {}.".format(dtype)) 

69 self.logconvert = StringIO() 

70 with contextlib.redirect_stdout(self.logconvert): 

71 with contextlib.redirect_stderr(self.logconvert): 

72 onx = to_onnx(self.skl, initial_types=initial_types, 

73 options=onnx_options, 

74 target_opset=__max_supported_opset__) 

75 onx.ir_version = get_ir_version(__max_supported_opset__) 

76 

77 self._init(onx, runtimes) 

78 

79 def _get_random_dataset(self, N, dim): 

80 """ 

81 Returns a random datasets. 

82 """ 

83 raise NotImplementedError( # pragma: no cover 

84 "This method must be overloaded.") 

85 

86 def _init(self, onx, runtimes): 

87 "Finalizes the init." 

88 f = BytesIO() 

89 f.write(onx.SerializeToString()) 

90 self.ort_onnx = onx 

91 content = f.getvalue() 

92 self.ort = OrderedDict() 

93 self.outputs = OrderedDict() 

94 for r in runtimes: 

95 self.ort[r] = OnnxInference(content, runtime=r) 

96 self.outputs[r] = self.ort[r].output_names 

97 self.extract_model_info_skl() 

98 self.extract_model_info_onnx(ort_size=len(content)) 

99 

100 def extract_model_info_skl(self, **kwargs): 

101 """ 

102 Populates member ``self.skl_info`` with additional 

103 information on the model such as the number of node for 

104 a decision tree. 

105 """ 

106 self.skl_info = dict( 

107 skl_nb_base_estimators=get_nb_skl_base_estimators(self.skl, fitted=True)) 

108 self.skl_info.update(kwargs) 

109 if isinstance(self.skl, BaseDecisionTree): 

110 self.skl_info["skl_dt_nodes"] = self.skl.tree_.node_count 

111 elif isinstance(self.skl, BaseForest): 

112 self.skl_info["skl_rf_nodes"] = sum( 

113 est.tree_.node_count for est in self.skl.estimators_) 

114 

115 def extract_model_info_onnx(self, **kwargs): 

116 """ 

117 Populates member ``self.onnx_info`` with additional 

118 information on the :epkg:`ONNX` graph. 

119 """ 

120 self.onnx_info = { 

121 'onnx_nodes': len(self.ort_onnx.graph.node), # pylint: disable=E1101 

122 'onnx_opset': __max_supported_opset__, 

123 } 

124 self.onnx_info.update(kwargs) 

125 

126 def data(self, N=None, dim=None, **kwargs): # pylint: disable=W0221 

127 """ 

128 Generates random features. 

129 

130 @param N number of observations 

131 @param dim number of features 

132 """ 

133 if dim is None: 

134 raise RuntimeError( # pragma: no cover 

135 "dim must be defined.") 

136 if N is None: 

137 raise RuntimeError( # pragma: no cover 

138 "N must be defined.") 

139 return self._get_random_dataset(N, dim)[:1] 

140 

141 def model_info(self, model): 

142 """ 

143 Returns additional informations about a model. 

144 

145 @param model model to describe 

146 @return dictionary with additional descriptor 

147 """ 

148 res = dict(type_name=model.__class__.__name__) 

149 return res 

150 

151 def validate(self, results, **kwargs): 

152 """ 

153 Checks that methods *predict* and *predict_proba* returns 

154 the same results for both :epkg:`scikit-learn` and 

155 :epkg:`onnxruntime`. 

156 """ 

157 res = {} 

158 baseline = None 

159 for idt, fct, vals in results: 

160 key = idt, fct.get('method', '') 

161 if key not in res: 

162 res[key] = {} 

163 if isinstance(vals, list): 

164 vals = pandas.DataFrame(vals).values 

165 lib = fct['lib'] 

166 res[key][lib] = vals 

167 if lib == 'skl': 

168 baseline = lib 

169 

170 if len(res) == 0: 

171 raise RuntimeError( # pragma: no cover 

172 "No results to compare.") 

173 if baseline is None: 

174 raise RuntimeError( # pragma: no cover 

175 "Unable to guess the baseline in {}.".format( 

176 list(res.pop()))) 

177 

178 for key, exp in res.items(): 

179 vbase = exp[baseline] 

180 if vbase.shape[0] <= 10000: 

181 for name, vals in exp.items(): 

182 if name == baseline: 

183 continue 

184 p1, p2 = vbase, vals 

185 if len(p1.shape) == 1 and len(p2.shape) == 2: 

186 p2 = p2.ravel() 

187 try: 

188 assert_almost_equal(p1, p2, decimal=4) 

189 except AssertionError as e: 

190 if p1.dtype == numpy.int64 and p2.dtype == numpy.int64: 

191 delta = numpy.sum(numpy.abs(p1 - p2) != 0) 

192 if delta <= 2: 

193 # scikit-learn does double computation not float, 

194 # discrepencies between scikit-learn is likely to happen 

195 continue 

196 msg = "ERROR: Dim {}-{} ({}-{}) - discrepencies between '{}' and '{}' for '{}'.".format( 

197 vbase.shape, vals.shape, getattr( 

198 p1, 'dtype', None), 

199 getattr(p2, 'dtype', None), baseline, name, key) 

200 self.dump_error(msg, skl=self.skl, ort=self.ort, 

201 baseline=vbase, discrepencies=vals, 

202 onnx_bytes=self.ort_onnx.SerializeToString(), 

203 results=results, **kwargs) 

204 raise AssertionError(msg) from e