Profile onnxruntime execution#

The following examples converts a model into ONNX and runs it with onnxruntime. This one is then uses to profile the execution by looking the time spent in each operator. This analysis gives some hints on how to optimize the processing time by looking the nodes consuming most of the ressources.

Neareast Neighbours#

import json
import numpy
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.axes_divider import make_axes_area_auto_adjustable
import pandas
from onnxruntime import InferenceSession, SessionOptions, get_device
from onnxruntime.capi._pybind_state import (  # pylint: disable=E0611
    SessionIOBinding, OrtDevice as C_OrtDevice, OrtValue as C_OrtValue)
from sklearn.neighbors import RadiusNeighborsRegressor
from skl2onnx import to_onnx
from tqdm import tqdm
from mlprodict.testing.experimental_c_impl.experimental_c import code_optimisation
from mlprodict.plotting.plotting import onnx_simple_text_plot, plot_onnx
from mlprodict.onnxrt.ops_whole.session import OnnxWholeSession

Available optimisation on this machine.

print(code_optimisation())
AVX-omp=8

Building the model#

X = numpy.random.randn(1000, 10).astype(numpy.float64)
y = X.sum(axis=1).reshape((-1, 1))

model = RadiusNeighborsRegressor()
model.fit(X, y)
RadiusNeighborsRegressor()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


Conversion to ONNX#

onx = to_onnx(model, X, options={'optim': 'cdist'},
              target_opset=17)

print(onnx_simple_text_plot(onx))
opset: domain='' version=17
opset: domain='ai.onnx.ml' version=1
opset: domain='com.microsoft' version=1
input: name='X' type=dtype('float64') shape=[None, 10]
init: name='knny_ArrayFeatureExtractorcst' type=dtype('float64') shape=(1000,)
init: name='cond_CDistcst' type=dtype('float64') shape=(1000, 10)
init: name='cond_Lesscst' type=dtype('float64') shape=(1,) -- array([1.])
init: name='arange_CumSumcst' type=dtype('int64') shape=(1,) -- array([1])
init: name='knny_Reshapecst' type=dtype('int64') shape=(2,) -- array([  -1, 1000])
init: name='Re_Reshapecst' type=dtype('int64') shape=(2,) -- array([-1,  1])
CDist[com.microsoft](X, cond_CDistcst, metric=b'euclidean') -> cond_dist
  Less(cond_dist, cond_Lesscst) -> cond_C0
    Cast(cond_C0, to=11) -> nnbin_output0
      ReduceSum(nnbin_output0, arange_CumSumcst, keepdims=0) -> norm_reduced0
  Shape(cond_dist) -> arange_shape0
    ConstantOfShape(arange_shape0, value=[-1.0]) -> arange_output01
      Cast(arange_output01, to=7) -> arange_output0
        CumSum(arange_output0, arange_CumSumcst) -> arange_y0
          Neg(arange_y0) -> arange_Y0
        Add(arange_Y0, arange_output0) -> arange_C0
    Where(cond_C0, arange_C0, arange_output0) -> nnind_output0
      Flatten(nnind_output0) -> knny_output0
        ArrayFeatureExtractor(knny_ArrayFeatureExtractorcst, knny_output0) -> knny_Z0
          Reshape(knny_Z0, knny_Reshapecst, allowzero=0) -> knny_reshaped0
            Cast(knny_reshaped0, to=11) -> final_output0
      Mul(final_output0, nnbin_output0) -> final_C0
        ReduceSum(final_C0, arange_CumSumcst, keepdims=0) -> final_reduced0
          Shape(final_reduced0) -> normr_shape0
        Reshape(norm_reduced0, normr_shape0, allowzero=0) -> normr_reshaped0
          Div(final_reduced0, normr_reshaped0) -> Di_C0
            Reshape(Di_C0, Re_Reshapecst, allowzero=0) -> variable
output: name='variable' type=dtype('float64') shape=[None, 1]

The ONNX graph looks like the following.

_, ax = plt.subplots(1, 1, figsize=(8, 15))
plot_onnx(onx, ax=ax)
plot profile ort
<AxesSubplot: >

Profiling#

The profiling is enabled by setting attribute enable_profling in SessionOptions. Method end_profiling collects all the results and stores it on disk in JSON format.

so = SessionOptions()
so.enable_profiling = True
sess = InferenceSession(onx.SerializeToString(), so,
                        providers=['CPUExecutionProvider'])
feeds = {'X': X[:100]}

for i in tqdm(range(0, 10)):
    sess.run(None, feeds)

