.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "gyexamples/plot_parallel_execution.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_gyexamples_plot_parallel_execution.py: .. _l-plot-parallel-execution: =============================== Multithreading with onnxruntime =============================== .. index:: thread, parallel, onnxruntime Python implements multithreading but it is not working in practice due to the GIL (see :epkg:`Le GIL`). However, if most of the parallelized code is not creating python object, this option becomes more interesting than creating several processes trying to exchange data through sockets. :epkg:`onnxruntime` falls into that category. For a big model such as a deeplearning model, this might be interesting. :epkg:`onnxruntime` already parallelizes the computation of every operator (Gemm, MatMul) using all the CPU it can get. To use that approach to get significant results, it needs to be used on different processors (CPU, GPU) in parallel. That's what this example shows. .. contents:: :local: A model ======= Let's retrieve a not so big model. They are taken from the `ONNX Model Zoo `_ or can even be custom. .. GENERATED FROM PYTHON SOURCE LINES 30-89 .. code-block:: default import gc import multiprocessing import os import pickle from pprint import pprint import urllib.request import threading import time import sys import tqdm import numpy import pandas from onnxcustom.utils.benchmark import measure_time import torch.cuda from onnxruntime import InferenceSession, get_all_providers from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 OrtValue as C_OrtValue) from onnxcustom.utils.onnxruntime_helper import get_ort_device_from_session def download_file(url, name, min_size): if not os.path.exists(name): print(f"download '{url}'") with urllib.request.urlopen(url) as u: content = u.read() if len(content) < min_size: raise RuntimeError( f"Unable to download '{url}' due to\n{content}") print(f"downloaded {len(content)} bytes.") with open(name, "wb") as f: f.write(content) else: print(f"'{name}' already downloaded") small = "custom" if "custom" in sys.argv else "small" if small == "custom": model_name = "gpt2.onnx" url_name = None maxN, stepN, repN = 5, 1, 4 big_model = True elif small: model_name = "mobilenetv2-10.onnx" url_name = ("https://github.com/onnx/models/raw/main/vision/" "classification/mobilenet/model") maxN, stepN, repN = 21, 2, 4 big_model = False else: model_name = "resnet18-v1-7.onnx" url_name = ("https://github.com/onnx/models/raw/main/vision/" "classification/resnet/model") maxN, stepN, repN = 21, 2, 4 big_model = False if url_name is not None: url_name += "/" + model_name download_file(url_name, model_name, 100000) .. rst-class:: sphx-glr-script-out .. code-block:: none download 'https://github.com/onnx/models/raw/main/vision/classification/mobilenet/model/mobilenetv2-10.onnx' downloaded 13963115 bytes. .. GENERATED FROM PYTHON SOURCE LINES 90-98 Measuring inference time when parallelizing on CPU ================================================== Sequence ++++++++ Let's first dig into the model to retrieve the input and output names as well as their shapes. .. GENERATED FROM PYTHON SOURCE LINES 98-101 .. code-block:: default sess1 = InferenceSession(model_name, providers=["CPUExecutionProvider"]) .. GENERATED FROM PYTHON SOURCE LINES 102-103 inputs. .. GENERATED FROM PYTHON SOURCE LINES 103-111 .. code-block:: default for i in sess1.get_inputs(): print(f"input {i}, name={i.name!r}, type={i.type}, shape={i.shape}") input_name = i.name input_shape = list(i.shape) if input_shape[0] in [None, "batch_size", "N"]: input_shape[0] = 1 .. rst-class:: sphx-glr-script-out .. code-block:: none input NodeArg(name='input', type='tensor(float)', shape=['batch_size', 3, 224, 224]), name='input', type=tensor(float), shape=['batch_size', 3, 224, 224] .. GENERATED FROM PYTHON SOURCE LINES 112-113 outputs. .. GENERATED FROM PYTHON SOURCE LINES 113-122 .. code-block:: default output_name = None for i in sess1.get_outputs(): print(f"output {i}, name={i.name!r}, type={i.type}, shape={i.shape}") if output_name is None: output_name = i.name print(f"input_name={input_name!r}, output_name={output_name!r}") .. rst-class:: sphx-glr-script-out .. code-block:: none output NodeArg(name='output', type='tensor(float)', shape=['batch_size', 1000]), name='output', type=tensor(float), shape=['batch_size', 1000] input_name='input', output_name='output' .. GENERATED FROM PYTHON SOURCE LINES 123-124 Let's take some random inputs. .. GENERATED FROM PYTHON SOURCE LINES 124-132 .. code-block:: default if model_name == "gpt2.onnx": with open("encoded_tensors-gpt2.pkl", "rb") as f: [encoded_tensors, labels] = pickle.load(f) rnd_img = encoded_tensors[0]["input_ids"].numpy() else: rnd_img = numpy.random.rand(*input_shape).astype(numpy.float32) .. GENERATED FROM PYTHON SOURCE LINES 133-134 And measure the processing time. .. GENERATED FROM PYTHON SOURCE LINES 134-142 .. code-block:: default results = sess1.run(None, {input_name: rnd_img}) print(f"output: type={results[0].dtype}, shape={results[0].shape}") pprint(measure_time(lambda: sess1.run(None, {input_name: rnd_img}), div_by_number=True, repeat=3, number=3)) .. rst-class:: sphx-glr-script-out .. code-block:: none output: type=float32, shape=(1, 1000) {'average': 0.015177655667583978, 'deviation': 0.00010602954223327644, 'max_exec': 0.01531660669328024, 'min_exec': 0.015059365658089519, 'number': 3, 'repeat': 3} .. GENERATED FROM PYTHON SOURCE LINES 143-147 Parallelization +++++++++++++++ We define a number of threads lower than the number of cores. .. GENERATED FROM PYTHON SOURCE LINES 147-158 .. code-block:: default n_threads = min(4, multiprocessing.cpu_count() - 1) print(f"n_threads={n_threads}") if model_name == "gpt2.onnx": imgs = [x["input_ids"].numpy() for x in encoded_tensors[:maxN * n_threads]] else: imgs = [numpy.random.rand(*input_shape).astype(numpy.float32) for i in range(maxN * n_threads)] .. rst-class:: sphx-glr-script-out .. code-block:: none n_threads=4 .. GENERATED FROM PYTHON SOURCE LINES 159-161 Let's create an object `InferenceSession` for every thread assuming the memory can hold that many objects. .. GENERATED FROM PYTHON SOURCE LINES 161-165 .. code-block:: default sesss = [InferenceSession(model_name, providers=["CPUExecutionProvider"]) for i in range(n_threads)] .. GENERATED FROM PYTHON SOURCE LINES 166-167 Let's measure the time for a sequence of images. .. GENERATED FROM PYTHON SOURCE LINES 167-180 .. code-block:: default def sequence(sess, imgs): # A simple function going through all images. res = [] for img in imgs: res.append(sess.run(None, {input_name: img})[0]) return res pprint(measure_time(lambda: sequence(sesss[0], imgs[:n_threads]), div_by_number=True, repeat=2, number=2)) .. rst-class:: sphx-glr-script-out .. code-block:: none {'average': 0.06002907722722739, 'deviation': 0.0001825352665035432, 'max_exec': 0.06021161249373108, 'min_exec': 0.0598465419607237, 'number': 2, 'repeat': 2} .. GENERATED FROM PYTHON SOURCE LINES 181-182 And then with multithreading. .. GENERATED FROM PYTHON SOURCE LINES 182-218 .. code-block:: default class MyThread(threading.Thread): def __init__(self, sess, imgs): threading.Thread.__init__(self) self.sess = sess self.imgs = imgs self.q = [] def run(self): for img in self.imgs: r = self.sess.run(None, {input_name: img})[0] self.q.append(r) def parallel(sesss, imgs): # creation of the threads n_threads = len(sesss) threads = [MyThread(sess, imgs[i::n_threads]) for i, sess in enumerate(sesss)] # start the threads for t in threads: t.start() # wait for each of them and stores the results res = [] for t in threads: t.join() res.extend(t.q) return res pprint(measure_time(lambda: parallel(sesss, imgs[:n_threads]), div_by_number=True, repeat=2, number=2)) .. rst-class:: sphx-glr-script-out .. code-block:: none {'average': 0.2536951872461941, 'deviation': 0.032998464273987296, 'max_exec': 0.28669365152018145, 'min_exec': 0.22069672297220677, 'number': 2, 'repeat': 2} .. GENERATED FROM PYTHON SOURCE LINES 219-225 It is worse. It is expected as this code tries to parallelize the execution of onnxruntime which is also trying to parallelize the execution of every matrix multiplication, every tensor operators. It is like using two conflicting strategies to parallize. Let's check for a different number of images to parallelize. .. GENERATED FROM PYTHON SOURCE LINES 225-269 .. code-block:: default def benchmark(fcts, sesss, imgs, stepN=1, repN=4): data = [] nth = len(sesss) Ns = [1] + list(range(nth, len(imgs), stepN * nth)) for N in tqdm.tqdm(Ns): obs = {'n_imgs': len(imgs), 'maxN': maxN, 'stepN': stepN, 'repN': repN, 'batch_size': N, 'n_threads': len(sesss)} ns = [] for name, fct, index in fcts: for i in range(repN): if index is None: r = fct(sesss, imgs[:N]) else: r = fct(sesss[index], imgs[:N]) if i == 0: # let's get rid of the first iteration sometimes # used to initialize internal objects. begin = time.perf_counter() end = (time.perf_counter() - begin) / (repN - 1) obs.update({f"n_imgs_{name}": len(r), f"time_{name}": end}) ns.append(len(r)) if len(set(ns)) != 1: raise RuntimeError( f"Cannot compare experiments as it returns differents number of " f"results ns={ns}, obs={obs}.") data.append(obs) return pandas.DataFrame(data) if not big_model: print(f"ORT // CPU, n_threads={len(sesss)}") df = benchmark(sesss=sesss, imgs=imgs, stepN=stepN, repN=repN, fcts=[('sequence', sequence, 0), ('parallel', parallel, None)]) df.reset_index(drop=False).to_csv("ort_cpu.csv", index=False) else: print("ORT // CPU skipped for a big model.") df = None df .. rst-class:: sphx-glr-script-out .. code-block:: none ORT // CPU, n_threads=4 0%| | 0/11 [00:00
n_imgs maxN stepN repN batch_size n_threads n_imgs_sequence time_sequence n_imgs_parallel time_parallel
0 84 21 2 4 1 4 1 0.022426 1 0.017216
1 84 21 2 4 4 4 4 0.060762 4 0.311757
2 84 21 2 4 12 4 12 0.181081 12 0.712939
3 84 21 2 4 20 4 20 0.302038 20 1.249855
4 84 21 2 4 28 4 28 0.422885 28 1.759554
5 84 21 2 4 36 4 36 0.546499 36 2.179846
6 84 21 2 4 44 4 44 0.677457 44 2.778476
7 84 21 2 4 52 4 52 0.798181 52 3.075618
8 84 21 2 4 60 4 60 0.904038 60 3.540634
9 84 21 2 4 68 4 68 1.029824 68 4.222446
10 84 21 2 4 76 4 76 1.167068 76 4.758600


