Benchmark operator LeakyRelu#

The operator LeakyRelu is equivalent to the function: LeayRelu(x) = \begin{array}{l} x \text{ if } x > 0  \\
\alpha x \text{otherwise} \end{array}. But it could be rewritten into the following decomposition LeayRelu(x) = x (\indicatrice{x} + \alpha (1 - \indicatrice{x})) =
x ((1 - \alpha) \indicatrice{x} + \alpha). Let’s compare the two implementation with onnx runtimes.

The ONNX graphs for both implementations of LeakyRely#

import numpy
from numpy.testing import assert_almost_equal
import matplotlib.pyplot as plt
from pandas import DataFrame
from onnx import TensorProto
from onnxruntime import InferenceSession, get_device
from skl2onnx.common.data_types import FloatTensorType
from skl2onnx.algebra.onnx_ops import (
    OnnxLeakyRelu, OnnxSign, OnnxMul, OnnxAdd, OnnxDiv,
    OnnxGreater, OnnxCast)
from cpyquickhelper.numbers.speed_measure import measure_time
from mlprodict.testing.experimental_c_impl.experimental_c import code_optimisation
from mlprodict.plotting.plotting import onnx_simple_text_plot
from onnxcustom.plotting.plotting_onnx import plot_onnxs
from tqdm import tqdm


print([code_optimisation(), get_device()])
['AVX-omp=8', 'CPU']

First implementation: the operator LeayRelu.

def build_leaky_relu(alpha=0.5, target_opset=15):
    x = OnnxLeakyRelu('X', alpha=alpha, op_version=target_opset,
                      output_names=['Y'])
    return x.to_onnx({'X': FloatTensorType()},
                     outputs={'Y': FloatTensorType()},
                     target_opset=target_opset)


onx_leaky = build_leaky_relu()
print(onnx_simple_text_plot(onx_leaky))
opset: domain='' version=15
input: name='X' type=dtype('float32') shape=None
LeakyRelu(X, alpha=0.50) -> Y
output: name='Y' type=dtype('float32') shape=None

Second option, the formula introduced above must adapted as ONNX operator Sign returns -1 if x is negative and not 0.

def build_leaky_relu_decomposed(alpha=0.5, target_opset=15):
    signo = OnnxSign('X', op_version=target_opset)
    sign = OnnxDiv(
        OnnxAdd(signo, numpy.array([1], dtype=numpy.float32),
                op_version=target_opset),
        numpy.array([2], dtype=numpy.float32), op_version=target_opset)
    fact = OnnxAdd(
        OnnxMul(sign, numpy.array([1 - alpha], dtype=numpy.float32),
                op_version=target_opset),
        numpy.array([alpha], dtype=numpy.float32),
        op_version=target_opset)
    x = OnnxMul('X', fact, op_version=target_opset,
                output_names=['Y'])
    return x.to_onnx({'X': FloatTensorType()},
                     outputs={'Y': FloatTensorType()},
                     target_opset=target_opset)


onx_leaky_dec = build_leaky_relu_decomposed()
print(onnx_simple_text_plot(onx_leaky_dec))
opset: domain='' version=15
input: name='X' type=dtype('float32') shape=None
init: name='Ad_Addcst' type=dtype('float32') shape=(1,) -- array([1.], dtype=float32)
init: name='Di_Divcst' type=dtype('float32') shape=(1,) -- array([2.], dtype=float32)
init: name='Mu_Mulcst' type=dtype('float32') shape=(1,) -- array([0.5], dtype=float32)
Identity(Mu_Mulcst) -> Ad_Addcst1
Sign(X) -> Si_output0
  Add(Si_output0, Ad_Addcst) -> Ad_C01
    Div(Ad_C01, Di_Divcst) -> Di_C0
      Mul(Di_C0, Mu_Mulcst) -> Mu_C0
  Add(Mu_C0, Ad_Addcst1) -> Ad_C0
    Mul(X, Ad_C0) -> Y
output: name='Y' type=dtype('float32') shape=None

