Compares implementations of Einsum#

This example compares the performance of numpy.einsum(), torch.einsum() and its decomposition into standard vector operations for a couple of equations.

Available optimisation#

The code shows which optimisation is used for the custom implementation, AVX or SSE and the number of available processors, equal to the default number of used threads to parallelize.

import numpy
import pandas
import matplotlib.pyplot as plt
from onnxruntime import InferenceSession
from skl2onnx.common.data_types import FloatTensorType
from skl2onnx.algebra.onnx_ops import OnnxEinsum
from cpyquickhelper.numbers import measure_time
from tqdm import tqdm
from mlprodict.testing.experimental_c_impl.experimental_c import (
    code_optimisation)
from mlprodict.testing.einsum.einsum_fct import _einsum
from mlprodict.plotting.plotting_onnx import plot_onnx
from deeponnxcustom.onnxtorch.tchrun import OnnxTorchRuntime
print(code_optimisation())

Out:

AVX-omp=8

Einsum: common code#

The main function which benchmark a couple of options.

try:
    from torch import einsum as torch_einsum, from_numpy
except ImportError:
    torch_einsum = None


def build_ort_einsum(equation, op_version=14):  # opset=13, 14, ...
    node = OnnxEinsum('x', 'y', equation=equation,
                      op_version=op_version,
                      output_names=['z'])
    onx = node.to_onnx(inputs=[('x', FloatTensorType()),
                               ('y', FloatTensorType())],
                       target_opset=op_version)
    sess = InferenceSession(onx.SerializeToString())
    return lambda x, y: sess.run(None, {'x': x, 'y': y})


def build_ort_decomposed(equation, op_version=14):  # opset=13, 14, ...
    cache = _einsum(equation, numpy.float32, opset=op_version,
                    optimize=True, verbose=True, runtime="python")
    if not hasattr(cache, 'onnx_'):
        cache.build()
    sess = InferenceSession(cache.onnx_.SerializeToString())
    return cache.onnx_, lambda x, y: sess.run(None, {'X0': x, 'X1': y})


def build_torch_decomposed(equation, op_version=14):  # opset=13, 14, ...
    cache = _einsum(equation, numpy.float32, opset=op_version,
                    optimize=True, verbose=True, runtime="python")
    if not hasattr(cache, 'onnx_'):
        cache.build()
    sess = OnnxTorchRuntime(cache.onnx_)
    return cache.onnx_, lambda x, y: sess.run(x, y)


def loop_einsum_eq(fct, equation, xs, ys):
    for x, y in zip(xs, ys):
        fct(equation, x, y)


def loop_einsum_eq_th(fct, equation, xs, ys):
    for x, y in zip(xs, ys):
        fct(equation, x, y, nthread=-1)


def loop_einsum(fct, xs, ys):
    for x, y in zip(xs, ys):
        fct(x, y)


