NeuralTreeNet et ONNX#

Links: notebook, html, PDF, python, slides, GitHub

La conversion d’un arbre de décision au format ONNX peut créer des différences entre le modèle original et le modèle converti (voir Issues when switching to float. Le problème vient d’un changement de type, les seuils de décisions sont arrondis au float32 le plus proche de leur valeur en float64 (double). Qu’advient-il si l’arbre de décision est converti en réseau de neurones d’abord.

L’approximation des seuils de décision ne change pas grand chose dans la majorité des cas. Cependant, il est possible que la comparaison d’une variable à un seuil de décision arrondi soit l’opposé de celle avec le seuil non arrondi. Dans ce cas, la décision suit un chemin différent dans l’arbre.

from jyquickhelper import add_notebook_menu
add_notebook_menu()
%matplotlib inline
%load_ext mlprodict

Jeu de données#

On construit un jeu de donnée aléatoire.

import numpy

X = numpy.random.randn(10000, 10)
y = X.sum(axis=1) / X.shape[1]
X = X.astype(numpy.float64)
y = y.astype(numpy.float64)
middle = X.shape[0] // 2
X_train, X_test = X[:middle], X[middle:]
y_train, y_test = y[:middle], y[middle:]

Partie scikit-learn#

Caler un arbre de décision#

from sklearn.tree import DecisionTreeRegressor

tree = DecisionTreeRegressor(max_depth=7)
tree.fit(X_train, y_train)
tree.score(X_train, y_train), tree.score(X_test, y_test)
(0.6179766027481131, 0.33709933420465643)
from sklearn.metrics import r2_score
r2_score(y_test, tree.predict(X_test))
0.33709933420465643

La profondeur de l’arbre est insuffisante mais ce n’est pas ce qui nous intéresse ici.

Conversion au format ONNX#

from mlprodict.onnx_conv import to_onnx

onx = to_onnx(tree, X[:1].astype(numpy.float32))
from mlprodict.onnxrt import OnnxInference

x_exp = X_test

oinf = OnnxInference(onx, runtime='onnxruntime1')
expected = tree.predict(x_exp)

got = oinf.run({'X': x_exp.astype(numpy.float32)})['variable']
numpy.abs(got - expected).max()
1.7421041873949668
from mlprodict.plotting.text_plot import onnx_simple_text_plot
print(onnx_simple_text_plot(onx))
opset: domain='ai.onnx.ml' version=1
opset: domain='' version=15
input: name='X' type=dtype('float32') shape=[None, 10]
TreeEnsembleRegressor(X, n_targets=1, nodes_falsenodeids=253:[128,65,34...252,0,0], nodes_featureids=253:[8,3,9...2,0,0], nodes_hitrates=253:[1.0,1.0...1.0,1.0], nodes_missing_value_tracks_true=253:[0,0,0...0,0,0], nodes_modes=253:[b'BRANCH_LEQ',b'BRANCH_LEQ'...b'LEAF',b'LEAF'], nodes_nodeids=253:[0,1,2...250,251,252], nodes_treeids=253:[0,0,0...0,0,0], nodes_truenodeids=253:[1,2,3...251,0,0], nodes_values=253:[0.00792999193072319,-0.12246682494878769...0.0,0.0], post_transform=b'NONE', target_ids=127:[0,0,0...0,0,0], target_nodeids=127:[7,8,10...249,251,252], target_treeids=127:[0,0,0...0,0,0], target_weights=127:[-0.9345570802688599,-0.6372960805892944...0.6169403195381165,1.0096807479858398]) -> variable
output: name='variable' type=dtype('float32') shape=[None, 1]

Après la conversion en un réseau de neurones#

Conversion en un réseau de neurones#

Un paramètre permet de faire varier la pente des fonctions sigmoïdes utilisées.

from tqdm import tqdm
from pandas import DataFrame
from mlstatpy.ml.neural_tree import NeuralTreeNet

xe = x_exp[:500]
expected = tree.predict(xe)

data = []
trees = {}
for i in tqdm([0.3, 0.4, 0.5, 0.7, 0.9, 1] + list(range(5, 61, 5))):
    root = NeuralTreeNet.create_from_tree(tree, k=i, arch='compact')
    got = root.predict(xe)[:, -1]
    me = numpy.abs(got - expected).mean()
    mx = numpy.abs(got - expected).max()
    obs = dict(k=i, max=mx, mean=me)
    data.append(obs)
    trees[i] = root
100%|██████████| 18/18 [00:01<00:00, 12.49it/s]
df = DataFrame(data)
df
k max mean
0 0.3 0.568981 0.158758
1 0.4 0.608304 0.132576
2 0.5 0.692657 0.128525
3 0.7 0.780543 0.131497
4 0.9 0.809866 0.128368
5 1.0 0.813889 0.124802
6 5.0 0.392482 0.022466
7 10.0 0.341749 0.006350
8 15.0 0.270649 0.002939
9 20.0 0.299713 0.002110
10 25.0 0.305493 0.001842
11 30.0 0.306111 0.001767
12 35.0 0.299371 0.001665
13 40.0 0.233556 0.001011
14 45.0 0.233606 0.000801
15 50.0 0.233614 0.000547
16 55.0 0.233615 0.000499
17 60.0 0.233615 0.000484
df.set_index('k').plot(title="Précision de la conversion\nen réseau de neurones");
../_images/neural_tree_onnx_20_0.png

L’erreur est meilleure mais il faudrait recommencer l’expérience plusieurs fois avant de pouvoir conclure afin d’obtenir un interval de confiance pour le même type de jeu de données. Ce sera pour une autre fois. Le résultat dépend du jeu de données et surtout de la proximité des seuils de décisions. Néanmoins, on calcule l’erreur sur l’ensemble de la base de test. Celle-ci a été tronquée pour aller plus vite.

expected = tree.predict(x_exp)
got = trees[50].predict(x_exp)[:, -1]
numpy.abs(got - expected).max(), numpy.abs(got - expected).mean()
(0.2336143002078063, 0.0002511855017989173)

On voit que l’erreur peut-être très grande. Elle reste néanmoins plus petite que l’erreur de conversion introduite par ONNX.

Conversion au format ONNX#

On crée tout d’abord une classe qui suit l’API de scikit-learn et qui englobe l’arbre qui vient d’être créé qui sera ensuite convertit en ONNX.

from mlstatpy.ml.neural_tree import NeuralTreeNetRegressor

reg = NeuralTreeNetRegressor(trees[50])
onx2 = to_onnx(reg, X[:1].astype(numpy.float32))
print(onnx_simple_text_plot(onx2))
opset: domain='' version=15
input: name='X' type=dtype('float32') shape=[None, 10]
init: name='Ma_MatMulcst' type=dtype('float32') shape=(1260,)
init: name='Ad_Addcst' type=dtype('float32') shape=(126,)
init: name='Mu_Mulcst' type=dtype('float32') shape=(1,) -- array([4.], dtype=float32)
init: name='Ma_MatMulcst1' type=dtype('float32') shape=(16002,)
init: name='Ad_Addcst1' type=dtype('float32') shape=(127,)
init: name='Ma_MatMulcst2' type=dtype('float32') shape=(127,)
init: name='Ad_Addcst2' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32)
MatMul(X, Ma_MatMulcst) -> Ma_Y02
  Add(Ma_Y02, Ad_Addcst) -> Ad_C02
    Mul(Ad_C02, Mu_Mulcst) -> Mu_C01
      Sigmoid(Mu_C01) -> Si_Y01
        MatMul(Si_Y01, Ma_MatMulcst1) -> Ma_Y01
          Add(Ma_Y01, Ad_Addcst1) -> Ad_C01
            Mul(Ad_C01, Mu_Mulcst) -> Mu_C0
              Sigmoid(Mu_C0) -> Si_Y0
                MatMul(Si_Y0, Ma_MatMulcst2) -> Ma_Y0
                  Add(Ma_Y0, Ad_Addcst2) -> Ad_C0
                    Identity(Ad_C0) -> variable
