Benchmark onnxruntime API: run or run_with_ort_values#

This short code compares different methods to call onnxruntime API.

  • run

  • run_with_ort_values

  • run_with_iobinding

You may profile this code:

py-spy record -o plot_benchmark_ort_api.svg -r 10
--native -- python plot_benchmark_ort_api.py

Linear Regression#

import numpy
import pandas
from onnxruntime import InferenceSession
from onnxruntime.capi._pybind_state import (  # pylint: disable=E0611
    SessionIOBinding, OrtDevice as C_OrtDevice,
    OrtMemType, OrtValue as C_OrtValue, RunOptions)
from sklearn import config_context
from sklearn.linear_model import LinearRegression
from skl2onnx import to_onnx
from cpyquickhelper.numbers.speed_measure import measure_time
from mlprodict.onnxrt import OnnxInference
from mlprodict.plotting.plotting import onnx_simple_text_plot
from mlprodict.testing.experimental_c_impl.experimental_c import code_optimisation

Available optimisation on this machine.

print(code_optimisation())
repeat = 250
number = 250
AVX-omp=8

Building the model#

X = numpy.random.randn(1000, 10).astype(numpy.float32)
y = X.sum(axis=1)

model = LinearRegression()
model.fit(X, y)
LinearRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


Conversion to ONNX#

onx = to_onnx(model, X, black_op={'LinearRegressor'})
print(onnx_simple_text_plot(onx))
opset: domain='' version=13
input: name='X' type=dtype('float32') shape=[None, 10]
init: name='coef' type=dtype('float32') shape=(10, 1)
init: name='intercept' type=dtype('float32') shape=(1,) -- array([-5.9604645e-08], dtype=float32)
init: name='shape_tensor' type=dtype('int64') shape=(2,) -- array([-1,  1])
MatMul(X, coef) -> multiplied
  Add(multiplied, intercept) -> resh
    Reshape(resh, shape_tensor) -> variable
output: name='variable' type=dtype('float32') shape=[None, 1]

Benchmarks#

data = []

scikit-learn

print('scikit-learn')

with config_context(assume_finite=True):
    obs = measure_time(lambda: model.predict(X),
                       context=dict(model=model, X=X),
                       repeat=repeat, number=number)
    obs['name'] = 'skl'
    data.append(obs)
scikit-learn

numpy runtime

print('numpy')
oinf = OnnxInference(onx, runtime="python_compiled")
obs = measure_time(
    lambda: oinf.run({'X': X}), context=dict(oinf=oinf, X=X),
    repeat=repeat, number=number)
obs['name'] = 'numpy'
data.append(obs)
numpy

onnxruntime: run

print('ort')
sess = InferenceSession(onx.SerializeToString(),
                        providers=['CPUExecutionProvider'])
obs = measure_time(lambda: sess.run(None, {'X': X}),
                   context=dict(sess=sess, X=X),
                   repeat=repeat, number=number)
obs['name'] = 'ort'
data.append(obs)
ort

onnxruntime: run from C API

print('ort-c')
sess = InferenceSession(onx.SerializeToString(),
                        providers=['CPUExecutionProvider'])
ro = RunOptions()
output_names = [o.name for o in sess.get_outputs()]
obs = measure_time(
    lambda: sess._sess.run(output_names, {'X': X}, ro),
    context=dict(sess=sess, X=X),
    repeat=repeat, number=number)
obs['name'] = 'ort-c'
data.append(obs)
ort-c

onnxruntime: run_with_ort_values from C API

print('ort-ov-c')
device = C_OrtDevice(C_OrtDevice.cpu(), OrtMemType.DEFAULT, 0)

Xov = C_OrtValue.ortvalue_from_numpy(X, device)

sess = InferenceSession(onx.SerializeToString(),
                        providers=['CPUExecutionProvider'])
ro = RunOptions()
output_names = [o.name for o in sess.get_outputs()]
obs = measure_time(
    lambda: sess._sess.run_with_ort_values(
        {'X': Xov}, output_names, ro),
    context=dict(sess=sess),
    repeat=repeat, number=number)
obs['name'] = 'ort-ov-c'
data.append(obs)
ort-ov-c

onnxruntime: run_with_iobinding from C API

print('ort-bind')
sess = InferenceSession(onx.SerializeToString(),
                        providers=['CPUExecutionProvider'])
bind = SessionIOBinding(sess._sess)
ort_device = C_OrtDevice(C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0)


def run_with_iobinding(sess, X, bind, ort_device):
    if X.__array_interface__['strides'] is not None:
        raise RuntimeError("onnxruntime only supports contiguous arrays.")
    bind.bind_input('X', ort_device, X.dtype, X.shape,
                    X.__array_interface__['data'][0])
    bind.bind_output('variable', ort_device)
    sess._sess.run_with_iobinding(bind, None)
    ortvalues = bind.get_outputs()
    return ortvalues[0].numpy()


obs = measure_time(lambda: run_with_iobinding(sess, X, bind, ort_device),
                   context=dict(run_with_iobinding=run_with_iobinding, X=X,
                                sess=sess, bind=bind, ort_device=ort_device),
                   repeat=repeat, number=number)