.. GENERATED FROM PYTHON SOURCE LINES 270-272 Plots +++++ .. GENERATED FROM PYTHON SOURCE LINES 272-295 .. code-block:: default def make_plot(df, title): if df is None: return None if "n_threads" in df.columns: n_threads = list(set(df.n_threads)) if len(n_threads) != 1: raise RuntimeError(f"n_threads={n_threads} must be unique.") index = "batch_size" else: n_threads = "?" index = "n_imgs_seq_cpu" kwargs = dict(title=f"{title}\nn_threads={n_threads}", logy=True) columns = [index] + [c for c in df.columns if c.startswith("time")] ax = df[columns].set_index(columns[0]).plot(**kwargs) ax.set_xlabel("batch size") ax.set_ylabel("seconds") return ax make_plot(df, "Time per image / batch size") .. image-sg:: /gyexamples/images/sphx_glr_plot_parallel_execution_001.png :alt: Time per image / batch size n_threads=[4] :srcset: /gyexamples/images/sphx_glr_plot_parallel_execution_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 296-300 As expected, it does not improve. It is like parallezing using two strategies, per kernel and per image, both trying to access all the process cores at the same time. The time spent to synchronize is significant. .. GENERATED FROM PYTHON SOURCE LINES 302-306 Same with another API based on OrtValue +++++++++++++++++++++++++++++++++++++++ See :epkg:`l-ortvalue-doc`. .. GENERATED FROM PYTHON SOURCE LINES 306-363 .. code-block:: default def sequence_ort_value(sess, imgs): ort_device = get_ort_device_from_session(sess) res = [] for img in imgs: ov = C_OrtValue.ortvalue_from_numpy(img, ort_device) out = sess._sess.run_with_ort_values( {input_name: ov}, [output_name], None)[0] res.append(out.numpy()) return res class MyThreadOrtValue(threading.Thread): def __init__(self, sess, imgs): threading.Thread.__init__(self) self.sess = sess self.imgs = imgs self.q = [] self.ort_device = get_ort_device_from_session(self.sess) def run(self): ort_device = self.ort_device sess = self.sess._sess q = self.q for img in self.imgs: ov = C_OrtValue.ortvalue_from_numpy(img, ort_device) out = sess.run_with_ort_values( {input_name: ov}, [output_name], None)[0] q.append(out.numpy()) def parallel_ort_value(sess, imgs): n_threads = len(sesss) threads = [MyThreadOrtValue(sess, imgs[i::n_threads]) for i, sess in enumerate(sesss)] for t in threads: t.start() res = [] for t in threads: t.join() res.extend(t.q) return res if not big_model: print(f"ORT // CPU (OrtValue), n_threads={len(sesss)}") df = benchmark(sesss=sesss, imgs=imgs, stepN=stepN, repN=repN, fcts=[('sequence', sequence_ort_value, 0), ('parallel', parallel_ort_value, None)]) df.reset_index(drop=False).to_csv("ort_cpu_ortvalue.csv", index=False) else: print("ORT // CPU (OrtValue) skipped for a big model.") df = None df .. rst-class:: sphx-glr-script-out .. code-block:: none ORT // CPU (OrtValue), n_threads=4 0%| | 0/11 [00:00
n_imgs maxN stepN repN batch_size n_threads n_imgs_sequence time_sequence n_imgs_parallel time_parallel
0 84 21 2 4 1 4 1 0.015207 1 0.016510
1 84 21 2 4 4 4 4 0.060909 4 0.254680
2 84 21 2 4 12 4 12 0.181104 12 0.724804
3 84 21 2 4 20 4 20 0.303054 20 1.239013
4 84 21 2 4 28 4 28 0.423197 28 1.731472
5 84 21 2 4 36 4 36 0.543858 36 2.162558
6 84 21 2 4 44 4 44 0.668846 44 2.937713
7 84 21 2 4 52 4 52 0.803038 52 3.233586
8 84 21 2 4 60 4 60 0.910624 60 3.910219
9 84 21 2 4 68 4 68 1.030855 68 4.354784
10 84 21 2 4 76 4 76 1.152432 76 5.026625