output: name='variable' type=dtype('float32') shape=[None, 1]
oinf2 = OnnxInference(onx2, runtime='onnxruntime1')
expected = tree.predict(x_exp)

got = oinf2.run({'X': x_exp.astype(numpy.float32)})['variable']
numpy.abs(got - expected).max()
1.7421041873949668

L’erreur est la même.

Temps de calcul#

x_exp32 = x_exp.astype(numpy.float32)

Tout d’abord le temps de calcul pour scikit-learn.

%timeit tree.predict(x_exp32)
513 µs ± 7.52 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Le temps de calcul pour l’arbre de décision au format ONNX.

%timeit oinf.run({'X': x_exp32})['variable']
186 µs ± 3.41 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Et le temps de calcul pour le réseau de neurones au format ONNX.m

%timeit oinf2.run({'X': x_exp32})['variable']
3.75 ms ± 311 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Ce temps de calcul très long est attendu car le modèle contient une multiplication de matrice très grande et surtout que tous les seuils de l’arbre sont calculés pour chaque observation. Là où l’implémentation de l’arbre de décision calcule d seuils, la profondeur de l’arbre, la nouvelle implémentation calcule tous les seuils soit 2^d pour chaque feuille. Il y a 2^d feuilles. Même en étant sparse, on peut réduire les calculs à d * 2^d ce qui fait encore beaucoup de calculs inutiles.

for node in trees[50].nodes:
    print(node.coef.shape, node.bias.shape)
(126, 11) (126,)
(127, 127) (127,)
(128,) ()

Cela dit, la plus grande matrice est creuse, elle peut être réduite considérablement.

