La conversion d'un arbre de décision au format ONNX peut créer des différences entre le modèle original et le modèle converti (voir Issues when switching to float. Le problème vient d'un changement de type, les seuils de décisions sont arrondis au float32 le plus proche de leur valeur en float64 (double). Qu'advient-il si l'arbre de décision est converti en réseau de neurones d'abord.
L'approximation des seuils de décision ne change pas grand chose dans la majorité des cas. Cependant, il est possible que la comparaison d'une variable à un seuil de décision arrondi soit l'opposé de celle avec le seuil non arrondi. Dans ce cas, la décision suit un chemin différent dans l'arbre.
from jyquickhelper import add_notebook_menu
add_notebook_menu()
%matplotlib inline
%load_ext mlprodict
On construit un jeu de donnée aléatoire.
import numpy
X = numpy.random.randn(10000, 10)
y = X.sum(axis=1) / X.shape[1]
X = X.astype(numpy.float64)
y = y.astype(numpy.float64)
middle = X.shape[0] // 2
X_train, X_test = X[:middle], X[middle:]
y_train, y_test = y[:middle], y[middle:]
from sklearn.tree import DecisionTreeRegressor
tree = DecisionTreeRegressor(max_depth=7)
tree.fit(X_train, y_train)
tree.score(X_train, y_train), tree.score(X_test, y_test)
(0.6179766027481131, 0.33709933420465643)
from sklearn.metrics import r2_score
r2_score(y_test, tree.predict(X_test))
0.33709933420465643
La profondeur de l'arbre est insuffisante mais ce n'est pas ce qui nous intéresse ici.
from mlprodict.onnx_conv import to_onnx
onx = to_onnx(tree, X[:1].astype(numpy.float32))
from mlprodict.onnxrt import OnnxInference
x_exp = X_test
oinf = OnnxInference(onx, runtime='onnxruntime1')
expected = tree.predict(x_exp)
got = oinf.run({'X': x_exp.astype(numpy.float32)})['variable']
numpy.abs(got - expected).max()
1.7421041873949668
from mlprodict.plotting.text_plot import onnx_simple_text_plot
print(onnx_simple_text_plot(onx))
opset: domain='ai.onnx.ml' version=1 opset: domain='' version=15 input: name='X' type=dtype('float32') shape=[None, 10] TreeEnsembleRegressor(X, n_targets=1, nodes_falsenodeids=253:[128,65,34...252,0,0], nodes_featureids=253:[8,3,9...2,0,0], nodes_hitrates=253:[1.0,1.0...1.0,1.0], nodes_missing_value_tracks_true=253:[0,0,0...0,0,0], nodes_modes=253:[b'BRANCH_LEQ',b'BRANCH_LEQ'...b'LEAF',b'LEAF'], nodes_nodeids=253:[0,1,2...250,251,252], nodes_treeids=253:[0,0,0...0,0,0], nodes_truenodeids=253:[1,2,3...251,0,0], nodes_values=253:[0.00792999193072319,-0.12246682494878769...0.0,0.0], post_transform=b'NONE', target_ids=127:[0,0,0...0,0,0], target_nodeids=127:[7,8,10...249,251,252], target_treeids=127:[0,0,0...0,0,0], target_weights=127:[-0.9345570802688599,-0.6372960805892944...0.6169403195381165,1.0096807479858398]) -> variable output: name='variable' type=dtype('float32') shape=[None, 1]
Un paramètre permet de faire varier la pente des fonctions sigmoïdes utilisées.
from tqdm import tqdm
from pandas import DataFrame
from mlstatpy.ml.neural_tree import NeuralTreeNet
xe = x_exp[:500]
expected = tree.predict(xe)
data = []
trees = {}
for i in tqdm([0.3, 0.4, 0.5, 0.7, 0.9, 1] + list(range(5, 61, 5))):
root = NeuralTreeNet.create_from_tree(tree, k=i, arch='compact')
got = root.predict(xe)[:, -1]
me = numpy.abs(got - expected).mean()
mx = numpy.abs(got - expected).max()
obs = dict(k=i, max=mx, mean=me)
data.append(obs)
trees[i] = root
100%|██████████| 18/18 [00:01<00:00, 12.49it/s]
df = DataFrame(data)
df
k | max | mean | |
---|---|---|---|
0 | 0.3 | 0.568981 | 0.158758 |
1 | 0.4 | 0.608304 | 0.132576 |
2 | 0.5 | 0.692657 | 0.128525 |
3 | 0.7 | 0.780543 | 0.131497 |
4 | 0.9 | 0.809866 | 0.128368 |
5 | 1.0 | 0.813889 | 0.124802 |
6 | 5.0 | 0.392482 | 0.022466 |
7 | 10.0 | 0.341749 | 0.006350 |
8 | 15.0 | 0.270649 | 0.002939 |
9 | 20.0 | 0.299713 | 0.002110 |
10 | 25.0 | 0.305493 | 0.001842 |
11 | 30.0 | 0.306111 | 0.001767 |
12 | 35.0 | 0.299371 | 0.001665 |
13 | 40.0 | 0.233556 | 0.001011 |
14 | 45.0 | 0.233606 | 0.000801 |
15 | 50.0 | 0.233614 | 0.000547 |
16 | 55.0 | 0.233615 | 0.000499 |
17 | 60.0 | 0.233615 | 0.000484 |
df.set_index('k').plot(title="Précision de la conversion\nen réseau de neurones");
L'erreur est meilleure mais il faudrait recommencer l'expérience plusieurs fois avant de pouvoir conclure afin d'obtenir un interval de confiance pour le même type de jeu de données. Ce sera pour une autre fois. Le résultat dépend du jeu de données et surtout de la proximité des seuils de décisions. Néanmoins, on calcule l'erreur sur l'ensemble de la base de test. Celle-ci a été tronquée pour aller plus vite.
expected = tree.predict(x_exp)
got = trees[50].predict(x_exp)[:, -1]
numpy.abs(got - expected).max(), numpy.abs(got - expected).mean()
(0.2336143002078063, 0.0002511855017989173)
On voit que l'erreur peut-être très grande. Elle reste néanmoins plus petite que l'erreur de conversion introduite par ONNX.
On crée tout d'abord une classe qui suit l'API de scikit-learn et qui englobe l'arbre qui vient d'être créé qui sera ensuite convertit en ONNX.
from mlstatpy.ml.neural_tree import NeuralTreeNetRegressor
reg = NeuralTreeNetRegressor(trees[50])
onx2 = to_onnx(reg, X[:1].astype(numpy.float32))
print(onnx_simple_text_plot(onx2))
opset: domain='' version=15 input: name='X' type=dtype('float32') shape=[None, 10] init: name='Ma_MatMulcst' type=dtype('float32') shape=(1260,) init: name='Ad_Addcst' type=dtype('float32') shape=(126,) init: name='Mu_Mulcst' type=dtype('float32') shape=(1,) -- array([4.], dtype=float32) init: name='Ma_MatMulcst1' type=dtype('float32') shape=(16002,) init: name='Ad_Addcst1' type=dtype('float32') shape=(127,) init: name='Ma_MatMulcst2' type=dtype('float32') shape=(127,) init: name='Ad_Addcst2' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32) MatMul(X, Ma_MatMulcst) -> Ma_Y02 Add(Ma_Y02, Ad_Addcst) -> Ad_C02 Mul(Ad_C02, Mu_Mulcst) -> Mu_C01 Sigmoid(Mu_C01) -> Si_Y01 MatMul(Si_Y01, Ma_MatMulcst1) -> Ma_Y01 Add(Ma_Y01, Ad_Addcst1) -> Ad_C01 Mul(Ad_C01, Mu_Mulcst) -> Mu_C0 Sigmoid(Mu_C0) -> Si_Y0 MatMul(Si_Y0, Ma_MatMulcst2) -> Ma_Y0 Add(Ma_Y0, Ad_Addcst2) -> Ad_C0 Identity(Ad_C0) -> variable output: name='variable' type=dtype('float32') shape=[None, 1]
oinf2 = OnnxInference(onx2, runtime='onnxruntime1')
expected = tree.predict(x_exp)
got = oinf2.run({'X': x_exp.astype(numpy.float32)})['variable']
numpy.abs(got - expected).max()
1.7421041873949668
L'erreur est la même.
x_exp32 = x_exp.astype(numpy.float32)
Tout d'abord le temps de calcul pour scikit-learn.
%timeit tree.predict(x_exp32)
513 µs ± 7.52 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Le temps de calcul pour l'arbre de décision au format ONNX.
%timeit oinf.run({'X': x_exp32})['variable']
186 µs ± 3.41 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Et le temps de calcul pour le réseau de neurones au format ONNX.m
%timeit oinf2.run({'X': x_exp32})['variable']
3.75 ms ± 311 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Ce temps de calcul très long est attendu car le modèle contient une multiplication de matrice très grande et surtout que tous les seuils de l'arbre sont calculés pour chaque observation. Là où l'implémentation de l'arbre de décision calcule d seuils, la profondeur de l'arbre, la nouvelle implémentation calcule tous les seuils soit $2^d$ pour chaque feuille. Il y a $2^d$ feuilles. Même en étant sparse, on peut réduire les calculs à $d * 2^d$ ce qui fait encore beaucoup de calculs inutiles.
for node in trees[50].nodes:
print(node.coef.shape, node.bias.shape)
(126, 11) (126,) (127, 127) (127,) (128,) ()
Cela dit, la plus grande matrice est creuse, elle peut être réduite considérablement.
from scipy.sparse import csr_matrix
for node in trees[50].nodes:
csr = csr_matrix(node.coef)
print(f"coef.shape={node.coef.shape}, size dense={node.coef.size}, "
f"size sparse={csr.size}, ratio={csr.size / node.coef.size}")
coef.shape=(126, 11), size dense=1386, size sparse=252, ratio=0.18181818181818182 coef.shape=(127, 127), size dense=16129, size sparse=1015, ratio=0.06293012586025172 coef.shape=(128,), size dense=128, size sparse=127, ratio=0.9921875
r = numpy.random.randn(trees[50].nodes[1].coef.shape[0])
mat = trees[50].nodes[1].coef
%timeit mat @ r
49.8 µs ± 1.25 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
csr = csr_matrix(mat)
%timeit csr @ r
7.08 µs ± 173 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Ce serait beaucoup plus rapide avec une matrice sparse et d'autant plus rapide que l'arbre est profond. Le modèle ONNX se décompose comme suit.
print(onnx_simple_text_plot(onx2))
opset: domain='' version=15 input: name='X' type=dtype('float32') shape=[None, 10] init: name='Ma_MatMulcst' type=dtype('float32') shape=(1260,) init: name='Ad_Addcst' type=dtype('float32') shape=(126,) init: name='Mu_Mulcst' type=dtype('float32') shape=(1,) -- array([4.], dtype=float32) init: name='Ma_MatMulcst1' type=dtype('float32') shape=(16002,) init: name='Ad_Addcst1' type=dtype('float32') shape=(127,) init: name='Ma_MatMulcst2' type=dtype('float32') shape=(127,) init: name='Ad_Addcst2' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32) MatMul(X, Ma_MatMulcst) -> Ma_Y02 Add(Ma_Y02, Ad_Addcst) -> Ad_C02 Mul(Ad_C02, Mu_Mulcst) -> Mu_C01 Sigmoid(Mu_C01) -> Si_Y01 MatMul(Si_Y01, Ma_MatMulcst1) -> Ma_Y01 Add(Ma_Y01, Ad_Addcst1) -> Ad_C01 Mul(Ad_C01, Mu_Mulcst) -> Mu_C0 Sigmoid(Mu_C0) -> Si_Y0 MatMul(Si_Y0, Ma_MatMulcst2) -> Ma_Y0 Add(Ma_Y0, Ad_Addcst2) -> Ad_C0 Identity(Ad_C0) -> variable output: name='variable' type=dtype('float32') shape=[None, 1]
Voyons comment le temps de calcul se répartit.
oinfpr = OnnxInference(onx2, runtime="onnxruntime1",
runtime_options={"enable_profiling": True})
for i in range(0, 43):
oinfpr.run({"X": x_exp32})
df = oinfpr.get_profiling(as_df=True)
df
cat | pid | tid | dur | ts | ph | name | args_op_name | args_parameter_size | args_graph_index | args_provider | args_exec_plan_index | args_activation_size | args_output_size | args_input_type_shape | args_output_type_shape | args_thread_scheduling_stats | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | Session | 78116 | 8820 | 387 | 4 | X | model_loading_array | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
1 | Session | 78116 | 8820 | 2532 | 428 | X | session_initialization | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2 | Node | 78116 | 8820 | 0 | 3294 | X | gemm_fence_before | Gemm | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
3 | Node | 78116 | 8820 | 1315 | 3300 | X | gemm_kernel_time | Gemm | 5544 | 11 | CPUExecutionProvider | 11 | 200000 | 2520000 | [{'float': [5000, 10]}, {'float': [10, 126]}, ... | [{'float': [5000, 126]}] | {'main_thread': {'thread_pool_name': 'session-... |
4 | Node | 78116 | 8820 | 0 | 4635 | X | gemm_fence_after | Gemm | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
986 | Node | 78116 | 8820 | 0 | 210170 | X | Ma_MatMul2_fence_before | MatMul | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
987 | Node | 78116 | 8820 | 124 | 210172 | X | Ma_MatMul2_kernel_time | MatMul | 508 | 8 | CPUExecutionProvider | 8 | 2540000 | 20000 | [{'float': [5000, 127]}, {'float': [127, 1]}] | [{'float': [5000, 1]}] | {'main_thread': {'thread_pool_name': 'session-... |
988 | Node | 78116 | 8820 | 0 | 210305 | X | Ma_MatMul2_fence_after | MatMul | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
989 | Session | 78116 | 8820 | 4378 | 205930 | X | SequentialExecutor::Execute | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
990 | Session | 78116 | 8820 | 4388 | 205925 | X | model_run | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
991 rows × 17 columns
set(df['args_provider'])
{'CPUExecutionProvider', nan}
dfp = df[df.args_provider == 'CPUExecutionProvider'].copy()
dfp['name'] = dfp['name'].apply(lambda s: s.replace("_kernel_time", ""))
gr_dur = dfp[['dur', "args_op_name", "name"]].groupby(["args_op_name", "name"]).sum().sort_values('dur')
gr_dur
dur | ||
---|---|---|
args_op_name | name | |
MatMul | Ma_MatMul2 | 6778 |
Mul | Mu_Mul | 12923 |
Sigmoid | Si_Sigmoid | 14849 |
Mul | Mu_Mul1 | 15151 |
Sigmoid | Si_Sigmoid1 | 15608 |
Gemm | gemm | 31763 |
gemm_token_0 | 99047 |
gr_n = dfp[['dur', "args_op_name", "name"]].groupby(["args_op_name", "name"]).count().sort_values('dur')
gr_n = gr_n.loc[gr_dur.index, :]
gr_n
dur | ||
---|---|---|
args_op_name | name | |
MatMul | Ma_MatMul2 | 43 |
Mul | Mu_Mul | 43 |
Sigmoid | Si_Sigmoid | 43 |
Mul | Mu_Mul1 | 43 |
Sigmoid | Si_Sigmoid1 | 43 |
Gemm | gemm | 43 |
gemm_token_0 | 43 |
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
gr_dur.plot.barh(ax=ax[0])
gr_n.plot.barh(ax=ax[1])
ax[0].set_title("duration")
ax[1].set_title("n occurences");
onnxruntime passe principalement son temps dans un produit matriciel. On vérifie plus précisément.
df[(df.args_op_name == 'Gemm') & (df.dur > 0)].sort_values('dur', ascending=False).head(n=2).T
127 | 12 | |
---|---|---|
cat | Node | Node |
pid | 78116 | 78116 |
tid | 8820 | 8820 |
dur | 4603 | 4083 |
ts | 37173 | 5949 |
ph | X | X |
name | gemm_token_0_kernel_time | gemm_token_0_kernel_time |
args_op_name | Gemm | Gemm |
args_parameter_size | 64516 | 64516 |
args_graph_index | 12 | 12 |
args_provider | CPUExecutionProvider | CPUExecutionProvider |
args_exec_plan_index | 12 | 12 |
args_activation_size | 2520000 | 2520000 |
args_output_size | 2540000 | 2540000 |
args_input_type_shape | [{'float': [5000, 126]}, {'float': [126, 127]}... | [{'float': [5000, 126]}, {'float': [126, 127]}... |
args_output_type_shape | [{'float': [5000, 127]}] | [{'float': [5000, 127]}] |
args_thread_scheduling_stats | {'main_thread': {'thread_pool_name': 'session-... | {'main_thread': {'thread_pool_name': 'session-... |
C'est un produit matriciel d'environ 5000x800 par 800x800.
gr_dur / gr_dur.dur.sum()
dur | ||
---|---|---|
args_op_name | name | |
MatMul | Ma_MatMul2 | 0.034561 |
Mul | Mu_Mul | 0.065894 |
Sigmoid | Si_Sigmoid | 0.075714 |
Mul | Mu_Mul1 | 0.077254 |
Sigmoid | Si_Sigmoid1 | 0.079584 |
Gemm | gemm | 0.161958 |
gemm_token_0 | 0.505035 |
r = (gr_dur / gr_dur.dur.sum()).dur.max()
r
0.5050352082154203
Il occupe 82% du temps. et d'après l'expérience précédente, son temps d'éxecution peut-être réduit par 10 en le remplaçant par une matrice sparse. Cela ne suffira pas pour accélerer le temps de calcul de ce réseau de neurones. Il est 84 ms comparé à 247 µs pour l'arbre de décision. Avec cette optimisation, il pourrait passer de :
t = 3.75 # ms
t * (1 - r) + r * t / 12
2.013941471759493
Soit une réduction du temps de calcul. Ce n'est pas mal mais pas assez.
hummingbird est une librairie qui convertit un arbre de décision en réseau de neurones. Voyons ses performances.
from hummingbird.ml import convert
model = convert(tree, 'torch')
expected = tree.predict(x_exp)
got = model.predict(x_exp)
numpy.abs(got - expected).max(), numpy.abs(got - expected).mean()
C:\xavierdupre\__home_\github_fork\scikit-learn\sklearn\utils\deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead. warnings.warn(msg, category=FutureWarning)
(4.3419181139370266e-08, 4.430287026515114e-09)
Le résultat est beaucoup plus fidèle au modèle.
%timeit model.predict(x_exp)
1.17 ms ± 34.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Il reste plus lent mais beaucoup plus rapide que la solution manuelle proposée dans les précédents paragraphes. Il contient un attribut model
.
from torch.nn import Module
isinstance(model.model, Module)
True
On convertit ce modèle au format ONNX.
import torch.onnx
x = torch.randn(x_exp.shape[0], x_exp.shape[1], requires_grad=True)
torch.onnx.export(model.model, x, 'tree_torch.onnx', opset_version=15,
input_names=['X'], output_names=['variable'],
dynamic_axes={
'X' : {0 : 'batch_size'},
'variable' : {0 : 'batch_size'}})
import onnx
onxh = onnx.load('tree_torch.onnx')
print(onnx_simple_text_plot(onxh, raise_exc=False))
opset: domain='' version=15 input: name='X' type=dtype('float32') shape=['batch_size', 10] init: name='_operators.0.root_nodes' type=dtype('int64') shape=(0,) -- array([8], dtype=int64) init: name='_operators.0.root_biases' type=dtype('float32') shape=(0,) -- array([0.00792999], dtype=float32) init: name='_operators.0.tree_indices' type=dtype('int64') shape=(0,) -- array([0], dtype=int64) init: name='_operators.0.leaf_nodes' type=dtype('float32') shape=(0,) -- array([ 1.0096807 , 0.6169403 , 0.61055773, 0.37810475, 0.31796893, 0.13317925, 0.0193846 , -0.2317742 , 0.39089343, 0.23506087, 0.3711936 , 0.10317916, 0.14956598, -0.14193445, -0.05965868, -0.27377078, 0.4128183 , 0.19658326, 0.25545415, 0.08118545, 0.08400188, -0.1502193 , -0.36846825, -0.79687625, 0.35822242, 0.49021915, 0.30870998, 0.01033915, 0.6740977 , 0.6740977 , -0.15315758, -0.41128033, 0.42920846, 0.13145493, 0.21853392, -0.10986731, 0.4493652 , 0.11318789, 0.12666471, -0.0623082 , 0.2872893 , 0.09948976, 0.11439473, -0.08801427, 0.16091613, -0.02319027, -0.10097775, -0.37583745, 0.18612385, -0.00453244, 0.3287116 , -0.1499349 , 0.7919218 , 0.04704398, -0.15423109, -0.43160027, 0.10802375, -0.1073833 , -0.07759219, -0.29175794, -0.1528881 , -0.4909434 , -0.23361537, -0.43578717, 0.7831867 , 0.45349318, 0.34956965, -0.3199535 , 0.3061573 , -0.34267113, 0.34963542, 0.04491445, 0.35399815, 0.14815213, 0.06678926, -0.16095412, 0.3214274 , 0.01484008, -0.1012276 , -0.3257699 , 0.26727676, 0.01970094, 0.10760042, -0.09169976, 0.20044112, -0.0324069 , -0.11015374, -0.28358367, 0.8083656 , 0.13358633, -0.07912118, -0.27182895, -0.07054728, -0.24895027, -0.20600456, -0.42033467, 0.34701794, -0.0638995 , 0.14252576, -0.06025055, 0.4228329 , 0.06789401, 0.03919645, -0.17267554, 0.07274943, -0.487512 , 0.04517636, -0.18857062, -0.03975222, -0.2652712 , -0.30853328, -0.50844556, 0.03321444, -0.15481217, -0.20701212, -0.40578464, -0.25884995, -0.46550158, -0.4797585 , -0.7324234 , 0.43939307, -0.06170902, -0.51546025, -0.19215119, -0.3705445 , -0.57504356, -0.6372961 , -0.9345571 ], dtype=float32) init: name='_operators.0.nodes.0' type=dtype('int64') shape=(0,) -- array([0, 3], dtype=int64) init: name='_operators.0.nodes.1' type=dtype('int64') shape=(0,) -- array([1, 2, 5, 9], dtype=int64) init: name='_operators.0.nodes.2' type=dtype('int64') shape=(0,) -- array([5, 6, 3, 7, 2, 0, 7, 1], dtype=int64) init: name='_operators.0.nodes.3' type=dtype('int64') shape=(0,) -- array([3, 9, 5, 3, 6, 4, 1, 3, 6, 6, 1, 6, 5, 4, 6, 2], dtype=int64) init: name='_operators.0.nodes.4' type=dtype('int64') shape=(0,) -- array([3, 2, 7, 6, 2, 4, 7, 8, 9, 5, 7, 8, 9, 4, 6, 9, 7, 9, 0, 7, 7, 9, 2, 7, 6, 4, 6, 5, 4, 0, 6, 0], dtype=int64) init: name='_operators.0.nodes.5' type=dtype('int64') shape=(0,) -- array([2, 8, 7, 6, 6, 3, 4, 9, 7, 3, 2, 6, 3, 3, 0, 1, 1, 0, 4, 7, 9, 5, 7, 9, 5, 3, 5, 9, 0, 5, 1, 4, 9, 4, 7, 7, 1, 9, 1, 1, 6, 2, 7, 7, 6, 1, 4, 4, 0, 0, 9, 8, 8, 2, 6, 2, 0, 3, 4, 2, 5, 6, 7, 3], dtype=int64) init: name='_operators.0.biases.0' type=dtype('float32') shape=(0,) -- array([ 0.19169255, -0.12246682], dtype=float32) init: name='_operators.0.biases.1' type=dtype('float32') shape=(0,) -- array([-0.40610337, -0.1467492 , -0.01880287, 0.15879431], dtype=float32) init: name='_operators.0.biases.2' type=dtype('float32') shape=(0,) -- array([ 0.736786 , -0.32427853, 0.30860555, 0.17994082, 0.6917758 , -0.00594712, 0.35950053, -0.9819274 ], dtype=float32) init: name='_operators.0.biases.3' type=dtype('float32') shape=(0,) -- array([-1.3495584 , -1.082793 , -0.6906011 , -0.08978076, -0.4007622 , 0.10756078, -0.68507075, 0.15814054, 0.5132364 , -0.18426335, 0.13685235, 0.10721841, 0.01814443, -0.41644228, -0.59770894, 0.607365 ], dtype=float32) init: name='_operators.0.biases.4' type=dtype('float32') shape=(0,) -- array([ 1.4203796 , -0.49269757, -0.12210988, -0.09692484, 0.5076643 , -1.3609421 , 1.154743 , 2.8748922 , -0.08181615, 0.7741028 , 0.20604724, 0.666296 , -0.6474025 , 0.6459148 , 0.02262808, -0.42282397, 0.46360654, -0.10058792, 0.25486696, 0.60041225, -0.06933744, 0.21294908, 0.96443814, 0.07923891, 0.4797698 , 1.2852331 , 0.24348404, -0.3404966 , -0.07175394, -0.8248828 , -0.74071133, -1.2140133 ], dtype=float32) init: name='_operators.0.biases.5' type=dtype('float32') shape=(0,) -- array([ 1.0626682 , 1.4745288 , 0.01898679, 0.5451088 , 0.15444604, 1.0631477 , -0.7555804 , -1.7192128 , -0.20905146, 0.19752283, -0.40471953, 0.13069782, 0.60331047, 1.5060809 , 0. , -1.8283446 , -0.8124372 , -1.381897 , 0.59209645, 0.3239226 , -0.42840806, -0.43624896, 0.58229303, -1.0196047 , -0.5632828 , 0.91483426, 1.8038778 , -0.5665638 , -1.2530733 , -0.6500004 , -1.3069727 , 0.48267984, 0.73503745, -1.871724 , -1.4965518 , 1.3147466 , 0.03919952, -0.885836 , 0.5479692 , -0.8086383 , -0.74240863, 0.14582941, 0.6496967 , -0.00911551, 2.4541488 , -0.90482277, 0.26108736, 0.7569448 , -1.0786855 , -0.45229852, 1.2146595 , -0.6756766 , -2.3066258 , 0.7911504 , 0.57490873, -0.40741247, 0.24633038, -1.2022957 , -0.65162694, -0.04244827, 1.558136 , -1.6220782 , 0.1574643 , -1.4209061 ], dtype=float32) Constant(value=[-1]) -> onnx::Reshape_27 Gather(X, _operators.0.root_nodes, axis=1) -> onnx::LessOrEqual_17 LessOrEqual(onnx::LessOrEqual_17, _operators.0.root_biases) -> onnx::Cast_18 Cast(onnx::Cast_18, to=7) -> onnx::Add_19 Add(onnx::Add_19, _operators.0.tree_indices) -> onnx::Reshape_20 Constant(value=[-1]) -> onnx::Reshape_21 Reshape(onnx::Reshape_20, onnx::Reshape_21, allowzero=0) -> onnx::Gather_22 Gather(_operators.0.nodes.0, onnx::Gather_22, axis=0) -> onnx::Reshape_23 Constant(value=[-1, 1]) -> onnx::Reshape_24 Reshape(onnx::Reshape_23, onnx::Reshape_24, allowzero=0) -> onnx::GatherElements_25 GatherElements(X, onnx::GatherElements_25, axis=1) -> onnx::Reshape_26 Reshape(onnx::Reshape_26, onnx::Reshape_27, allowzero=0) -> onnx::LessOrEqual_28 Constant(value=2) -> onnx::Mul_29 Mul(onnx::Gather_22, onnx::Mul_29) -> onnx::Add_30 Gather(_operators.0.biases.0, onnx::Gather_22, axis=0) -> onnx::LessOrEqual_31 LessOrEqual(onnx::LessOrEqual_28, onnx::LessOrEqual_31) -> onnx::Cast_32 Cast(onnx::Cast_32, to=7) -> onnx::Add_33 Add(onnx::Add_30, onnx::Add_33) -> onnx::Gather_34 Gather(_operators.0.nodes.1, onnx::Gather_34, axis=0) -> onnx::Reshape_35 Constant(value=[-1, 1]) -> onnx::Reshape_36 Reshape(onnx::Reshape_35, onnx::Reshape_36, allowzero=0) -> onnx::GatherElements_37 GatherElements(X, onnx::GatherElements_37, axis=1) -> onnx::Reshape_38 Constant(value=[-1]) -> onnx::Reshape_39 Reshape(onnx::Reshape_38, onnx::Reshape_39, allowzero=0) -> onnx::LessOrEqual_40 Constant(value=2) -> onnx::Mul_41 Mul(onnx::Gather_34, onnx::Mul_41) -> onnx::Add_42 Gather(_operators.0.biases.1, onnx::Gather_34, axis=0) -> onnx::LessOrEqual_43 LessOrEqual(onnx::LessOrEqual_40, onnx::LessOrEqual_43) -> onnx::Cast_44 Cast(onnx::Cast_44, to=7) -> onnx::Add_45 Add(onnx::Add_42, onnx::Add_45) -> onnx::Gather_46 Gather(_operators.0.nodes.2, onnx::Gather_46, axis=0) -> onnx::Reshape_47 Constant(value=[-1, 1]) -> onnx::Reshape_48 Reshape(onnx::Reshape_47, onnx::Reshape_48, allowzero=0) -> onnx::GatherElements_49 GatherElements(X, onnx::GatherElements_49, axis=1) -> onnx::Reshape_50 Constant(value=[-1]) -> onnx::Reshape_51 Reshape(onnx::Reshape_50, onnx::Reshape_51, allowzero=0) -> onnx::LessOrEqual_52 Constant(value=2) -> onnx::Mul_53 Mul(onnx::Gather_46, onnx::Mul_53) -> onnx::Add_54 Gather(_operators.0.biases.2, onnx::Gather_46, axis=0) -> onnx::LessOrEqual_55 LessOrEqual(onnx::LessOrEqual_52, onnx::LessOrEqual_55) -> onnx::Cast_56 Cast(onnx::Cast_56, to=7) -> onnx::Add_57 Add(onnx::Add_54, onnx::Add_57) -> onnx::Gather_58 Gather(_operators.0.nodes.3, onnx::Gather_58, axis=0) -> onnx::Reshape_59 Constant(value=[-1, 1]) -> onnx::Reshape_60 Reshape(onnx::Reshape_59, onnx::Reshape_60, allowzero=0) -> onnx::GatherElements_61 GatherElements(X, onnx::GatherElements_61, axis=1) -> onnx::Reshape_62 Constant(value=[-1]) -> onnx::Reshape_63 Reshape(onnx::Reshape_62, onnx::Reshape_63, allowzero=0) -> onnx::LessOrEqual_64 Constant(value=2) -> onnx::Mul_65 Mul(onnx::Gather_58, onnx::Mul_65) -> onnx::Add_66 Gather(_operators.0.biases.3, onnx::Gather_58, axis=0) -> onnx::LessOrEqual_67 LessOrEqual(onnx::LessOrEqual_64, onnx::LessOrEqual_67) -> onnx::Cast_68 Cast(onnx::Cast_68, to=7) -> onnx::Add_69 Add(onnx::Add_66, onnx::Add_69) -> onnx::Gather_70 Gather(_operators.0.nodes.4, onnx::Gather_70, axis=0) -> onnx::Reshape_71 Constant(value=[-1, 1]) -> onnx::Reshape_72 Reshape(onnx::Reshape_71, onnx::Reshape_72, allowzero=0) -> onnx::GatherElements_73 GatherElements(X, onnx::GatherElements_73, axis=1) -> onnx::Reshape_74 Constant(value=[-1]) -> onnx::Reshape_75 Reshape(onnx::Reshape_74, onnx::Reshape_75, allowzero=0) -> onnx::LessOrEqual_76 Constant(value=2) -> onnx::Mul_77 Mul(onnx::Gather_70, onnx::Mul_77) -> onnx::Add_78 Gather(_operators.0.biases.4, onnx::Gather_70, axis=0) -> onnx::LessOrEqual_79 LessOrEqual(onnx::LessOrEqual_76, onnx::LessOrEqual_79) -> onnx::Cast_80 Cast(onnx::Cast_80, to=7) -> onnx::Add_81 Add(onnx::Add_78, onnx::Add_81) -> onnx::Gather_82 Gather(_operators.0.nodes.5, onnx::Gather_82, axis=0) -> onnx::Reshape_83 Constant(value=[-1, 1]) -> onnx::Reshape_84 Reshape(onnx::Reshape_83, onnx::Reshape_84, allowzero=0) -> onnx::GatherElements_85 GatherElements(X, onnx::GatherElements_85, axis=1) -> onnx::Reshape_86 Constant(value=[-1]) -> onnx::Reshape_87 Reshape(onnx::Reshape_86, onnx::Reshape_87, allowzero=0) -> onnx::LessOrEqual_88 Constant(value=2) -> onnx::Mul_89 Mul(onnx::Gather_82, onnx::Mul_89) -> onnx::Add_90 Gather(_operators.0.biases.5, onnx::Gather_82, axis=0) -> onnx::LessOrEqual_91 LessOrEqual(onnx::LessOrEqual_88, onnx::LessOrEqual_91) -> onnx::Cast_92 Cast(onnx::Cast_92, to=7) -> onnx::Add_93 Add(onnx::Add_90, onnx::Add_93) -> onnx::Gather_94 Gather(_operators.0.leaf_nodes, onnx::Gather_94, axis=0) -> onnx::Reshape_95 Constant(value=[-1, 1, 1]) -> onnx::Reshape_96 Reshape(onnx::Reshape_95, onnx::Reshape_96, allowzero=0) -> output Constant(value=[1]) -> onnx::ReduceSum_98 ReduceSum(output, onnx::ReduceSum_98, keepdims=0) -> variable output: name='variable' type=dtype('float32') shape=['batch_size', 'ReduceSumvariable_dim_1']
%onnxview onxh
La librairie réimplémente la décision d'un arbre décision à partir d'un produit matriciel pour chaque niveau de l'arbre. Tous les seuils sont évalués. Les matrices n'ont pas besoin d'être sparses car les features nécessaires sont récupérées. Le seuil de décision est implémenté avec un test et non une sigmoïde. Ce modèle est donc identique en terme de prédiction au modèle initial.
oinfh = OnnxInference(onxh, runtime='onnxruntime1')
expected = tree.predict(x_exp)
got = oinfh.run({'X': x_exp.astype(numpy.float32)})['variable']
numpy.abs(got - expected).max()
1.7421041873949668
La conversion reste imparfaite également.
%timeit oinfh.run({'X': x_exp32})['variable']
3.13 ms ± 445 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Et le temps de calcul est aussi plus long.
L'idée derrière tout cela est aussi de pouvoir réestimer les coefficients du réseau de neurones une fois converti.
x_train = X_train[:100]
expected = tree.predict(x_train)
reg = NeuralTreeNetRegressor(trees[1], verbose=1, max_iter=10, lr=1e-4)
got = reg.predict(x_train)
numpy.abs(got - expected).max(), numpy.abs(got - expected).mean()
(1.0246115055833722, 0.24094382754240642)
La différence est grande.
reg.fit(x_train, expected)
0/10: loss: 3.201 lr=0.0001 max(coef): 6.5 l1=0/1.5e+03 l2=0/2.5e+03 1/10: loss: 2.593 lr=9.95e-06 max(coef): 6.5 l1=2e+03/1.5e+03 l2=1.3e+03/2.5e+03 2/10: loss: 2.506 lr=7.05e-06 max(coef): 6.5 l1=1.4e+02/1.5e+03 l2=6.2/2.5e+03 3/10: loss: 2.461 lr=5.76e-06 max(coef): 6.5 l1=1.2e+03/1.5e+03 l2=6.8e+02/2.5e+03 4/10: loss: 2.429 lr=4.99e-06 max(coef): 6.5 l1=6.5e+02/1.5e+03 l2=2.1e+02/2.5e+03 5/10: loss: 2.405 lr=4.47e-06 max(coef): 6.5 l1=1.9e+02/1.5e+03 l2=13/2.5e+03 6/10: loss: 2.392 lr=4.08e-06 max(coef): 6.5 l1=1.6e+02/1.5e+03 l2=6.8/2.5e+03 7/10: loss: 2.375 lr=3.78e-06 max(coef): 6.5 l1=1.8e+02/1.5e+03 l2=9.5/2.5e+03 8/10: loss: 2.358 lr=3.53e-06 max(coef): 6.5 l1=1.1e+02/1.5e+03 l2=7/2.5e+03 9/10: loss: 2.345 lr=3.33e-06 max(coef): 6.5 l1=3.7e+02/1.5e+03 l2=56/2.5e+03 10/10: loss: 2.333 lr=3.16e-06 max(coef): 6.5 l1=6.1e+02/1.5e+03 l2=1.3e+02/2.5e+03
NeuralTreeNetRegressor(estimator=None, lr=0.0001, max_iter=10, verbose=1)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
NeuralTreeNetRegressor(estimator=None, lr=0.0001, max_iter=10, verbose=1)
got = reg.predict(x_train)
numpy.abs(got - expected).max(), numpy.abs(got - expected).mean()
(1.256860512819292, 0.25663312220721907)
Ca ne marche pas aussi bien que prévu. Il faudrait sans doute plusieurs itérations et jouer avec les paramètres d'apprentissage.