.. GENERATED FROM PYTHON SOURCE LINES 364-365 Plots. .. GENERATED FROM PYTHON SOURCE LINES 365-368 .. code-block:: default make_plot(df, "Time per image / batch size\nrun_with_ort_values") .. image-sg:: /gyexamples/images/sphx_glr_plot_parallel_execution_002.png :alt: Time per image / batch size run_with_ort_values n_threads=[4] :srcset: /gyexamples/images/sphx_glr_plot_parallel_execution_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 369-373 It leads to the same conclusion. It is no use to parallelize on CPU as onnxruntime is already doing that per kernel. Let's free the memory to make some space for other experiments. .. GENERATED FROM PYTHON SOURCE LINES 373-377 .. code-block:: default del sesss[:] gc.collect() .. rst-class:: sphx-glr-script-out .. code-block:: none 9648 .. GENERATED FROM PYTHON SOURCE LINES 378-382 GPU === Let's check first if it is possible. .. GENERATED FROM PYTHON SOURCE LINES 382-396 .. code-block:: default has_cuda = "CUDAExecutionProvider" in get_all_providers() if not has_cuda: print(f"No CUDA provider was detected in {get_all_providers()}.") n_gpus = torch.cuda.device_count() if has_cuda else 0 if n_gpus == 0: print("No GPU or one GPU was detected.") elif n_gpus == 1: print("1 GPU was detected.") else: print(f"{n_gpus} GPUs were detected.") .. rst-class:: sphx-glr-script-out .. code-block:: none No GPU or one GPU was detected. .. GENERATED FROM PYTHON SOURCE LINES 397-399 Parallelization GPU + CPU +++++++++++++++++++++++++ .. GENERATED FROM PYTHON SOURCE LINES 399-427 .. code-block:: default if has_cuda and n_gpus > 0: print("ORT // CPU + GPU") repN = 4 sesss = [InferenceSession(model_name, providers=["CPUExecutionProvider"]), InferenceSession(model_name, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])] if model_name == "gpt2.onnx": imgs = [x["input_ids"].numpy() for x in encoded_tensors[:maxN * len(sesss)]] else: imgs = [numpy.random.rand(*input_shape).astype(numpy.float32) for i in range(maxN * len(sesss))] df = benchmark(sesss=sesss, imgs=imgs, stepN=stepN, repN=repN, fcts=[('seq_cpu', sequence_ort_value, 0), ('seq_gpu', sequence_ort_value, 1), ('parallel', parallel_ort_value, None)]) df.reset_index(drop=False).to_csv("ort_cpu_gpu.csv", index=False) del sesss[:] gc.collect() else: print("No GPU is available but data should be like the following.") df = pandas.read_csv("data/ort_cpu_gpu.csv") df .. rst-class:: sphx-glr-script-out .. code-block:: none No GPU is available but data should be like the following. .. raw:: html
index n_imgs maxN stepN repN batch_size n_threads n_imgs_seq_cpu time_seq_cpu n_imgs_seq_gpu time_seq_gpu n_imgs_parallel time_parallel
0 0 42 21 2 4 1 2 1 0.002749 1 0.001918 1 0.003333
1 1 42 21 2 4 2 2 2 0.010106 2 0.003670 2 0.006719
2 2 42 21 2 4 6 2 6 0.019608 6 0.010759 6 0.014681
3 3 42 21 2 4 10 2 10 0.033824 10 0.018202 10 0.031529
4 4 42 21 2 4 14 2 14 0.039911 14 0.031874 14 0.032424
5 5 42 21 2 4 18 2 18 0.063273 18 0.041411 18 0.040567
6 6 42 21 2 4 22 2 22 0.059806 22 0.049939 22 0.049083
7 7 42 21 2 4 26 2 26 0.075591 26 0.065178 26 0.060896
8 8 42 21 2 4 30 2 30 0.083746 30 0.070518 30 0.070927
9 9 42 21 2 4 34 2 34 0.097336 34 0.078067 34 0.085081
10 10 42 21 2 4 38 2 38 0.124420 38 0.072007 38 0.077388