from scipy.sparse import csr_matrix

for node in trees[50].nodes:
    csr = csr_matrix(node.coef)
    print(f"coef.shape={node.coef.shape}, size dense={node.coef.size}, "
          f"size sparse={csr.size}, ratio={csr.size / node.coef.size}")
coef.shape=(126, 11), size dense=1386, size sparse=252, ratio=0.18181818181818182
coef.shape=(127, 127), size dense=16129, size sparse=1015, ratio=0.06293012586025172
coef.shape=(128,), size dense=128, size sparse=127, ratio=0.9921875
r = numpy.random.randn(trees[50].nodes[1].coef.shape[0])
mat = trees[50].nodes[1].coef
%timeit mat @ r
49.8 µs ± 1.25 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
csr = csr_matrix(mat)
%timeit csr @ r
7.08 µs ± 173 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Ce serait beaucoup plus rapide avec une matrice sparse et d’autant plus rapide que l’arbre est profond. Le modèle ONNX se décompose comme suit.

print(onnx_simple_text_plot(onx2))
opset: domain='' version=15
input: name='X' type=dtype('float32') shape=[None, 10]
init: name='Ma_MatMulcst' type=dtype('float32') shape=(1260,)
init: name='Ad_Addcst' type=dtype('float32') shape=(126,)
init: name='Mu_Mulcst' type=dtype('float32') shape=(1,) -- array([4.], dtype=float32)
init: name='Ma_MatMulcst1' type=dtype('float32') shape=(16002,)
init: name='Ad_Addcst1' type=dtype('float32') shape=(127,)
init: name='Ma_MatMulcst2' type=dtype('float32') shape=(127,)
init: name='Ad_Addcst2' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32)
MatMul(X, Ma_MatMulcst) -> Ma_Y02
  Add(Ma_Y02, Ad_Addcst) -> Ad_C02
    Mul(Ad_C02, Mu_Mulcst) -> Mu_C01
      Sigmoid(Mu_C01) -> Si_Y01
        MatMul(Si_Y01, Ma_MatMulcst1) -> Ma_Y01
          Add(Ma_Y01, Ad_Addcst1) -> Ad_C01
            Mul(Ad_C01, Mu_Mulcst) -> Mu_C0
              Sigmoid(Mu_C0) -> Si_Y0
                MatMul(Si_Y0, Ma_MatMulcst2) -> Ma_Y0
                  Add(Ma_Y0, Ad_Addcst2) -> Ad_C0
                    Identity(Ad_C0) -> variable
output: name='variable' type=dtype('float32') shape=[None, 1]

Voyons comment le temps de calcul se répartit.

oinfpr = OnnxInference(onx2, runtime="onnxruntime1",
                      runtime_options={"enable_profiling": True})
for i in range(0, 43):
    oinfpr.run({"X": x_exp32})