prof = sess.end_profiling()
print(prof)
  0%|          | 0/10 [00:00<?, ?it/s]
 90%|######### | 9/10 [00:00<00:00, 86.09it/s]
100%|##########| 10/10 [00:00<00:00, 85.23it/s]
onnxruntime_profile__2023-01-17_01-56-07.json

Better rendering#

with open(prof, "r") as f:
    js = json.load(f)
df = pandas.DataFrame(OnnxWholeSession.process_profiling(js))
df
cat pid tid dur ts ph name args_op_name args_thread_scheduling_stats args_input_type_shape args_activation_size args_parameter_size args_graph_index args_output_size args_provider args_output_type_shape args_exec_plan_index
0 Session 32082 32082 5242 6 X model_loading_array NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
1 Session 32082 32082 7317 5340 X session_initialization NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2 Node 32082 32082 2 14974 X cond_CDist_fence_before CDist NaN NaN NaN NaN NaN NaN NaN NaN NaN
3 Node 32082 32082 2881 14988 X cond_CDist_kernel_time CDist {'main_thread': {'thread_pool_name': 'session-... [{'double': [100, 10]}, {'double': [1000, 10]}] 8000 80000 0 800000 CPUExecutionProvider [{'double': [100, 1000]}] 0
4 Node 32082 32082 0 17895 X cond_CDist_fence_after CDist NaN NaN NaN NaN NaN NaN NaN NaN NaN
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
617 Node 32082 32082 0 131850 X Re_Reshape_fence_before Reshape NaN NaN NaN NaN NaN NaN NaN NaN NaN
618 Node 32082 32082 40 131853 X Re_Reshape_kernel_time Reshape {'main_thread': {'thread_pool_name': 'session-... [{'double': [100]}, {'int64': [2]}] 800 16 20 800 CPUExecutionProvider [{'double': [100, 1]}] 20
619 Node 32082 32082 0 131903 X Re_Reshape_fence_after Reshape NaN NaN NaN NaN NaN NaN NaN NaN NaN
620 Session 32082 32082 11483 120426 X SequentialExecutor::Execute NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
621 Session 32082 32082 11509 120411 X model_run NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

622 rows × 17 columns



Graphs#

First graph is by operator type.

gr_dur = df[['dur', "args_op_name"]].groupby(
    "args_op_name").sum().sort_values('dur')
gr_n = df[['dur', "args_op_name"]].groupby(
    "args_op_name").count().sort_values('dur')
gr_n = gr_n.loc[gr_dur.index, :]

fig, ax = plt.subplots(1, 2, figsize=(8, 4))
gr_dur.plot.barh(ax=ax[0])
gr_n.plot.barh(ax=ax[1])
ax[0].set_title("duration")
ax[1].set_title("n occurences")
fig.suptitle(model.__class__.__name__)
RadiusNeighborsRegressor, duration, n occurences
Text(0.5, 0.98, 'RadiusNeighborsRegressor')

Second graph is by operator name.

gr_dur = df[['dur', "args_op_name", "name"]].groupby(
    ["args_op_name", "name"]).sum().sort_values('dur')
gr_dur.head(n=5)
dur
args_op_name name
Flatten knny_Flatten_fence_after 0
Less cond_Less_fence_after 0
cond_Less_fence_before 0
Mul final_Mul_fence_after 0
final_Mul_fence_before 0


And the graph.

_, ax = plt.subplots(1, 1, figsize=(8, gr_dur.shape[0] // 2))
gr_dur.plot.barh(ax=ax)
ax.set_title("duration per node")
for label in (ax.get_xticklabels() + ax.get_yticklabels()):
    label.set_fontsize(7)
make_axes_area_auto_adjustable(ax)
duration per node

The model spends most of its time in CumSum operator. Operator Shape gets called the highest number of times.

# plt.show()

GPU or CPU#

if get_device().upper() == 'GPU':
    ort_device = C_OrtDevice(
        C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0)
else:
    ort_device = C_OrtDevice(
        C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0)

# session
sess = InferenceSession(onx.SerializeToString(), so,
                        providers=['CPUExecutionProvider',
                                   'CUDAExecutionProvider'])
bind = SessionIOBinding(sess._sess)

# moving the data on CPU or GPU
ort_value = C_OrtValue.ortvalue_from_numpy(X, ort_device)
somewhere/workspace/onnxcustom/onnxcustom_UT_39_std/_venv/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:54: UserWarning: Specified provider 'CUDAExecutionProvider' is not in available provider names.Available providers: 'CPUExecutionProvider'
  warnings.warn(

A function which calls the API for any device.

def run_with_iobinding(sess, bind, ort_device, ort_value, dtype):
    bind.bind_input('X', ort_device, dtype, ort_value.shape(),
                    ort_value.data_ptr())
    bind.bind_output('variable', ort_device)
    sess._sess.run_with_iobinding(bind, None)
    ortvalues = bind.get_outputs()
    return ortvalues[0].numpy()

The profiling.

for i in tqdm(range(0, 10)):
    run_with_iobinding(sess, bind, ort_device, ort_value, X.dtype)

prof = sess.end_profiling()
with open(prof, "r") as f:
    js = json.load(f)
df = pandas.DataFrame(OnnxWholeSession.process_profiling(js))
df
  0%|          | 0/10 [00:00<?, ?it/s]
 10%|#         | 1/10 [00:00<00:00,  9.62it/s]
 30%|###       | 3/10 [00:00<00:00, 10.19it/s]
 50%|#####     | 5/10 [00:00<00:00, 10.26it/s]
 70%|#######   | 7/10 [00:00<00:00, 10.31it/s]
 90%|######### | 9/10 [00:00<00:00, 10.34it/s]
100%|##########| 10/10 [00:00<00:00, 10.29it/s]
cat pid tid dur ts ph name args_op_name args_thread_scheduling_stats args_input_type_shape args_activation_size args_parameter_size args_graph_index args_output_size args_provider args_output_type_shape args_exec_plan_index
0 Session 32082 32082 812 6 X model_loading_array NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
1 Session 32082 32082 7226 897 X session_initialization NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2 Node 32082 32082 1 14358 X cond_CDist_fence_before CDist NaN NaN NaN NaN NaN NaN NaN NaN NaN
3 Node 32082 32082 19961 14369 X cond_CDist_kernel_time CDist {'main_thread': {'thread_pool_name': 'session-... [{'double': [1000, 10]}, {'double': [1000, 10]}] 80000 80000 0 8000000 CPUExecutionProvider [{'double': [1000, 1000]}] 0
4 Node 32082 32082 0 34356 X cond_CDist_fence_after CDist NaN NaN NaN NaN NaN NaN NaN NaN NaN
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
617 Node 32082 32082 0 985706 X Re_Reshape_fence_before Reshape NaN NaN NaN NaN NaN NaN NaN NaN NaN
618 Node 32082 32082 41 985709 X Re_Reshape_kernel_time Reshape {'main_thread': {'thread_pool_name': 'session-... [{'double': [1000]}, {'int64': [2]}] 8000 16 20 8000 CPUExecutionProvider [{'double': [1000, 1]}] 20
619 Node 32082 32082 0 985761 X Re_Reshape_fence_after Reshape NaN NaN NaN NaN NaN NaN NaN NaN NaN
620 Session 32082 32082 95520 890247 X SequentialExecutor::Execute NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
621 Session 32082 32082 95546 890231 X model_run NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

622 rows × 17 columns



First graph is by operator type.

gr_dur = df[['dur', "args_op_name"]].groupby(
    "args_op_name").sum().sort_values('dur')
gr_n = df[['dur', "args_op_name"]].groupby(
    "args_op_name").count().sort_values('dur')
gr_n = gr_n.loc[gr_dur.index, :]

fig, ax = plt.subplots(1, 2, figsize=(8, 4))
gr_dur.plot.barh(ax=ax[0])
gr_n.plot.barh(ax=ax[1])
ax[0].set_title("duration")
ax[1].set_title("n occurences")
fig.suptitle(model.__class__.__name__)
RadiusNeighborsRegressor, duration, n occurences
Text(0.5, 0.98, 'RadiusNeighborsRegressor')

Second graph is by operator name.

gr_dur = df[['dur', "args_op_name", "name"]].groupby(
    ["args_op_name", "name"]).sum().sort_values('dur')
gr_dur.head(n=5)
dur
args_op_name name
Flatten knny_Flatten_fence_after 0
Reshape knny_Reshape_fence_after 0
Re_Reshape_fence_before 0
Re_Reshape_fence_after 0
ReduceSum norm_ReduceSum_fence_before 0


And the graph.

_, ax = plt.subplots(1, 1, figsize=(8, gr_dur.shape[0] // 2))
gr_dur.plot.barh(ax=ax)
ax.set_title("duration per node")
for label in (ax.get_xticklabels() + ax.get_yticklabels()):
    label.set_fontsize(7)
make_axes_area_auto_adjustable(ax)
duration per node

It shows the same results.

# plt.show()

Total running time of the script: ( 0 minutes 11.444 seconds)

Gallery generated by Sphinx-Gallery