Third option, use of operater Greater

def build_leaky_relu_decomposed_greater(alpha=0.5, target_opset=15):
    signo = OnnxGreater('X', numpy.array([0], dtype=numpy.float32),
                        op_version=target_opset)
    sign = OnnxCast(signo, to=TensorProto.FLOAT,
                    op_version=target_opset)
    fact = OnnxAdd(
        OnnxMul(sign, numpy.array([1 - alpha], dtype=numpy.float32),
                op_version=target_opset),
        numpy.array([alpha], dtype=numpy.float32),
        op_version=target_opset)
    x = OnnxMul('X', fact, op_version=target_opset,
                output_names=['Y'])
    return x.to_onnx({'X': FloatTensorType()},
                     outputs={'Y': FloatTensorType()},
                     target_opset=target_opset)


onx_leaky_dec_greater = build_leaky_relu_decomposed_greater()
print(onnx_simple_text_plot(onx_leaky_dec_greater))
opset: domain='' version=15
input: name='X' type=dtype('float32') shape=None
init: name='Gr_Greatercst' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32)
init: name='Mu_Mulcst' type=dtype('float32') shape=(1,) -- array([0.5], dtype=float32)
Greater(X, Gr_Greatercst) -> Gr_C0
  Cast(Gr_C0, to=1) -> Ca_output0
    Mul(Ca_output0, Mu_Mulcst) -> Mu_C0
Identity(Mu_Mulcst) -> Ad_Addcst
  Add(Mu_C0, Ad_Addcst) -> Ad_C0
    Mul(X, Ad_C0) -> Y
output: name='Y' type=dtype('float32') shape=None

Visually

plot_onnxs(onx_leaky, onx_leaky_dec, onx_leaky_dec_greater,
           title=["One operator", "Decomposed\nLeakyRelu",
                  "Decomposed\nLeakyRelu Greater"])
One operator, Decomposed LeakyRelu, Decomposed LeakyRelu Greater
array([<AxesSubplot: title={'center': 'One operator'}>,
       <AxesSubplot: title={'center': 'Decomposed\nLeakyRelu'}>,
       <AxesSubplot: title={'center': 'Decomposed\nLeakyRelu Greater'}>],
      dtype=object)

Check that both graph returns are equivalent#

sess1 = InferenceSession(onx_leaky.SerializeToString(),
                         providers=['CPUExecutionProvider'])
sess_dec = InferenceSession(onx_leaky_dec.SerializeToString(),
                            providers=['CPUExecutionProvider'])
sess_dec_greater = InferenceSession(onx_leaky_dec_greater.SerializeToString(),
                                    providers=['CPUExecutionProvider'])

for shape in [(1, ), (10, ), (5, 5), (7, 2, 4)]:
    rnd = numpy.random.randn(*shape).astype(numpy.float32)
    res1 = sess1.run(None, {'X': rnd})[0]
    res_dec = sess_dec.run(None, {'X': rnd})[0]
    res_dec_greater = sess_dec_greater.run(None, {'X': rnd})[0]
    assert_almost_equal(res1, res_dec)
    assert_almost_equal(res1, res_dec_greater)

Benchmark#

fcts = [('leakyrelu', sess1), ('dec', sess_dec),
        ('dec_greater', sess_dec_greater)]

N = 100
data = []
for dim in tqdm([10, 128, 256, 512, 1000, 2000]):
    for shape in [(N, dim), (dim, N)]:
        rnd = numpy.random.randn(*shape).astype(numpy.float32)
        for name, sess in fcts:
            repeat = int(4001 / dim)
            obs = measure_time(
                lambda: sess.run(None, {'X': rnd}),
                context=dict(rnd=rnd, sess=sess),
                div_by_number=True, repeat=repeat, number=200)
            obs['name'] = name
            obs['N'] = N
            obs['dim'] = dim
            obs['orient'] = shape[0] == N
            obs['shape'] = "%dx%d" % shape
            data.append(obs)

df = DataFrame(data)
df[['name', 'N', 'dim', 'average', 'deviation']]

