.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "gyexamples/plot_benchmark_graph_opt.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_gyexamples_plot_benchmark_graph_opt.py: .. _benchmark-ort-onnx-graph-opt: Benchmark onnxruntime optimization ================================== :epkg:`onnxruntime` does optimize the ONNX graph before running the inference. It tries for example to fuse a matrix multiplication following or followed by a transpose, choosing the most efficient path. .. contents:: :local: One ONNX file +++++++++++++ This section creates an ONNX graph if there is not one. .. GENERATED FROM PYTHON SOURCE LINES 20-36 .. code-block:: default import os from collections import OrderedDict, Counter import numpy import onnx from cpyquickhelper.numbers.speed_measure import measure_time import pandas from onnxruntime import InferenceSession, SessionOptions, get_device from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 SessionIOBinding, OrtDevice as C_OrtDevice, OrtValue as C_OrtValue, GraphOptimizationLevel) from sklearn.neighbors import RadiusNeighborsRegressor from skl2onnx import to_onnx from tqdm import tqdm from mlprodict.testing.experimental_c_impl.experimental_c import code_optimisation .. GENERATED FROM PYTHON SOURCE LINES 37-38 Available optimisation on this machine. .. GENERATED FROM PYTHON SOURCE LINES 38-42 .. code-block:: default print(code_optimisation()) .. rst-class:: sphx-glr-script-out .. code-block:: none AVX-omp=8 .. GENERATED FROM PYTHON SOURCE LINES 43-45 Building the model ++++++++++++++++++ .. GENERATED FROM PYTHON SOURCE LINES 45-61 .. code-block:: default filename = "onnx_to_profile.onnx" if not os.path.exists(filename): print(f"Generate a graph for {filename!r}.") X = numpy.random.randn(1000, 10).astype(numpy.float64) y = X.sum(axis=1).reshape((-1, 1)) model = RadiusNeighborsRegressor() model.fit(X, y) onx = to_onnx(model, X, options={'optim': 'cdist'}, target_opset=17) with open(filename, "wb") as f: f.write(onx.SerializeToString()) .. GENERATED FROM PYTHON SOURCE LINES 62-66 Functions +++++++++ We need to generate random inputs to test the graph. .. GENERATED FROM PYTHON SOURCE LINES 66-103 .. code-block:: default def random_input(typ, shape, batch): if typ == 'tensor(double)': dtype = numpy.float64 elif typ == 'tensor(float)': dtype = numpy.float32 else: raise NotImplementedError( f"Unable to guess dtype from {typ!r}.") if len(shape) <= 1: new_shape = shape elif shape[0] is None: new_shape = tuple([batch] + list(shape[1:])) else: new_shape = shape return numpy.random.randn(*new_shape).astype(dtype) def random_feed(sess, batch=10): """ Creates a dictionary of random inputs. :param batch: dimension to use as batch dimension if unknown :return: dictionary """ inputs = sess.get_inputs() res = OrderedDict() for inp in inputs: name = inp.name typ = inp.type shape = inp.shape res[name] = random_input(typ, shape, batch) return res .. GENERATED FROM PYTHON SOURCE LINES 104-105 A function which calls the API for any device. .. GENERATED FROM PYTHON SOURCE LINES 105-118 .. code-block:: default def run_with_iobinding(sess, bind, ort_device, feed_ort_value, outputs): for name, (value, dtype) in feed_ort_value.items(): bind.bind_input(name, ort_device, dtype, value.shape(), value.data_ptr()) for out in outputs: bind.bind_output(out, ort_device) sess._sess.run_with_iobinding(bind, None) ortvalues = bind.get_outputs() return [o.numpy() for o in ortvalues] .. GENERATED FROM PYTHON SOURCE LINES 119-124 Benchmark +++++++++ Let's choose the device available on this machine. batch dimension is set to 10. .. GENERATED FROM PYTHON SOURCE LINES 124-137 .. code-block:: default batch = 200 if get_device().upper() == 'GPU': ort_device = C_OrtDevice( C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0) provider = 'CUDAExecutionProvider' else: ort_device = C_OrtDevice( C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0) provider = 'CPUExecutionProvider' print(f"provider = {provider!r}") .. rst-class:: sphx-glr-script-out .. code-block:: none provider = 'CPUExecutionProvider' .. GENERATED FROM PYTHON SOURCE LINES 138-139 We load the graph. .. GENERATED FROM PYTHON SOURCE LINES 139-143 .. code-block:: default with open(filename, 'rb') as f: onx = onnx.load(f) .. GENERATED FROM PYTHON SOURCE LINES 144-145 Create of the session. .. GENERATED FROM PYTHON SOURCE LINES 145-194 .. code-block:: default data = [] files = [] legend = [] for graph_opt, name_opt in tqdm([ (GraphOptimizationLevel.ORT_DISABLE_ALL, "ORT_DISABLE_ALL"), (GraphOptimizationLevel.ORT_ENABLE_BASIC, "ORT_ENABLE_BASIC"), (GraphOptimizationLevel.ORT_ENABLE_EXTENDED, "ORT_ENABLE_EXTENDED"), (GraphOptimizationLevel.ORT_ENABLE_ALL, "ORT_ENABLE_ALL")]): so = SessionOptions() so.graph_optimization_level = graph_opt so.optimized_model_filepath = ( os.path.split(filename)[-1] + f".optimized.{name_opt}.onnx") files.append(so.optimized_model_filepath) legend.append(name_opt) sess = InferenceSession(onx.SerializeToString(), so, providers=[provider]) bind = SessionIOBinding(sess._sess) ##################################### # Creates random data feed = random_feed(sess, batch) ##################################### # moving the data on CPU or GPU feed_ort_value = OrderedDict( (name, (C_OrtValue.ortvalue_from_numpy(v, ort_device), v.dtype)) for name, v in feed.items()) outputs = [o.name for o in sess.get_outputs()] ####################################### # The profiling. obs = measure_time( lambda: run_with_iobinding( sess, bind, ort_device, feed_ort_value, outputs), context=dict(run_with_iobinding=run_with_iobinding, feed_ort_value=feed_ort_value, outputs=outputs, sess=sess, bind=bind, ort_device=ort_device), repeat=10, number=10, div_by_number=True) obs['name'] = name_opt data.append(obs) df = pandas.DataFrame(data) df .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0/4 [00:00
average deviation min_exec max_exec repeat number ttime context_size name
0 0.022404 0.000044 0.022363 0.022495 10 10 0.224037 360 ORT_DISABLE_ALL
1 0.021755 0.000049 0.021699 0.021852 10 10 0.217548 360 ORT_ENABLE_BASIC
2 0.021758 0.000038 0.021700 0.021829 10 10 0.217584 360 ORT_ENABLE_EXTENDED
3 0.021735 0.000020 0.021700 0.021767 10 10 0.217352 360 ORT_ENABLE_ALL


.. GENERATED FROM PYTHON SOURCE LINES 195-197 Graph +++++ .. GENERATED FROM PYTHON SOURCE LINES 197-205 .. code-block:: default df = df.set_index('name') dev = df[['deviation']].copy() dev.columns = ['average'] ax = df[['average']].plot.bar(yerr=dev) ax.set_title(os.path.split(filename)[-1]) ax.tick_params(axis='x', labelrotation=15) .. image-sg:: /gyexamples/images/sphx_glr_plot_benchmark_graph_opt_001.png :alt: onnx_to_profile.onnx :srcset: /gyexamples/images/sphx_glr_plot_benchmark_graph_opt_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 206-207 The result are similar because the optimized model was very similar. .. GENERATED FROM PYTHON SOURCE LINES 207-219 .. code-block:: default data = [] for name in files: with open(name, "rb") as f: onx = onnx.load(f) op_names = [op.op_type for op in onx.graph.node] data.append(Counter(op_names)) df = pandas.DataFrame(data).T df.columns = legend df .. raw:: html
ORT_DISABLE_ALL ORT_ENABLE_BASIC ORT_ENABLE_EXTENDED ORT_ENABLE_ALL
CDist 1 1 1 1
Less 1 1 1 1
Shape 2 2 2 2
ConstantOfShape 1 1 1 1
Cast 3 2 2 2
ReduceSum 2 2 2 2
CumSum 1 1 1 1
Neg 1 1 1 1
Add 1 1 1 1
Where 1 1 1 1
Flatten 1 1 1 1
ArrayFeatureExtractor 1 1 1 1
Reshape 3 3 3 3
Mul 1 1 1 1
Div 1 1 1 1


.. GENERATED FROM PYTHON SOURCE LINES 220-221 Graph. .. GENERATED FROM PYTHON SOURCE LINES 221-227 .. code-block:: default ax = df.plot.barh(yerr=dev) ax.set_title(os.path.split(filename)[-1]) # import matplotlib.pyplot as plt # plt.show() .. image-sg:: /gyexamples/images/sphx_glr_plot_benchmark_graph_opt_002.png :alt: onnx_to_profile.onnx :srcset: /gyexamples/images/sphx_glr_plot_benchmark_graph_opt_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Text(0.5, 1.0, 'onnx_to_profile.onnx') .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 9.926 seconds) .. _sphx_glr_download_gyexamples_plot_benchmark_graph_opt.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_benchmark_graph_opt.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_benchmark_graph_opt.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_