Coverage for src/pymlbenchmark/external/onnxruntime_perf.py: 99%
99 statements
« prev ^ index » next coverage.py v7.2.1, created at 2023-03-08 00:27 +0100
« prev ^ index » next coverage.py v7.2.1, created at 2023-03-08 00:27 +0100
1"""
2@file
3@brief Implements a benchmark about performance for :epkg:`onnxruntime`
4"""
5import contextlib
6from collections import OrderedDict
7from io import BytesIO, StringIO
8import numpy
9from numpy.testing import assert_almost_equal
10import pandas
11from sklearn.ensemble._forest import BaseForest
12from sklearn.tree._classes import BaseDecisionTree
13from mlprodict.onnxrt import OnnxInference
14from mlprodict import __max_supported_opset__, get_ir_version
15from ..benchmark import BenchPerfTest
16from ..benchmark.sklearn_helper import get_nb_skl_base_estimators
19class OnnxRuntimeBenchPerfTest(BenchPerfTest):
20 """
21 Specific test to compare computing time predictions
22 with :epkg:`scikit-learn` and :epkg:`onnxruntime`.
23 See example :ref:`l-example-onnxruntime-logreg`.
24 The class requires the following modules to be installed:
25 :epkg:`onnx`, :epkg:`onnxruntime`, :epkg:`skl2onnx`,
26 :epkg:`mlprodict`.
27 """
29 def __init__(self, estimator, dim=None, N_fit=100000,
30 runtimes=('python_compiled', 'onnxruntime1'),
31 onnx_options=None, dtype=numpy.float32,
32 **opts):
33 """
34 @param estimator estimator class
35 @param dim number of features
36 @param N_fit number of observations to fit an estimator
37 @param runtimes runtimes to test for class :epkg:`OnnxInference`
38 @param opts training settings
39 @param onnx_options ONNX conversion options
40 @param dtype dtype (float32 or float64)
41 """
42 # These libraries are optional.
43 from skl2onnx import to_onnx # pylint: disable=E0401,C0415
44 from skl2onnx.common.data_types import FloatTensorType, DoubleTensorType # pylint: disable=E0401,C0415
46 if dim is None:
47 raise RuntimeError( # pragma: no cover
48 "dim must be defined.")
49 BenchPerfTest.__init__(self, **opts)
51 allowed = {"max_depth"}
52 opts = {k: v for k, v in opts.items() if k in allowed}
53 self.dtype = dtype
54 self.skl = estimator(**opts)
55 X, y = self._get_random_dataset(N_fit, dim)
56 try:
57 self.skl.fit(X, y)
58 except Exception as e: # pragma: no cover
59 raise RuntimeError("X.shape={}\nopts={}\nTraining failed for {}".format(
60 X.shape, opts, self.skl)) from e
62 if dtype == numpy.float64:
63 initial_types = [('X', DoubleTensorType([None, X.shape[1]]))]
64 elif dtype == numpy.float32:
65 initial_types = [('X', FloatTensorType([None, X.shape[1]]))]
66 else:
67 raise ValueError( # pragma: no cover
68 "Unable to convert the model into ONNX, unsupported dtype {}.".format(dtype))
69 self.logconvert = StringIO()
70 with contextlib.redirect_stdout(self.logconvert):
71 with contextlib.redirect_stderr(self.logconvert):
72 onx = to_onnx(self.skl, initial_types=initial_types,
73 options=onnx_options,
74 target_opset=__max_supported_opset__)
75 onx.ir_version = get_ir_version(__max_supported_opset__)
77 self._init(onx, runtimes)
79 def _get_random_dataset(self, N, dim):
80 """
81 Returns a random datasets.
82 """
83 raise NotImplementedError( # pragma: no cover
84 "This method must be overloaded.")
86 def _init(self, onx, runtimes):
87 "Finalizes the init."
88 f = BytesIO()
89 f.write(onx.SerializeToString())
90 self.ort_onnx = onx
91 content = f.getvalue()
92 self.ort = OrderedDict()
93 self.outputs = OrderedDict()
94 for r in runtimes:
95 self.ort[r] = OnnxInference(content, runtime=r)
96 self.outputs[r] = self.ort[r].output_names
97 self.extract_model_info_skl()
98 self.extract_model_info_onnx(ort_size=len(content))
100 def extract_model_info_skl(self, **kwargs):
101 """
102 Populates member ``self.skl_info`` with additional
103 information on the model such as the number of node for
104 a decision tree.
105 """
106 self.skl_info = dict(
107 skl_nb_base_estimators=get_nb_skl_base_estimators(self.skl, fitted=True))
108 self.skl_info.update(kwargs)
109 if isinstance(self.skl, BaseDecisionTree):
110 self.skl_info["skl_dt_nodes"] = self.skl.tree_.node_count
111 elif isinstance(self.skl, BaseForest):
112 self.skl_info["skl_rf_nodes"] = sum(
113 est.tree_.node_count for est in self.skl.estimators_)
115 def extract_model_info_onnx(self, **kwargs):
116 """
117 Populates member ``self.onnx_info`` with additional
118 information on the :epkg:`ONNX` graph.
119 """
120 self.onnx_info = {
121 'onnx_nodes': len(self.ort_onnx.graph.node), # pylint: disable=E1101
122 'onnx_opset': __max_supported_opset__,
123 }
124 self.onnx_info.update(kwargs)
126 def data(self, N=None, dim=None, **kwargs): # pylint: disable=W0221
127 """
128 Generates random features.
130 @param N number of observations
131 @param dim number of features
132 """
133 if dim is None:
134 raise RuntimeError( # pragma: no cover
135 "dim must be defined.")
136 if N is None:
137 raise RuntimeError( # pragma: no cover
138 "N must be defined.")
139 return self._get_random_dataset(N, dim)[:1]
141 def model_info(self, model):
142 """
143 Returns additional informations about a model.
145 @param model model to describe
146 @return dictionary with additional descriptor
147 """
148 res = dict(type_name=model.__class__.__name__)
149 return res
151 def validate(self, results, **kwargs):
152 """
153 Checks that methods *predict* and *predict_proba* returns
154 the same results for both :epkg:`scikit-learn` and
155 :epkg:`onnxruntime`.
156 """
157 res = {}
158 baseline = None
159 for idt, fct, vals in results:
160 key = idt, fct.get('method', '')
161 if key not in res:
162 res[key] = {}
163 if isinstance(vals, list):
164 vals = pandas.DataFrame(vals).values
165 lib = fct['lib']
166 res[key][lib] = vals
167 if lib == 'skl':
168 baseline = lib
170 if len(res) == 0:
171 raise RuntimeError( # pragma: no cover
172 "No results to compare.")
173 if baseline is None:
174 raise RuntimeError( # pragma: no cover
175 "Unable to guess the baseline in {}.".format(
176 list(res.pop())))
178 for key, exp in res.items():
179 vbase = exp[baseline]
180 if vbase.shape[0] <= 10000:
181 for name, vals in exp.items():
182 if name == baseline:
183 continue
184 p1, p2 = vbase, vals
185 if len(p1.shape) == 1 and len(p2.shape) == 2:
186 p2 = p2.ravel()
187 try:
188 assert_almost_equal(p1, p2, decimal=4)
189 except AssertionError as e:
190 if p1.dtype == numpy.int64 and p2.dtype == numpy.int64:
191 delta = numpy.sum(numpy.abs(p1 - p2) != 0)
192 if delta <= 2:
193 # scikit-learn does double computation not float,
194 # discrepencies between scikit-learn is likely to happen
195 continue
196 msg = "ERROR: Dim {}-{} ({}-{}) - discrepencies between '{}' and '{}' for '{}'.".format(
197 vbase.shape, vals.shape, getattr(
198 p1, 'dtype', None),
199 getattr(p2, 'dtype', None), baseline, name, key)
200 self.dump_error(msg, skl=self.skl, ort=self.ort,
201 baseline=vbase, discrepencies=vals,
202 onnx_bytes=self.ort_onnx.SerializeToString(),
203 results=results, **kwargs)
204 raise AssertionError(msg) from e