.. GENERATED FROM PYTHON SOURCE LINES 428-429 Plots. .. GENERATED FROM PYTHON SOURCE LINES 429-433 .. code-block:: default ax = make_plot(df, "Time per image / batch size\nCPU + GPU") ax .. image-sg:: /gyexamples/images/sphx_glr_plot_parallel_execution_003.png :alt: Time per image / batch size CPU + GPU n_threads=[2] :srcset: /gyexamples/images/sphx_glr_plot_parallel_execution_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 434-436 The parallelization on mulitple CPU + GPUs is working, it is faster than CPU but it is still slower than using a single GPU in that case. .. GENERATED FROM PYTHON SOURCE LINES 438-442 Parallelization on multiple GPUs ++++++++++++++++++++++++++++++++ This is the only case for which it should work as every GPU is indenpendent. .. GENERATED FROM PYTHON SOURCE LINES 442-473 .. code-block:: default if n_gpus > 1: print("ORT // GPUs") sesss = [] for i in range(n_gpus): print(f"Initialize device {i}") sesss.append( InferenceSession(model_name, providers=["CUDAExecutionProvider", "CPUExecutionProvider"], provider_options=[{"device_id": i}, {}])) if model_name == "gpt2.onnx": imgs = [x["input_ids"].numpy() for x in encoded_tensors[:maxN * len(sesss)]] else: imgs = [numpy.random.rand(*input_shape).astype(numpy.float32) for i in range(maxN * len(sesss))] df = benchmark(sesss=sesss, imgs=imgs, stepN=stepN, repN=repN, fcts=[('sequence', sequence_ort_value, 0), ('parallel', parallel_ort_value, None)]) df.reset_index(drop=False).to_csv("ort_gpus.csv", index=False) del sesss[:] gc.collect() else: print("No GPU is available but data should be like the following.") df = pandas.read_csv("data/ort_gpus.csv") df .. rst-class:: sphx-glr-script-out .. code-block:: none No GPU is available but data should be like the following. .. raw:: html
index n_imgs maxN stepN repN batch_size n_threads n_imgs_sequence time_sequence n_imgs_parallel time_parallel
0 0 84 21 2 4 1 4 1 0.002126 1 0.002219
1 1 84 21 2 4 4 4 4 0.007492 4 0.003118
2 2 84 21 2 4 12 4 12 0.023209 12 0.008247
3 3 84 21 2 4 20 4 20 0.040625 20 0.012640
4 4 84 21 2 4 28 4 28 0.054286 28 0.017552
5 5 84 21 2 4 36 4 36 0.069855 36 0.022314
6 6 84 21 2 4 44 4 44 0.085468 44 0.027169
7 7 84 21 2 4 52 4 52 0.100239 52 0.028660
8 8 84 21 2 4 60 4 60 0.114419 60 0.035442
9 9 84 21 2 4 68 4 68 0.131548 68 0.037772
10 10 84 21 2 4 76 4 76 0.149430 76 0.047436