df = oinfpr.get_profiling(as_df=True)
df
cat pid tid dur ts ph name args_op_name args_parameter_size args_graph_index args_provider args_exec_plan_index args_activation_size args_output_size args_input_type_shape args_output_type_shape args_thread_scheduling_stats
0 Session 78116 8820 387 4 X model_loading_array NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
1 Session 78116 8820 2532 428 X session_initialization NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2 Node 78116 8820 0 3294 X gemm_fence_before Gemm NaN NaN NaN NaN NaN NaN NaN NaN NaN
3 Node 78116 8820 1315 3300 X gemm_kernel_time Gemm 5544 11 CPUExecutionProvider 11 200000 2520000 [{'float': [5000, 10]}, {'float': [10, 126]}, ... [{'float': [5000, 126]}] {'main_thread': {'thread_pool_name': 'session-...
4 Node 78116 8820 0 4635 X gemm_fence_after Gemm NaN NaN NaN NaN NaN NaN NaN NaN NaN
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
986 Node 78116 8820 0 210170 X Ma_MatMul2_fence_before MatMul NaN NaN NaN NaN NaN NaN NaN NaN NaN
987 Node 78116 8820 124 210172 X Ma_MatMul2_kernel_time MatMul 508 8 CPUExecutionProvider 8 2540000 20000 [{'float': [5000, 127]}, {'float': [127, 1]}] [{'float': [5000, 1]}] {'main_thread': {'thread_pool_name': 'session-...
988 Node 78116 8820 0 210305 X Ma_MatMul2_fence_after MatMul NaN NaN NaN NaN NaN NaN NaN NaN NaN
989 Session 78116 8820 4378 205930 X SequentialExecutor::Execute NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
990 Session 78116 8820 4388 205925 X model_run NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

991 rows × 17 columns

set(df['args_provider'])
{'CPUExecutionProvider', nan}
dfp = df[df.args_provider == 'CPUExecutionProvider'].copy()
dfp['name'] = dfp['name'].apply(lambda s: s.replace("_kernel_time", ""))
gr_dur = dfp[['dur', "args_op_name", "name"]].groupby(["args_op_name", "name"]).sum().sort_values('dur')
gr_dur
dur
args_op_name name
MatMul Ma_MatMul2 6778
Mul Mu_Mul 12923
Sigmoid Si_Sigmoid 14849
Mul Mu_Mul1 15151
Sigmoid Si_Sigmoid1 15608
Gemm gemm 31763
gemm_token_0 99047
gr_n = dfp[['dur', "args_op_name", "name"]].groupby(["args_op_name", "name"]).count().sort_values('dur')
gr_n = gr_n.loc[gr_dur.index, :]
gr_n
dur
args_op_name name
MatMul Ma_MatMul2 43
Mul Mu_Mul 43
Sigmoid Si_Sigmoid 43
Mul Mu_Mul1 43
Sigmoid Si_Sigmoid1 43
Gemm gemm 43
gemm_token_0 43
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 2, figsize=(12, 4))
gr_dur.plot.barh(ax=ax[0])
gr_n.plot.barh(ax=ax[1])
ax[0].set_title("duration")
ax[1].set_title("n occurences");
../_images/neural_tree_onnx_51_0.png

onnxruntime passe principalement son temps dans un produit matriciel. On vérifie plus précisément.

df[(df.args_op_name == 'Gemm') & (df.dur > 0)].sort_values('dur', ascending=False).head(n=2).T
127 12
cat Node Node
pid 78116 78116
tid 8820 8820
dur 4603 4083
ts 37173 5949
ph X X
name gemm_token_0_kernel_time gemm_token_0_kernel_time
args_op_name Gemm Gemm
args_parameter_size 64516 64516
args_graph_index 12 12
args_provider CPUExecutionProvider CPUExecutionProvider
args_exec_plan_index 12 12
args_activation_size 2520000 2520000
args_output_size 2540000 2540000
args_input_type_shape [{'float': [5000, 126]}, {'float': [126, 127]}... [{'float': [5000, 126]}, {'float': [126, 127]}...
args_output_type_shape [{'float': [5000, 127]}] [{'float': [5000, 127]}]
args_thread_scheduling_stats {'main_thread': {'thread_pool_name': 'session-... {'main_thread': {'thread_pool_name': 'session-...

C’est un produit matriciel d’environ 5000x800 par 800x800.

gr_dur / gr_dur.dur.sum()
dur
args_op_name name
MatMul Ma_MatMul2 0.034561
Mul Mu_Mul 0.065894
Sigmoid Si_Sigmoid 0.075714
Mul Mu_Mul1 0.077254
Sigmoid Si_Sigmoid1 0.079584
Gemm gemm 0.161958
gemm_token_0 0.505035
r = (gr_dur / gr_dur.dur.sum()).dur.max()
r
0.5050352082154203

Il occupe 82% du temps. et d’après l’expérience précédente, son temps d’éxecution peut-être réduit par 10 en le remplaçant par une matrice sparse. Cela ne suffira pas pour accélerer le temps de calcul de ce réseau de neurones. Il est 84 ms comparé à 247 µs pour l’arbre de décision. Avec cette optimisation, il pourrait passer de :

t = 3.75  # ms
t * (1 - r) + r * t / 12
2.013941471759493

Soit une réduction du temps de calcul. Ce n’est pas mal mais pas assez.

Hummingbird#

hummingbird est une librairie qui convertit un arbre de décision en réseau de neurones. Voyons ses performances.

from hummingbird.ml import convert

model = convert(tree, 'torch')

expected = tree.predict(x_exp)
got = model.predict(x_exp)
numpy.abs(got - expected).max(), numpy.abs(got - expected).mean()
C:xavierdupre__home_github_forkscikit-learnsklearnutilsdeprecation.py:103: FutureWarning: The attribute n_features_ is deprecated in 1.0 and will be removed in 1.2. Use n_features_in_ instead.
  warnings.warn(msg, category=FutureWarning)
(4.3419181139370266e-08, 4.430287026515114e-09)

Le résultat est beaucoup plus fidèle au modèle.

%timeit model.predict(x_exp)
1.17 ms ± 34.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Il reste plus lent mais beaucoup plus rapide que la solution manuelle proposée dans les précédents paragraphes. Il contient un attribut model.

from torch.nn import Module
isinstance(model.model, Module)
True

On convertit ce modèle au format ONNX.

import torch.onnx

x = torch.randn(x_exp.shape[0], x_exp.shape[1], requires_grad=True)
torch.onnx.export(model.model, x, 'tree_torch.onnx', opset_version=15,
                  input_names=['X'], output_names=['variable'],
                  dynamic_axes={
                      'X' : {0 : 'batch_size'},
                      'variable' : {0 : 'batch_size'}})
import onnx

onxh = onnx.load('tree_torch.onnx')
print(onnx_simple_text_plot(onxh, raise_exc=False))
opset: domain='' version=15
input: name='X' type=dtype('float32') shape=['batch_size', 10]
init: name='_operators.0.root_nodes' type=dtype('int64') shape=(0,) -- array([8], dtype=int64)
init: name='_operators.0.root_biases' type=dtype('float32') shape=(0,) -- array([0.00792999], dtype=float32)
init: name='_operators.0.tree_indices' type=dtype('int64') shape=(0,) -- array([0], dtype=int64)
init: name='_operators.0.leaf_nodes' type=dtype('float32') shape=(0,) -- array([ 1.0096807 ,  0.6169403 ,  0.61055773,  0.37810475,  0.31796893,
        0.13317925,  0.0193846 , -0.2317742 ,  0.39089343,  0.23506087,
        0.3711936 ,  0.10317916,  0.14956598, -0.14193445, -0.05965868,
       -0.27377078,  0.4128183 ,  0.19658326,  0.25545415,  0.08118545,
        0.08400188, -0.1502193 , -0.36846825, -0.79687625,  0.35822242,
        0.49021915,  0.30870998,  0.01033915,  0.6740977 ,  0.6740977 ,
       -0.15315758, -0.41128033,  0.42920846,  0.13145493,  0.21853392,
       -0.10986731,  0.4493652 ,  0.11318789,  0.12666471, -0.0623082 ,
        0.2872893 ,  0.09948976,  0.11439473, -0.08801427,  0.16091613,
       -0.02319027, -0.10097775, -0.37583745,  0.18612385, -0.00453244,
        0.3287116 , -0.1499349 ,  0.7919218 ,  0.04704398, -0.15423109,
       -0.43160027,  0.10802375, -0.1073833 , -0.07759219, -0.29175794,
       -0.1528881 , -0.4909434 , -0.23361537, -0.43578717,  0.7831867 ,
        0.45349318,  0.34956965, -0.3199535 ,  0.3061573 , -0.34267113,
        0.34963542,  0.04491445,  0.35399815,  0.14815213,  0.06678926,
       -0.16095412,  0.3214274 ,  0.01484008, -0.1012276 , -0.3257699 ,
        0.26727676,  0.01970094,  0.10760042, -0.09169976,  0.20044112,
       -0.0324069 , -0.11015374, -0.28358367,  0.8083656 ,  0.13358633,
       -0.07912118, -0.27182895, -0.07054728, -0.24895027, -0.20600456,
       -0.42033467,  0.34701794, -0.0638995 ,  0.14252576, -0.06025055,
        0.4228329 ,  0.06789401,  0.03919645, -0.17267554,  0.07274943,
       -0.487512  ,  0.04517636, -0.18857062, -0.03975222, -0.2652712 ,
       -0.30853328, -0.50844556,  0.03321444, -0.15481217, -0.20701212,
       -0.40578464, -0.25884995, -0.46550158, -0.4797585 , -0.7324234 ,
        0.43939307, -0.06170902, -0.51546025, -0.19215119, -0.3705445 ,
       -0.57504356, -0.6372961 , -0.9345571 ], dtype=float32)
init: name='_operators.0.nodes.0' type=dtype('int64') shape=(0,) -- array([0, 3], dtype=int64)
init: name='_operators.0.nodes.1' type=dtype('int64') shape=(0,) -- array([1, 2, 5, 9], dtype=int64)
init: name='_operators.0.nodes.2' type=dtype('int64') shape=(0,) -- array([5, 6, 3, 7, 2, 0, 7, 1], dtype=int64)
init: name='_operators.0.nodes.3' type=dtype('int64') shape=(0,) -- array([3, 9, 5, 3, 6, 4, 1, 3, 6, 6, 1, 6, 5, 4, 6, 2], dtype=int64)
init: name='_operators.0.nodes.4' type=dtype('int64') shape=(0,) -- array([3, 2, 7, 6, 2, 4, 7, 8, 9, 5, 7, 8, 9, 4, 6, 9, 7, 9, 0, 7, 7, 9,
       2, 7, 6, 4, 6, 5, 4, 0, 6, 0], dtype=int64)
init: name='_operators.0.nodes.5' type=dtype('int64') shape=(0,) -- array([2, 8, 7, 6, 6, 3, 4, 9, 7, 3, 2, 6, 3, 3, 0, 1, 1, 0, 4, 7, 9, 5,
       7, 9, 5, 3, 5, 9, 0, 5, 1, 4, 9, 4, 7, 7, 1, 9, 1, 1, 6, 2, 7, 7,
       6, 1, 4, 4, 0, 0, 9, 8, 8, 2, 6, 2, 0, 3, 4, 2, 5, 6, 7, 3],
      dtype=int64)
init: name='_operators.0.biases.0' type=dtype('float32') shape=(0,) -- array([ 0.19169255, -0.12246682], dtype=float32)
init: name='_operators.0.biases.1' type=dtype('float32') shape=(0,) -- array([-0.40610337, -0.1467492 , -0.01880287,  0.15879431], dtype=float32)
init: name='_operators.0.biases.2' type=dtype('float32') shape=(0,) -- array([ 0.736786  , -0.32427853,  0.30860555,  0.17994082,  0.6917758 ,
       -0.00594712,  0.35950053, -0.9819274 ], dtype=float32)
init: name='_operators.0.biases.3' type=dtype('float32') shape=(0,) -- array([-1.3495584 , -1.082793  , -0.6906011 , -0.08978076, -0.4007622 ,
        0.10756078, -0.68507075,  0.15814054,  0.5132364 , -0.18426335,
        0.13685235,  0.10721841,  0.01814443, -0.41644228, -0.59770894,
        0.607365  ], dtype=float32)
init: name='_operators.0.biases.4' type=dtype('float32') shape=(0,) -- array([ 1.4203796 , -0.49269757, -0.12210988, -0.09692484,  0.5076643 ,
       -1.3609421 ,  1.154743  ,  2.8748922 , -0.08181615,  0.7741028 ,
        0.20604724,  0.666296  , -0.6474025 ,  0.6459148 ,  0.02262808,
       -0.42282397,  0.46360654, -0.10058792,  0.25486696,  0.60041225,
       -0.06933744,  0.21294908,  0.96443814,  0.07923891,  0.4797698 ,
        1.2852331 ,  0.24348404, -0.3404966 , -0.07175394, -0.8248828 ,
       -0.74071133, -1.2140133 ], dtype=float32)
init: name='_operators.0.biases.5' type=dtype('float32') shape=(0,) -- array([ 1.0626682 ,  1.4745288 ,  0.01898679,  0.5451088 ,  0.15444604,
        1.0631477 , -0.7555804 , -1.7192128 , -0.20905146,  0.19752283,
       -0.40471953,  0.13069782,  0.60331047,  1.5060809 ,  0.        ,
       -1.8283446 , -0.8124372 , -1.381897  ,  0.59209645,  0.3239226 ,
       -0.42840806, -0.43624896,  0.58229303, -1.0196047 , -0.5632828 ,
        0.91483426,  1.8038778 , -0.5665638 , -1.2530733 , -0.6500004 ,
       -1.3069727 ,  0.48267984,  0.73503745, -1.871724  , -1.4965518 ,
        1.3147466 ,  0.03919952, -0.885836  ,  0.5479692 , -0.8086383 ,
       -0.74240863,  0.14582941,  0.6496967 , -0.00911551,  2.4541488 ,
       -0.90482277,  0.26108736,  0.7569448 , -1.0786855 , -0.45229852,
        1.2146595 , -0.6756766 , -2.3066258 ,  0.7911504 ,  0.57490873,
       -0.40741247,  0.24633038, -1.2022957 , -0.65162694, -0.04244827,
        1.558136  , -1.6220782 ,  0.1574643 , -1.4209061 ], dtype=float32)
Constant(value=[-1]) -> onnx::Reshape_27
Gather(X, _operators.0.root_nodes, axis=1) -> onnx::LessOrEqual_17
  LessOrEqual(onnx::LessOrEqual_17, _operators.0.root_biases) -> onnx::Cast_18
    Cast(onnx::Cast_18, to=7) -> onnx::Add_19
      Add(onnx::Add_19, _operators.0.tree_indices) -> onnx::Reshape_20
Constant(value=[-1]) -> onnx::Reshape_21
  Reshape(onnx::Reshape_20, onnx::Reshape_21, allowzero=0) -> onnx::Gather_22
    Gather(_operators.0.nodes.0, onnx::Gather_22, axis=0) -> onnx::Reshape_23
Constant(value=[-1, 1]) -> onnx::Reshape_24
  Reshape(onnx::Reshape_23, onnx::Reshape_24, allowzero=0) -> onnx::GatherElements_25
    GatherElements(X, onnx::GatherElements_25, axis=1) -> onnx::Reshape_26
  Reshape(onnx::Reshape_26, onnx::Reshape_27, allowzero=0) -> onnx::LessOrEqual_28
Constant(value=2) -> onnx::Mul_29
  Mul(onnx::Gather_22, onnx::Mul_29) -> onnx::Add_30
Gather(_operators.0.biases.0, onnx::Gather_22, axis=0) -> onnx::LessOrEqual_31
  LessOrEqual(onnx::LessOrEqual_28, onnx::LessOrEqual_31) -> onnx::Cast_32
    Cast(onnx::Cast_32, to=7) -> onnx::Add_33
    Add(onnx::Add_30, onnx::Add_33) -> onnx::Gather_34
      Gather(_operators.0.nodes.1, onnx::Gather_34, axis=0) -> onnx::Reshape_35
Constant(value=[-1, 1]) -> onnx::Reshape_36
  Reshape(onnx::Reshape_35, onnx::Reshape_36, allowzero=0) -> onnx::GatherElements_37
    GatherElements(X, onnx::GatherElements_37, axis=1) -> onnx::Reshape_38
Constant(value=[-1]) -> onnx::Reshape_39
  Reshape(onnx::Reshape_38, onnx::Reshape_39, allowzero=0) -> onnx::LessOrEqual_40
Constant(value=2) -> onnx::Mul_41
  Mul(onnx::Gather_34, onnx::Mul_41) -> onnx::Add_42
Gather(_operators.0.biases.1, onnx::Gather_34, axis=0) -> onnx::LessOrEqual_43
  LessOrEqual(onnx::LessOrEqual_40, onnx::LessOrEqual_43) -> onnx::Cast_44
    Cast(onnx::Cast_44, to=7) -> onnx::Add_45
    Add(onnx::Add_42, onnx::Add_45) -> onnx::Gather_46
      Gather(_operators.0.nodes.2, onnx::Gather_46, axis=0) -> onnx::Reshape_47
Constant(value=[-1, 1]) -> onnx::Reshape_48
  Reshape(onnx::Reshape_47, onnx::Reshape_48, allowzero=0) -> onnx::GatherElements_49
    GatherElements(X, onnx::GatherElements_49, axis=1) -> onnx::Reshape_50
Constant(value=[-1]) -> onnx::Reshape_51
  Reshape(onnx::Reshape_50, onnx::Reshape_51, allowzero=0) -> onnx::LessOrEqual_52
Constant(value=2) -> onnx::Mul_53
  Mul(onnx::Gather_46, onnx::Mul_53) -> onnx::Add_54
Gather(_operators.0.biases.2, onnx::Gather_46, axis=0) -> onnx::LessOrEqual_55
  LessOrEqual(onnx::LessOrEqual_52, onnx::LessOrEqual_55) -> onnx::Cast_56
    Cast(onnx::Cast_56, to=7) -> onnx::Add_57
    Add(onnx::Add_54, onnx::Add_57) -> onnx::Gather_58
      Gather(_operators.0.nodes.3, onnx::Gather_58, axis=0) -> onnx::Reshape_59
Constant(value=[-1, 1]) -> onnx::Reshape_60
  Reshape(onnx::Reshape_59, onnx::Reshape_60, allowzero=0) -> onnx::GatherElements_61
    GatherElements(X, onnx::GatherElements_61, axis=1) -> onnx::Reshape_62
Constant(value=[-1]) -> onnx::Reshape_63
  Reshape(onnx::Reshape_62, onnx::Reshape_63, allowzero=0) -> onnx::LessOrEqual_64
Constant(value=2) -> onnx::Mul_65
  Mul(onnx::Gather_58, onnx::Mul_65) -> onnx::Add_66
Gather(_operators.0.biases.3, onnx::Gather_58, axis=0) -> onnx::LessOrEqual_67
  LessOrEqual(onnx::LessOrEqual_64, onnx::LessOrEqual_67) -> onnx::Cast_68
    Cast(onnx::Cast_68, to=7) -> onnx::Add_69
    Add(onnx::Add_66, onnx::Add_69) -> onnx::Gather_70
      Gather(_operators.0.nodes.4, onnx::Gather_70, axis=0) -> onnx::Reshape_71
Constant(value=[-1, 1]) -> onnx::Reshape_72
  Reshape(onnx::Reshape_71, onnx::Reshape_72, allowzero=0) -> onnx::GatherElements_73
    GatherElements(X, onnx::GatherElements_73, axis=1) -> onnx::Reshape_74
Constant(value=[-1]) -> onnx::Reshape_75
  Reshape(onnx::Reshape_74, onnx::Reshape_75, allowzero=0) -> onnx::LessOrEqual_76
Constant(value=2) -> onnx::Mul_77
  Mul(onnx::Gather_70, onnx::Mul_77) -> onnx::Add_78
Gather(_operators.0.biases.4, onnx::Gather_70, axis=0) -> onnx::LessOrEqual_79
  LessOrEqual(onnx::LessOrEqual_76, onnx::LessOrEqual_79) -> onnx::Cast_80
    Cast(onnx::Cast_80, to=7) -> onnx::Add_81
    Add(onnx::Add_78, onnx::Add_81) -> onnx::Gather_82
      Gather(_operators.0.nodes.5, onnx::Gather_82, axis=0) -> onnx::Reshape_83
Constant(value=[-1, 1]) -> onnx::Reshape_84
  Reshape(onnx::Reshape_83, onnx::Reshape_84, allowzero=0) -> onnx::GatherElements_85
    GatherElements(X, onnx::GatherElements_85, axis=1) -> onnx::Reshape_86
Constant(value=[-1]) -> onnx::Reshape_87
  Reshape(onnx::Reshape_86, onnx::Reshape_87, allowzero=0) -> onnx::LessOrEqual_88
Constant(value=2) -> onnx::Mul_89
  Mul(onnx::Gather_82, onnx::Mul_89) -> onnx::Add_90
Gather(_operators.0.biases.5, onnx::Gather_82, axis=0) -> onnx::LessOrEqual_91
  LessOrEqual(onnx::LessOrEqual_88, onnx::LessOrEqual_91) -> onnx::Cast_92
    Cast(onnx::Cast_92, to=7) -> onnx::Add_93
    Add(onnx::Add_90, onnx::Add_93) -> onnx::Gather_94
      Gather(_operators.0.leaf_nodes, onnx::Gather_94, axis=0) -> onnx::Reshape_95
Constant(value=[-1, 1, 1]) -> onnx::Reshape_96
  Reshape(onnx::Reshape_95, onnx::Reshape_96, allowzero=0) -> output
Constant(value=[1]) -> onnx::ReduceSum_98
  ReduceSum(output, onnx::ReduceSum_98, keepdims=0) -> variable
output: name='variable' type=dtype('float32') shape=['batch_size', 'ReduceSumvariable_dim_1']
%onnxview onxh

La librairie réimplémente la décision d’un arbre décision à partir d’un produit matriciel pour chaque niveau de l’arbre. Tous les seuils sont évalués. Les matrices n’ont pas besoin d’être sparses car les features nécessaires sont récupérées. Le seuil de décision est implémenté avec un test et non une sigmoïde. Ce modèle est donc identique en terme de prédiction au modèle initial.

oinfh = OnnxInference(onxh, runtime='onnxruntime1')
expected = tree.predict(x_exp)

got = oinfh.run({'X': x_exp.astype(numpy.float32)})['variable']
numpy.abs(got - expected).max()
1.7421041873949668

La conversion reste imparfaite également.

%timeit oinfh.run({'X': x_exp32})['variable']
3.13 ms ± 445 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Et le temps de calcul est aussi plus long.

Apprentissage#

L’idée derrière tout cela est aussi de pouvoir réestimer les coefficients du réseau de neurones une fois converti.

x_train = X_train[:100]
expected = tree.predict(x_train)
reg = NeuralTreeNetRegressor(trees[1], verbose=1, max_iter=10, lr=1e-4)
got = reg.predict(x_train)
numpy.abs(got - expected).max(), numpy.abs(got - expected).mean()
(1.0246115055833722, 0.24094382754240642)

La différence est grande.

reg.fit(x_train, expected)
0/10: loss: 3.201 lr=0.0001 max(coef): 6.5 l1=0/1.5e+03 l2=0/2.5e+03
1/10: loss: 2.593 lr=9.95e-06 max(coef): 6.5 l1=2e+03/1.5e+03 l2=1.3e+03/2.5e+03
2/10: loss: 2.506 lr=7.05e-06 max(coef): 6.5 l1=1.4e+02/1.5e+03 l2=6.2/2.5e+03
3/10: loss: 2.461 lr=5.76e-06 max(coef): 6.5 l1=1.2e+03/1.5e+03 l2=6.8e+02/2.5e+03
4/10: loss: 2.429 lr=4.99e-06 max(coef): 6.5 l1=6.5e+02/1.5e+03 l2=2.1e+02/2.5e+03
5/10: loss: 2.405 lr=4.47e-06 max(coef): 6.5 l1=1.9e+02/1.5e+03 l2=13/2.5e+03
6/10: loss: 2.392 lr=4.08e-06 max(coef): 6.5 l1=1.6e+02/1.5e+03 l2=6.8/2.5e+03
7/10: loss: 2.375 lr=3.78e-06 max(coef): 6.5 l1=1.8e+02/1.5e+03 l2=9.5/2.5e+03
8/10: loss: 2.358 lr=3.53e-06 max(coef): 6.5 l1=1.1e+02/1.5e+03 l2=7/2.5e+03
9/10: loss: 2.345 lr=3.33e-06 max(coef): 6.5 l1=3.7e+02/1.5e+03 l2=56/2.5e+03
10/10: loss: 2.333 lr=3.16e-06 max(coef): 6.5 l1=6.1e+02/1.5e+03 l2=1.3e+02/2.5e+03
NeuralTreeNetRegressor(estimator=None, lr=0.0001, max_iter=10, verbose=1)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
got = reg.predict(x_train)
numpy.abs(got - expected).max(), numpy.abs(got - expected).mean()
(1.256860512819292, 0.25663312220721907)

Ca ne marche pas aussi bien que prévu. Il faudrait sans doute plusieurs itérations et jouer avec les paramètres d’apprentissage.