print(df[['name', 'N', 'dim', 'average']])
  0%|          | 0/6 [00:00<?, ?it/s]
 17%|#6        | 1/6 [00:28<02:20, 28.05s/it]
 33%|###3      | 2/6 [00:32<00:57, 14.37s/it]
 50%|#####     | 3/6 [00:36<00:28,  9.56s/it]
 67%|######6   | 4/6 [00:40<00:14,  7.32s/it]
 83%|########3 | 5/6 [00:44<00:06,  6.01s/it]
100%|##########| 6/6 [00:47<00:00,  4.99s/it]
100%|##########| 6/6 [00:47<00:00,  7.88s/it]
           name    N   dim   average
0     leakyrelu  100    10  0.000047
1           dec  100    10  0.000068
2   dec_greater  100    10  0.000060
3     leakyrelu  100    10  0.000047
4           dec  100    10  0.000068
5   dec_greater  100    10  0.000060
6     leakyrelu  100   128  0.000076
7           dec  100   128  0.000194
8   dec_greater  100   128  0.000117
9     leakyrelu  100   128  0.000074
10          dec  100   128  0.000193
11  dec_greater  100   128  0.000116
12    leakyrelu  100   256  0.000095
13          dec  100   256  0.000349
14  dec_greater  100   256  0.000198
15    leakyrelu  100   256  0.000093
16          dec  100   256  0.000348
17  dec_greater  100   256  0.000194
18    leakyrelu  100   512  0.000142
19          dec  100   512  0.000769
20  dec_greater  100   512  0.000470
21    leakyrelu  100   512  0.000142
22          dec  100   512  0.000770
23  dec_greater  100   512  0.000470
24    leakyrelu  100  1000  0.000300
25          dec  100  1000  0.001098
26  dec_greater  100  1000  0.000865
27    leakyrelu  100  1000  0.000394
28          dec  100  1000  0.001098
29  dec_greater  100  1000  0.000828
30    leakyrelu  100  2000  0.000613
31          dec  100  2000  0.001821
32  dec_greater  100  2000  0.001316
33    leakyrelu  100  2000  0.000576
34          dec  100  2000  0.001818
35  dec_greater  100  2000  0.001269

Other to way to look at it.

def speedup(piv):
    for c in piv.columns:
        if c == 'leakyrelu':
            continue
        piv[c] = piv['leakyrelu'] / piv[c]
    piv['leakyrelu'] = 1
    return piv


piv = speedup(df.pivot('shape', 'name', 'average'))
piv
name dec dec_greater leakyrelu
shape
1000x100 0.358262 0.475129 1
100x10 0.689093 0.780442 1
100x1000 0.273181 0.346878 1
100x128 0.390485 0.649343 1
100x2000 0.336393 0.465464 1
100x256 0.272989 0.480707 1
100x512 0.183923 0.300987 1
10x100 0.692456 0.783562 1
128x100 0.384721 0.640376 1
2000x100 0.316868 0.453896 1
256x100 0.266726 0.477558 1
512x100 0.184956 0.302847 1


Graph.

fig, ax = plt.subplots(1, 2, figsize=(12, 5))
speedup(df[df.orient].pivot('dim', 'name', 'average')).plot(ax=ax[0])
ax[0].set_title("LeakyRelu speedup, shape=(%d,dim)"
                "\nThe higher the better" % N)
speedup(df[~df.orient].pivot('dim', 'name', 'average')).plot(ax=ax[1])
ax[1].set_title("LeakyRelu speedup, shape=(dim,%d)"
                "\nThe higher the better" % N)
LeakyRelu speedup, shape=(100,dim) The higher the better, LeakyRelu speedup, shape=(dim,100) The higher the better
Text(0.5, 1.0, 'LeakyRelu speedup, shape=(dim,100)\nThe higher the better')

This kind of benchmark helps finding better implementation of operator runtime.

# plt.show()

Total running time of the script: ( 0 minutes 49.961 seconds)

Gallery generated by Sphinx-Gallery