Compares implementations of Einsum#

This example compares the performance of numpy.einsum(), torch.einsum() and its decomposition into standard vector operations for a couple of equations.

Available optimisation#

The code shows which optimisation is used for the custom implementation, AVX or SSE and the number of available processors, equal to the default number of used threads to parallelize.

import numpy
import pandas
import matplotlib.pyplot as plt
from onnxruntime import InferenceSession
from skl2onnx.common.data_types import FloatTensorType
from skl2onnx.algebra.onnx_ops import OnnxEinsum
from cpyquickhelper.numbers import measure_time
from tqdm import tqdm
from mlprodict.testing.experimental_c_impl.experimental_c import (
from mlprodict.testing.einsum.einsum_fct import _einsum
from mlprodict.plotting.plotting_onnx import plot_onnx
from deeponnxcustom.onnxtorch.tchrun import OnnxTorchRuntime



Einsum: common code#

The main function which benchmark a couple of options.

    from torch import einsum as torch_einsum, from_numpy
except ImportError:
    torch_einsum = None

def build_ort_einsum(equation, op_version=14):  # opset=13, 14, ...
    node = OnnxEinsum('x', 'y', equation=equation,
    onx = node.to_onnx(inputs=[('x', FloatTensorType()),
                               ('y', FloatTensorType())],
    sess = InferenceSession(onx.SerializeToString())
    return lambda x, y:, {'x': x, 'y': y})

def build_ort_decomposed(equation, op_version=14):  # opset=13, 14, ...
    cache = _einsum(equation, numpy.float32, opset=op_version,
                    optimize=True, verbose=True, runtime="python")
    if not hasattr(cache, 'onnx_'):
    sess = InferenceSession(cache.onnx_.SerializeToString())
    return cache.onnx_, lambda x, y:, {'X0': x, 'X1': y})

def build_torch_decomposed(equation, op_version=14):  # opset=13, 14, ...
    cache = _einsum(equation, numpy.float32, opset=op_version,
                    optimize=True, verbose=True, runtime="python")
    if not hasattr(cache, 'onnx_'):
    sess = OnnxTorchRuntime(cache.onnx_)
    return cache.onnx_, lambda x, y:, y)

def loop_einsum_eq(fct, equation, xs, ys):
    for x, y in zip(xs, ys):
        fct(equation, x, y)

def loop_einsum_eq_th(fct, equation, xs, ys):
    for x, y in zip(xs, ys):
        fct(equation, x, y, nthread=-1)

def loop_einsum(fct, xs, ys):
    for x, y in zip(xs, ys):
        fct(x, y)