obs['name'] = 'ort-bind'
data.append(obs)
ort-bind

This fourth implementation is very similar to the previous one but it only binds array once and reuse the memory without changing the binding. It assumes that input size and output size never change. It copies the data into the fixed buffer and returns the same array, modified inplace.

print('ort-bind-inplace')
sess = InferenceSession(onx.SerializeToString(),
                        providers=['CPUExecutionProvider'])
bind = SessionIOBinding(sess._sess)
ort_device = C_OrtDevice(C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0)

Y = sess.run(None, {'X': X})[0]
bX = X.copy()
bY = Y.copy()

bind.bind_input('X', ort_device, numpy.float32, bX.shape,
                bX.__array_interface__['data'][0])
bind.bind_output('variable', ort_device, numpy.float32, bY.shape,
                 bY.__array_interface__['data'][0])
ortvalues = bind.get_outputs()


def run_with_iobinding(sess, bX, bY, X, bind, ortvalues):
    if X.__array_interface__['strides'] is not None:
        raise RuntimeError("onnxruntime only supports contiguous arrays.")
    bX[:, :] = X[:, :]
    sess._sess.run_with_iobinding(bind, None)
    return bY


obs = measure_time(
    lambda: run_with_iobinding(
        sess, bX, bY, X, bind, ortvalues),
    context=dict(run_with_iobinding=run_with_iobinding, X=X,
                 sess=sess, bind=bind, ortvalues=ortvalues, bX=bX, bY=bY),
    repeat=repeat, number=number)

obs['name'] = 'ort-bind-inplace'
data.append(obs)
ort-bind-inplace

Fifth implementation is equivalent to the previous one but does not copy anything.

print('ort-run-inplace')
sess = InferenceSession(onx.SerializeToString(),
                        providers=['CPUExecutionProvider'])
bind = SessionIOBinding(sess._sess)
ort_device = C_OrtDevice(C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0)

Y = sess.run(None, {'X': X})[0]
bX = X.copy()
bY = Y.copy()

bind.bind_input('X', ort_device, numpy.float32, bX.shape,
                bX.__array_interface__['data'][0])
bind.bind_output('variable', ort_device, numpy.float32, bY.shape,
                 bY.__array_interface__['data'][0])
ortvalues = bind.get_outputs()


def run_with_iobinding_no_copy(sess, bX, bY, X, bind, ortvalues):
    if X.__array_interface__['strides'] is not None:
        raise RuntimeError("onnxruntime only supports contiguous arrays.")
    # bX[:, :] = X[:, :]
    sess._sess.run_with_iobinding(bind, None)
    return bY


obs = measure_time(
    lambda: run_with_iobinding_no_copy(
        sess, bX, bY, X, bind, ortvalues),
    context=dict(run_with_iobinding_no_copy=run_with_iobinding_no_copy, X=X,
                 sess=sess, bind=bind, ortvalues=ortvalues, bX=bX, bY=bY),
    repeat=repeat, number=number)

obs['name'] = 'ort-run-inplace'
data.append(obs)
ort-run-inplace

Final#

df = pandas.DataFrame(data)
print(df[['name', 'average', 'number', 'repeat', 'deviation']])
df
               name   average  number  repeat  deviation
0               skl  0.040577     250     250   0.000395
1             numpy  0.030783     250     250   0.000398
2               ort  0.025867     250     250   0.000093
3             ort-c  0.023828     250     250   0.000028
4          ort-ov-c  0.022516     250     250   0.000025
5          ort-bind  0.034486     250     250   0.000078
6  ort-bind-inplace  0.026029     250     250   0.000061
7   ort-run-inplace  0.023210     250     250   0.000033
average deviation min_exec max_exec repeat number ttime context_size name
0 0.040577 0.000395 0.040214 0.043239 250 250 10.144357 232 skl
1 0.030783 0.000398 0.030489 0.034370 250 250 7.695718 232 numpy
2 0.025867 0.000093 0.025730 0.026189 250 250 6.466838 232 ort
3 0.023828 0.000028 0.023764 0.024018 250 250 5.956967 232 ort-c
4 0.022516 0.000025 0.022464 0.022726 250 250 5.629089 232 ort-ov-c
5 0.034486 0.000078 0.034319 0.034789 250 250 8.621502 232 ort-bind
6 0.026029 0.000061 0.025903 0.026547 250 250 6.507348 360 ort-bind-inplace
7 0.023210 0.000033 0.023147 0.023406 250 250 5.802480 360 ort-run-inplace


Graph#

ax = df.set_index('name')[['average']].plot.bar()
ax.set_title("Average inference time\nThe lower the better")
ax.tick_params(axis='x', labelrotation=15)
Average inference time The lower the better

Conclusion#

A profiling (onnxruntime is compiled with debug information) including # calls to native C++ functions shows that referencing input by name # takes a significant time when the graph is very small such as this one. The logic in method run_with_iobinding is much longer that the one implemented in run.

# import matplotlib.pyplot as plt
# plt.show()

Total running time of the script: ( 0 minutes 57.416 seconds)

Gallery generated by Sphinx-Gallery