def benchmark_equation(equation, number=5, repeat=3):
    # equations
    ort_einsum = build_ort_einsum(equation)
    einsum_onnx, ort_einsum_decomposed = build_ort_decomposed(equation)
    torch_onnx, ort_torch_decomposed = build_torch_decomposed(equation)

    K, S, M, E = 16, 1024, 768, 64
    C = S // E * 2
    SIZE_MAP = {'K': K, 'S': S, 'E': E, 'C': C, 'M': M}

    pos1 = equation.find(',')
    pos2 = equation.find('->')
    lhs_op = equation[0:pos1]
    rhs_op = equation[pos1 + 1:pos2]
    lhs_shape = []
    for c in lhs_op:
        lhs_shape.append(SIZE_MAP[c.upper()])
    rhs_shape = []
    for c in rhs_op:
        rhs_shape.append(SIZE_MAP[c.upper()])

    terms = equation.split('->')[0].split(',')
    if 'e' in equation:
        pos_left = terms[0].find('e')
        pos_right = terms[1].find('e')
    else:
        pos_left = terms[0].find('k')
        pos_right = terms[1].find('k')

    def left_dim(dim):
        if pos_left == -1:
            return lhs_shape
        cp = list(lhs_shape)
        cp[pos_left] = dim
        return tuple(cp)

    def right_dim(dim):
        if pos_right == -1:
            return rhs_shape
        cp = list(rhs_shape)
        cp[pos_right] = dim
        return tuple(cp)

    sizes = [8, 16, 32, 64, 128, 256]
    if max(len(rhs_shape), len(lhs_shape)) >= 3:
        sizes = sizes[:4]

    res = []
    for dim in tqdm(sizes):
        xs = [numpy.random.rand(*left_dim(dim)).astype(numpy.float32)
              for _ in range(5)]
        ys = [numpy.random.rand(*right_dim(dim)).astype(numpy.float32)
              for _ in range(5)]

        # numpy
        ctx = dict(equation=equation, xs=xs, ys=ys, einsum=numpy.einsum,
                   loop_einsum=loop_einsum, loop_einsum_eq=loop_einsum_eq,
                   loop_einsum_eq_th=loop_einsum_eq_th)
        obs = measure_time(
            lambda: loop_einsum_eq(numpy.einsum, equation, xs, ys),
            div_by_number=True, context=ctx, repeat=repeat, number=number)
        obs['dim'] = dim
        obs['fct'] = 'numpy.einsum'
        res.append(obs)

        # onnxruntime
        ctx['einsum'] = ort_einsum
        obs = measure_time(
            lambda: loop_einsum(ort_einsum, xs, ys),
            div_by_number=True, context=ctx, repeat=repeat, number=number)
        obs['dim'] = dim
        obs['fct'] = 'ort_einsum'
        res.append(obs)

        # onnxruntime decomposed
        ctx['einsum'] = ort_einsum_decomposed
        obs = measure_time(
            lambda: loop_einsum(ort_einsum_decomposed, xs, ys),
            div_by_number=True, context=ctx, repeat=repeat, number=number)
        obs['dim'] = dim
        obs['fct'] = 'ort_dec'
        res.append(obs)

        # torch decomposed
        ctx['einsum'] = ort_torch_decomposed
        ctx['xs'] = [from_numpy(x) for x in xs]
        ctx['ys'] = [from_numpy(y) for y in ys]
        obs = measure_time(
            lambda: loop_einsum(ort_torch_decomposed, ctx['xs'], ctx['ys']),
            div_by_number=True, context=ctx, repeat=repeat, number=number)
        obs['dim'] = dim
        obs['fct'] = 'torch_dec'
        res.append(obs)

        if torch_einsum is not None:
            # torch
            ctx['einsum'] = torch_einsum
            obs = measure_time(
                lambda: loop_einsum_eq(
                    torch_einsum, equation, ctx['xs'], ctx['ys']),
                div_by_number=True, context=ctx, repeat=repeat, number=number)
            obs['dim'] = dim
            obs['fct'] = 'torch_einsum'
            res.append(obs)

    # Dataframes
    df = pandas.DataFrame(res)
    piv = df.pivot('dim', 'fct', 'average')

    rs = piv.copy()
    rs['ort_einsum'] = rs['numpy.einsum'] / rs['ort_einsum']
    rs['ort_dec'] = rs['numpy.einsum'] / rs['ort_dec']
    if 'torch_einsum' in rs.columns:
        rs['torch_einsum'] = rs['numpy.einsum'] / rs['torch_einsum']
    if 'torch_dec' in rs.columns:
        rs['torch_dec'] = rs['numpy.einsum'] / rs['torch_dec']
    rs['numpy.einsum'] = 1.

    # Graphs.
    shapes = ("%s - %s" % (left_dim('N'), right_dim('N'))).replace("'", "")
    fig, ax = plt.subplots(1, 2, figsize=(14, 5))
    piv.plot(logx=True, logy=True, ax=ax[0],
             title="Einsum benchmark\n%s -- %s"
                   " lower better" % (shapes, equation))
    ax[0].legend(prop={"size": 9})
    rs.plot(logx=True, logy=True, ax=ax[1],
            title="Einsum Speedup, baseline=numpy\n%s -- %s"
                  " higher better" % (shapes, equation))
    ax[1].plot([min(rs.index), max(rs.index)], [0.5, 0.5], 'g--')
    ax[1].plot([min(rs.index), max(rs.index)], [2., 2.], 'g--')
    ax[1].legend(prop={"size": 9})

    return df, rs, ax, einsum_onnx