.. GENERATED FROM PYTHON SOURCE LINES 474-475 Plots. .. GENERATED FROM PYTHON SOURCE LINES 475-479 .. code-block:: default ax = make_plot(df, "Time per image / batch size\n4 GPUs") ax .. image-sg:: /gyexamples/images/sphx_glr_plot_parallel_execution_004.png :alt: Time per image / batch size 4 GPUs n_threads=[4] :srcset: /gyexamples/images/sphx_glr_plot_parallel_execution_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 480-483 The parallelization on multiple GPUs did work. With a model `GPT2 `_, it would give the following results. .. GENERATED FROM PYTHON SOURCE LINES 483-492 .. code-block:: default data = pandas.read_csv("data/ort_gpus_gpt2.csv") df = pandas.DataFrame(data) ax = make_plot(df, "Time per image / batch size\n4 GPUs - GPT2") ax # import matplotlib.pyplot as plt # plt.show() .. image-sg:: /gyexamples/images/sphx_glr_plot_parallel_execution_005.png :alt: Time per image / batch size 4 GPUs - GPT2 n_threads=[4] :srcset: /gyexamples/images/sphx_glr_plot_parallel_execution_005.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 4 minutes 18.247 seconds) .. _sphx_glr_download_gyexamples_plot_parallel_execution.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_parallel_execution.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_parallel_execution.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_