def benchmark_equation(equation, number=5, repeat=3):
    # equations
    ort_einsum = build_ort_einsum(equation)
    einsum_onnx, ort_einsum_decomposed = build_ort_decomposed(equation)
    torch_onnx, ort_torch_decomposed = build_torch_decomposed(equation)

    K, S, M, E = 16, 1024, 768, 64
    C = S // E * 2
    SIZE_MAP = {'K': K, 'S': S, 'E': E, 'C': C, 'M': M}

    pos1 = equation.find(',')
    pos2 = equation.find('->')
    lhs_op = equation[0:pos1]
    rhs_op = equation[pos1 + 1:pos2]
    lhs_shape = []
    for c in lhs_op:
    rhs_shape = []
    for c in rhs_op:

    terms = equation.split('->')[0].split(',')
    if 'e' in equation:
        pos_left = terms[0].find('e')
        pos_right = terms[1].find('e')
        pos_left = terms[0].find('k')
        pos_right = terms[1].find('k')

    def left_dim(dim):
        if pos_left == -1:
            return lhs_shape
        cp = list(lhs_shape)
        cp[pos_left] = dim
        return tuple(cp)

    def right_dim(dim):
        if pos_right == -1:
            return rhs_shape
        cp = list(rhs_shape)
        cp[pos_right] = dim
        return tuple(cp)

    sizes = [8, 16, 32, 64, 128, 256]
    if max(len(rhs_shape), len(lhs_shape)) >= 3:
        sizes = sizes[:4]

    res = []
    for dim in tqdm(sizes):
        xs = [numpy.random.rand(*left_dim(dim)).astype(numpy.float32)
              for _ in range(5)]
        ys = [numpy.random.rand(*right_dim(dim)).astype(numpy.float32)
              for _ in range(5)]

        # numpy
        ctx = dict(equation=equation, xs=xs, ys=ys, einsum=numpy.einsum,
                   loop_einsum=loop_einsum, loop_einsum_eq=loop_einsum_eq,
        obs = measure_time(
            lambda: loop_einsum_eq(numpy.einsum, equation, xs, ys),
            div_by_number=True, context=ctx, repeat=repeat, number=number)
        obs['dim'] = dim
        obs['fct'] = 'numpy.einsum'

        # onnxruntime
        ctx['einsum'] = ort_einsum
        obs = measure_time(
            lambda: loop_einsum(ort_einsum, xs, ys),
            div_by_number=True, context=ctx, repeat=repeat, number=number)
        obs['dim'] = dim
        obs['fct'] = 'ort_einsum'

        # onnxruntime decomposed
        ctx['einsum'] = ort_einsum_decomposed
        obs = measure_time(
            lambda: loop_einsum(ort_einsum_decomposed, xs, ys),
            div_by_number=True, context=ctx, repeat=repeat, number=number)
        obs['dim'] = dim
        obs['fct'] = 'ort_dec'

        # torch decomposed
        ctx['einsum'] = ort_torch_decomposed
        ctx['xs'] = [from_numpy(x) for x in xs]
        ctx['ys'] = [from_numpy(y) for y in ys]
        obs = measure_time(
            lambda: loop_einsum(ort_torch_decomposed, ctx['xs'], ctx['ys']),
            div_by_number=True, context=ctx, repeat=repeat, number=number)
        obs['dim'] = dim
        obs['fct'] = 'torch_dec'

        if torch_einsum is not None:
            # torch
            ctx['einsum'] = torch_einsum
            obs = measure_time(
                lambda: loop_einsum_eq(
                    torch_einsum, equation, ctx['xs'], ctx['ys']),
                div_by_number=True, context=ctx, repeat=repeat, number=number)
            obs['dim'] = dim
            obs['fct'] = 'torch_einsum'

    # Dataframes
    df = pandas.DataFrame(res)
    piv = df.pivot('dim', 'fct', 'average')

    rs = piv.copy()
    rs['ort_einsum'] = rs['numpy.einsum'] / rs['ort_einsum']
    rs['ort_dec'] = rs['numpy.einsum'] / rs['ort_dec']
    if 'torch_einsum' in rs.columns:
        rs['torch_einsum'] = rs['numpy.einsum'] / rs['torch_einsum']
    if 'torch_dec' in rs.columns:
        rs['torch_dec'] = rs['numpy.einsum'] / rs['torch_dec']
    rs['numpy.einsum'] = 1.

    # Graphs.
    shapes = ("%s - %s" % (left_dim('N'), right_dim('N'))).replace("'", "")
    fig, ax = plt.subplots(1, 2, figsize=(14, 5))
    piv.plot(logx=True, logy=True, ax=ax[0],
             title="Einsum benchmark\n%s -- %s"
                   " lower better" % (shapes, equation))
    ax[0].legend(prop={"size": 9})
    rs.plot(logx=True, logy=True, ax=ax[1],
            title="Einsum Speedup, baseline=numpy\n%s -- %s"
                  " higher better" % (shapes, equation))
    ax[1].plot([min(rs.index), max(rs.index)], [0.5, 0.5], 'g--')
    ax[1].plot([min(rs.index), max(rs.index)], [2., 2.], 'g--')
    ax[1].legend(prop={"size": 9})

    return df, rs, ax, einsum_onnx

A last function to plot the ONNX graphs.

def plot_onnx_einsum(equation, onx):
    filename = "einsum_eq_%s.onnx" % (
        equation.replace(",", "_").replace("->", "__"))
    with open(filename, "wb") as f:
    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    ax = plot_onnx(onx, ax=ax)
    return ax

First equation: s,se->se#

dfs = []
equation = "s,se->se"
df, piv, ax, onx = benchmark_equation(equation)
df.pivot("fct", "dim", "average")
df.pivot("dim", "fct", "average")
Einsum benchmark [1024] - (1024, N) -- s,se->se lower better, Einsum Speedup, baseline=numpy [1024] - (1024, N) -- s,se->se higher better


fct numpy.einsum ort_dec ort_einsum torch_dec torch_einsum
8 0.000247 0.001273 0.002432 0.034922 0.000516
16 0.000275 0.001247 0.001946 0.032480 0.000643
32 0.000367 0.001906 0.002314 0.036192 0.000892
64 0.000572 0.004570 0.003370 0.035161 0.035208
128 0.001403 0.004497 0.005505 0.036485 0.033636
256 0.003661 0.006020 0.007247 0.037037 0.012131

The onnx decomposition.

plot_onnx_einsum(equation, onx)



Second equation: se,sc->sec#

dfs = []
equation = "se,sc->sec"
df, piv, ax, onx = benchmark_equation(equation)
df.pivot("fct", "dim", "average")
df.pivot("dim", "fct", "average")
Einsum benchmark (1024, N) - [1024, 32] -- se,sc->sec lower better, Einsum Speedup, baseline=numpy (1024, N) - [1024, 32] -- se,sc->sec higher better


fct numpy.einsum ort_dec ort_einsum torch_dec torch_einsum
8 0.002855 0.005619 0.009149 0.035531 0.013773
16 0.005955 0.008737 0.013332 0.035361 0.036852
32 0.010595 0.016909 0.018995 0.037060 0.036418
64 0.021972 0.031209 0.034387 0.038537 0.042747
128 0.042581 0.056879 0.066103 0.043034 0.051790
256 0.123119 0.168853 0.184204 0.080065 0.089179

The onnx decomposition.

plot_onnx_einsum(equation, onx)



Third equation: se,se->s#

dfs = []
equation = "se,se->s"
df, piv, ax, onx = benchmark_equation(equation)
df.pivot("fct", "dim", "average")
df.pivot("dim", "fct", "average")
Einsum benchmark (1024, N) - (1024, N) -- se,se->s lower better, Einsum Speedup, baseline=numpy (1024, N) - (1024, N) -- se,se->s higher better


fct numpy.einsum ort_dec ort_einsum torch_dec torch_einsum
8 0.000249 0.002874 0.006535 0.036733 0.034544
16 0.000279 0.004695 0.011164 0.035590 0.034840
32 0.000351 0.008317 0.020396 0.031704 0.034645
64 0.000580 0.016016 0.039005 0.033084 0.035106
128 0.001326 0.030395 0.076134 0.035560 0.035343
256 0.002155 0.059731 0.149860 0.036302 0.036010

The onnx decomposition.

plot_onnx_einsum(equation, onx)



Fourth equation: ks,ksm->sm#

dfs = []
equation = "ks,ksm->sm"
df, piv, ax, onx = benchmark_equation(equation)
df.pivot("fct", "dim", "average")
df.pivot("dim", "fct", "average")
Einsum benchmark (N, 1024) - (N, 1024, 768) -- ks,ksm->sm lower better, Einsum Speedup, baseline=numpy (N, 1024) - (N, 1024, 768) -- ks,ksm->sm higher better


fct numpy.einsum ort_dec ort_einsum torch_dec torch_einsum
8 0.046381 0.075154 0.116060 0.150669 0.141059
16 0.087337 0.175303 0.213492 0.218604 0.208969
32 0.169449 1.702189 0.408497 8.416296 2.774563
64 0.350684 3.124925 0.826111 5.511207 11.133901

The onnx decomposition.

plot_onnx_einsum(equation, onx)



Fifth equation: sec,sm->ecm#

dfs = []
equation = "sec,sm->ecm"
df, piv, ax, onx = benchmark_equation(equation, number=1, repeat=1)
df.pivot("fct", "dim", "average")
df.pivot("dim", "fct", "average")
Einsum benchmark (1024, N, 32) - [1024, 768] -- sec,sm->ecm lower better, Einsum Speedup, baseline=numpy (1024, N, 32) - [1024, 768] -- sec,sm->ecm higher better


fct numpy.einsum ort_dec ort_einsum torch_dec torch_einsum
8 0.491584 0.187608 0.100328 0.167776 0.121127
16 1.758361 0.294181 0.309701 0.209646 0.179711
32 4.110958 0.516717 0.401818 0.288493 0.259152
64 8.159549 0.880178 0.801697 0.500670 0.467781

The onnx decomposition.

plot_onnx_einsum(equation, onx)



Sixth equation: sec,ecm->sm#

dfs = []
equation = "sec,ecm->sm"
df, piv, ax, onx = benchmark_equation(equation, number=1, repeat=1)
df.pivot("fct", "dim", "average")
df.pivot("dim", "fct", "average")
Einsum benchmark (1024, N, 32) - (N, 32, 768) -- sec,ecm->sm lower better, Einsum Speedup, baseline=numpy (1024, N, 32) - (N, 32, 768) -- sec,ecm->sm higher better


fct numpy.einsum ort_dec ort_einsum torch_dec torch_einsum
8 0.473223 0.144362 0.097199 0.189313 0.180613
16 1.468553 0.229524 0.180112 0.217261 0.234583
32 3.082025 0.427721 0.350729 0.284155 0.316785
64 6.216536 1.275591 0.712646 0.495009 0.505629

The onnx decomposition.

plot_onnx_einsum(equation, onx)




pytorch seems quite efficient on these examples.

merged = pandas.concat(dfs)
name = "einsum"
merged.to_csv("plot_%s.csv" % name, index=False)
merged.to_excel("plot_%s.xlsx" % name, index=False)
plt.savefig("plot_%s.png" % name)

plot op einsum

Total running time of the script: ( 10 minutes 55.111 seconds)