A last function to plot the ONNX graphs.

def plot_onnx_einsum(equation, onx):
    filename = "einsum_eq_%s.onnx" % (
        equation.replace(",", "_").replace("->", "__"))
    with open(filename, "wb") as f:
        f.write(onx.SerializeToString())
    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    ax = plot_onnx(onx, ax=ax)
    ax.set_title(equation)
    return ax

First equation: s,se->se#

dfs = []
equation = "s,se->se"
df, piv, ax, onx = benchmark_equation(equation)
df.pivot("fct", "dim", "average")
dfs.append(df)
df.pivot("dim", "fct", "average")
Einsum benchmark [1024] - (1024, N) -- s,se->se lower better, Einsum Speedup, baseline=numpy [1024] - (1024, N) -- s,se->se higher better

Out:

  0%|          | 0/3 [00:00<?, ?it/s]
0.0069 rtbest='s,se->se':   0%|          | 0/3 [00:00<?, ?it/s]
0.0069 rtbest='s,se->se':  33%|###3      | 1/3 [00:00<00:01,  1.32it/s]
0.0069 rtbest='s,se->se': 100%|##########| 3/3 [00:00<00:00,  3.75it/s]
somewhere/workspace/deeponnxcustom/deeponnxcustom_UT_39_std/_doc/sphinxdoc/source/deeponnxcustom/onnxtorch/tchrun.py:173: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  ../torch/csrc/utils/tensor_numpy.cpp:178.)
  res[init.name] = torch.from_numpy(  # pylint: disable=E1101

  0%|          | 0/6 [00:00<?, ?it/s]
 17%|#6        | 1/6 [00:00<00:02,  1.68it/s]
 33%|###3      | 2/6 [00:01<00:02,  1.75it/s]
 50%|#####     | 3/6 [00:01<00:01,  1.67it/s]
 67%|######6   | 4/6 [00:02<00:01,  1.20it/s]
 83%|########3 | 5/6 [00:04<00:00,  1.02it/s]
100%|##########| 6/6 [00:05<00:00,  1.00it/s]
100%|##########| 6/6 [00:05<00:00,  1.14it/s]
fct numpy.einsum ort_dec ort_einsum torch_dec torch_einsum
dim
8 0.000247 0.001273 0.002432 0.034922 0.000516
16 0.000275 0.001247 0.001946 0.032480 0.000643
32 0.000367 0.001906 0.002314 0.036192 0.000892
64 0.000572 0.004570 0.003370 0.035161 0.035208
128 0.001403 0.004497 0.005505 0.036485 0.033636
256 0.003661 0.006020 0.007247 0.037037 0.012131


The onnx decomposition.

plot_onnx_einsum(equation, onx)
s,se->se

Out:

<AxesSubplot:title={'center':'s,se->se'}>

Second equation: se,sc->sec#

dfs = []
equation = "se,sc->sec"
df, piv, ax, onx = benchmark_equation(equation)
df.pivot("fct", "dim", "average")
dfs.append(df)
df.pivot("dim", "fct", "average")
Einsum benchmark (1024, N) - [1024, 32] -- se,sc->sec lower better, Einsum Speedup, baseline=numpy (1024, N) - [1024, 32] -- se,sc->sec higher better

Out:

  0%|          | 0/7 [00:00<?, ?it/s]
0.0083 rtbest='se,sc->sec':   0%|          | 0/7 [00:00<?, ?it/s]
0.0082 rtbest='es,ec->esc':   0%|          | 0/7 [00:00<?, ?it/s]
0.0082 rtbest='es,ec->esc':  57%|#####7    | 4/7 [00:00<00:00, 37.36it/s]
0.0081 rtbest='ec,es->ecs':  57%|#####7    | 4/7 [00:00<00:00, 37.36it/s]
0.0079 rtbest='cs,ce->cse':  57%|#####7    | 4/7 [00:00<00:00, 37.36it/s]
0.0079 rtbest='cs,ce->cse': 100%|##########| 7/7 [00:00<00:00, 38.92it/s]

  0%|          | 0/6 [00:00<?, ?it/s]
 17%|#6        | 1/6 [00:01<00:05,  1.01s/it]
 33%|###3      | 2/6 [00:02<00:05,  1.31s/it]
 50%|#####     | 3/6 [00:04<00:04,  1.54s/it]
 67%|######6   | 4/6 [00:06<00:03,  1.94s/it]
 83%|########3 | 5/6 [00:10<00:02,  2.66s/it]
100%|##########| 6/6 [00:20<00:00,  5.06s/it]
100%|##########| 6/6 [00:20<00:00,  3.42s/it]
fct numpy.einsum ort_dec ort_einsum torch_dec torch_einsum
dim
8 0.002855 0.005619 0.009149 0.035531 0.013773
16 0.005955 0.008737 0.013332 0.035361 0.036852
32 0.010595 0.016909 0.018995 0.037060 0.036418
64 0.021972 0.031209 0.034387 0.038537 0.042747
128 0.042581 0.056879 0.066103 0.043034 0.051790
256 0.123119 0.168853 0.184204 0.080065 0.089179


The onnx decomposition.

plot_onnx_einsum(equation, onx)
se,sc->sec

Out:

<AxesSubplot:title={'center':'se,sc->sec'}>

Third equation: se,se->s#

dfs = []
equation = "se,se->s"
df, piv, ax, onx = benchmark_equation(equation)
df.pivot("fct", "dim", "average")
dfs.append(df)
df.pivot("dim", "fct", "average")
Einsum benchmark (1024, N) - (1024, N) -- se,se->s lower better, Einsum Speedup, baseline=numpy (1024, N) - (1024, N) -- se,se->s higher better

Out:

  0%|          | 0/3 [00:00<?, ?it/s]
0.0076 rtbest='se,se->s':   0%|          | 0/3 [00:00<?, ?it/s]
0.0076 rtbest='se,se->s': 100%|##########| 3/3 [00:00<00:00, 39.19it/s]

  0%|          | 0/6 [00:00<?, ?it/s]
 17%|#6        | 1/6 [00:01<00:06,  1.22s/it]
 33%|###3      | 2/6 [00:02<00:05,  1.27s/it]
 50%|#####     | 3/6 [00:03<00:04,  1.35s/it]
 67%|######6   | 4/6 [00:05<00:03,  1.56s/it]
 83%|########3 | 5/6 [00:08<00:01,  1.98s/it]
100%|##########| 6/6 [00:12<00:00,  2.78s/it]
100%|##########| 6/6 [00:12<00:00,  2.15s/it]
fct numpy.einsum ort_dec ort_einsum torch_dec torch_einsum
dim
8 0.000249 0.002874 0.006535 0.036733 0.034544
16 0.000279 0.004695 0.011164 0.035590 0.034840
32 0.000351 0.008317 0.020396 0.031704 0.034645
64 0.000580 0.016016 0.039005 0.033084 0.035106
128 0.001326 0.030395 0.076134 0.035560 0.035343
256 0.002155 0.059731 0.149860 0.036302 0.036010


The onnx decomposition.

plot_onnx_einsum(equation, onx)
se,se->s

Out:

<AxesSubplot:title={'center':'se,se->s'}>

Fourth equation: ks,ksm->sm#

dfs = []
equation = "ks,ksm->sm"
df, piv, ax, onx = benchmark_equation(equation)
df.pivot("fct", "dim", "average")
dfs.append(df)
df.pivot("dim", "fct", "average")
Einsum benchmark (N, 1024) - (N, 1024, 768) -- ks,ksm->sm lower better, Einsum Speedup, baseline=numpy (N, 1024) - (N, 1024, 768) -- ks,ksm->sm higher better

Out:

  0%|          | 0/7 [00:00<?, ?it/s]
0.009 rtbest='ks,ksm->sm':   0%|          | 0/7 [00:00<?, ?it/s]
0.009 rtbest='km,kms->ms':   0%|          | 0/7 [00:00<?, ?it/s]
0.009 rtbest='km,kms->ms':  57%|#####7    | 4/7 [00:00<00:00, 35.30it/s]
0.0089 rtbest='sm,smk->mk':  57%|#####7    | 4/7 [00:00<00:00, 35.30it/s]
0.0089 rtbest='sm,smk->mk': 100%|##########| 7/7 [00:00<00:00, 36.02it/s]

  0%|          | 0/4 [00:00<?, ?it/s]
 25%|##5       | 1/4 [00:08<00:26,  8.80s/it]
 50%|#####     | 2/4 [00:24<00:25, 12.62s/it]
 75%|#######5  | 3/4 [03:49<01:40, 100.71s/it]
100%|##########| 4/4 [09:12<00:00, 188.50s/it]
100%|##########| 4/4 [09:12<00:00, 138.17s/it]
fct numpy.einsum ort_dec ort_einsum torch_dec torch_einsum
dim
8 0.046381 0.075154 0.116060 0.150669 0.141059
16 0.087337 0.175303 0.213492 0.218604 0.208969
32 0.169449 1.702189 0.408497 8.416296 2.774563
64 0.350684 3.124925 0.826111 5.511207 11.133901


The onnx decomposition.

plot_onnx_einsum(equation, onx)
ks,ksm->sm

Out:

<AxesSubplot:title={'center':'ks,ksm->sm'}>

Fifth equation: sec,sm->ecm#

dfs = []
equation = "sec,sm->ecm"
df, piv, ax, onx = benchmark_equation(equation, number=1, repeat=1)
df.pivot("fct", "dim", "average")
dfs.append(df)
df.pivot("dim", "fct", "average")
Einsum benchmark (1024, N, 32) - [1024, 768] -- sec,sm->ecm lower better, Einsum Speedup, baseline=numpy (1024, N, 32) - [1024, 768] -- sec,sm->ecm higher better

Out:

  0%|          | 0/25 [00:00<?, ?it/s]
0.012 rtbest='sec,sm->ecm':   0%|          | 0/25 [00:00<?, ?it/s]
0.0091 rtbest='sec,sm->ecm':   0%|          | 0/25 [00:00<?, ?it/s]
0.0091 rtbest='sec,sm->ecm':  16%|#6        | 4/25 [00:00<00:00, 31.96it/s]
0.009 rtbest='esc,em->scm':  16%|#6        | 4/25 [00:00<00:00, 31.96it/s]
0.0087 rtbest='sce,sm->cem':  16%|#6        | 4/25 [00:00<00:00, 31.96it/s]
0.0087 rtbest='sce,sm->cem':  32%|###2      | 8/25 [00:00<00:00, 34.15it/s]
0.0087 rtbest='sce,sm->cem':  48%|####8     | 12/25 [00:00<00:00, 35.39it/s]
0.0087 rtbest='sem,sc->emc':  48%|####8     | 12/25 [00:00<00:00, 35.39it/s]
0.0087 rtbest='sem,sc->emc':  64%|######4   | 16/25 [00:00<00:00, 34.86it/s]
0.0087 rtbest='sem,sc->emc':  80%|########  | 20/25 [00:00<00:00, 35.38it/s]
0.0087 rtbest='sem,sc->emc':  96%|#########6| 24/25 [00:00<00:00, 35.92it/s]
0.0087 rtbest='sem,sc->emc': 100%|##########| 25/25 [00:00<00:00, 35.29it/s]

  0%|          | 0/4 [00:00<?, ?it/s]
 25%|##5       | 1/4 [00:01<00:03,  1.20s/it]
 50%|#####     | 2/4 [00:04<00:04,  2.21s/it]
 75%|#######5  | 3/4 [00:09<00:03,  3.85s/it]
100%|##########| 4/4 [00:21<00:00,  6.75s/it]
100%|##########| 4/4 [00:21<00:00,  5.28s/it]
fct numpy.einsum ort_dec ort_einsum torch_dec torch_einsum
dim
8 0.491584 0.187608 0.100328 0.167776 0.121127
16 1.758361 0.294181 0.309701 0.209646 0.179711
32 4.110958 0.516717 0.401818 0.288493 0.259152
64 8.159549 0.880178 0.801697 0.500670 0.467781


The onnx decomposition.

plot_onnx_einsum(equation, onx)
sec,sm->ecm

Out:

<AxesSubplot:title={'center':'sec,sm->ecm'}>

Sixth equation: sec,ecm->sm#

dfs = []
equation = "sec,ecm->sm"
df, piv, ax, onx = benchmark_equation(equation, number=1, repeat=1)
df.pivot("fct", "dim", "average")
dfs.append(df)
df.pivot("dim", "fct", "average")
Einsum benchmark (1024, N, 32) - (N, 32, 768) -- sec,ecm->sm lower better, Einsum Speedup, baseline=numpy (1024, N, 32) - (N, 32, 768) -- sec,ecm->sm higher better

Out:

  0%|          | 0/25 [00:00<?, ?it/s]
0.0098 rtbest='sec,ecm->sm':   0%|          | 0/25 [00:00<?, ?it/s]
0.0098 rtbest='sec,ecm->sm':  16%|#6        | 4/25 [00:00<00:00, 31.93it/s]
0.0097 rtbest='sce,cem->sm':  16%|#6        | 4/25 [00:00<00:00, 31.93it/s]
0.0097 rtbest='sce,cem->sm':  32%|###2      | 8/25 [00:00<00:00, 32.63it/s]
0.0097 rtbest='mcs,cse->me':  32%|###2      | 8/25 [00:00<00:00, 32.63it/s]
0.0097 rtbest='mcs,cse->me':  48%|####8     | 12/25 [00:00<00:00, 33.26it/s]
0.0096 rtbest='ecs,csm->em':  48%|####8     | 12/25 [00:00<00:00, 33.26it/s]
0.0095 rtbest='sem,emc->sc':  48%|####8     | 12/25 [00:00<00:00, 33.26it/s]
0.0095 rtbest='sem,emc->sc':  64%|######4   | 16/25 [00:00<00:00, 32.73it/s]
0.0095 rtbest='mes,esc->mc':  64%|######4   | 16/25 [00:00<00:00, 32.73it/s]
0.0093 rtbest='ems,msc->ec':  64%|######4   | 16/25 [00:00<00:00, 32.73it/s]
0.0093 rtbest='ems,msc->ec':  80%|########  | 20/25 [00:00<00:00, 33.41it/s]
0.0093 rtbest='ems,msc->ec':  96%|#########6| 24/25 [00:00<00:00, 33.96it/s]
0.0093 rtbest='ems,msc->ec': 100%|##########| 25/25 [00:00<00:00, 33.53it/s]

  0%|          | 0/4 [00:00<?, ?it/s]
 25%|##5       | 1/4 [00:01<00:03,  1.15s/it]
 50%|#####     | 2/4 [00:03<00:03,  1.91s/it]
 75%|#######5  | 3/4 [00:08<00:03,  3.18s/it]
100%|##########| 4/4 [00:17<00:00,  5.74s/it]
100%|##########| 4/4 [00:17<00:00,  4.49s/it]
fct numpy.einsum ort_dec ort_einsum torch_dec torch_einsum
dim
8 0.473223 0.144362 0.097199 0.189313 0.180613
16 1.468553 0.229524 0.180112 0.217261 0.234583
32 3.082025 0.427721 0.350729 0.284155 0.316785
64 6.216536 1.275591 0.712646 0.495009 0.505629


The onnx decomposition.

plot_onnx_einsum(equation, onx)
sec,ecm->sm

Out:

<AxesSubplot:title={'center':'sec,ecm->sm'}>

Conclusion#

pytorch seems quite efficient on these examples.

merged = pandas.concat(dfs)
name = "einsum"
merged.to_csv("plot_%s.csv" % name, index=False)
merged.to_excel("plot_%s.xlsx" % name, index=False)
plt.savefig("plot_%s.png" % name)

# plt.show()
plot op einsum

Total running time of the script: ( 10 minutes 55.111 seconds)

Gallery generated by Sphinx-Gallery