.. _onnxsklearnconsortiumrst: =========================================== ONNX, scikit-learn, persistence, deployment =========================================== .. only:: html **Links:** :download:`notebook `, :downloadlink:`html `, :download:`PDF `, :download:`python `, :downloadlink:`slides `, :githublink:`GitHub|_doc/notebooks/2019/sklearn/onnx_sklearn_consortium.ipynb|*` The notebook explains what ONNX is and how it can be used combined with `sklearn-onnx `__ and `onnxruntime `__. `ONNX `__ is a serialization format for machine learning models. **Xavier Dupré** - Senior Data Scientist at Microsoft - Computer Science Teacher at `ENSAE `__, `github/xadupre `__, `github/sdpython `__. .. code:: ipython3 from jyquickhelper import add_notebook_menu add_notebook_menu(last_level=2) .. contents:: :local: .. code:: ipython3 import numpy as np from pyquickhelper.helpgen import NbImage from sklearn.datasets import load_iris from sklearn.linear_model import LogisticRegression from sklearn.ensemble import RandomForestClassifier from jupytalk.talk_examples.sklearn2019 import ( graph_persistence_pickle, graph_persistence_pickle_issues, graph_persistence_onnx, profile_fct_graph, onnx2str, onnx2dotnb, onnxdocstring2html, rename_input_output, graph_three_components) from mlinsights.plotting import pipeline2dot %matplotlib inline .. code:: ipython3 from logging import getLogger logger = getLogger('skl2onnx') logger.disabled = True Many functions are implemented in `sklearn2019.py `__. Persistence and predictions --------------------------- Persistence with pickle ~~~~~~~~~~~~~~~~~~~~~~~ .. code:: ipython3 graph_persistence_pickle() .. image:: onnx_sklearn_consortium_7_0.png Main issues ~~~~~~~~~~~ - Unpickle is unstable - Predictions are not fast (scikit-learn is optimized for batch prediction) .. code:: ipython3 graph_persistence_pickle_issues() .. image:: onnx_sklearn_consortium_9_0.png Iris dataset ~~~~~~~~~~~~ .. code:: ipython3 data = load_iris() X, y = data.data, data.target Example with logistic regression ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code:: ipython3 clr = LogisticRegression(multi_class="auto", solver="liblinear").fit(X, y) profile_fct_graph(lambda: [clr.predict(X) for i in range(0, 1000)], "Cumulated time inside functions when predicting\nLogisticRegression", ["safe_sparse_dot", "dot", "sum"]); .. image:: onnx_sklearn_consortium_13_0.png Persistence with ONNX ~~~~~~~~~~~~~~~~~~~~~ .. code:: ipython3 graph_persistence_onnx() .. image:: onnx_sklearn_consortium_15_0.png Three components for ONNX ~~~~~~~~~~~~~~~~~~~~~~~~~ .. code:: ipython3 graph_three_components() .. image:: onnx_sklearn_consortium_17_0.png ONNX specifications ------------------- - `ONNX `__ = **Set of mathematical operations** assembled into a **graph**. - It is versioned and **stable**: backward compatibility. - It is optimized for deep learning, it works with **single float** Example with matrix operations ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code:: ipython3 X.shape .. parsed-literal:: (150, 4) .. code:: ipython3 beta = np.random.randn(4, 3) M = (X @ beta) expM = np.exp(M) pred = expM / (expM + 1) pred[:5] .. parsed-literal:: array([[0.9993822 , 0.45859843, 0.99672386], [0.99938059, 0.477194 , 0.99527649], [0.99892886, 0.46590839, 0.99484022], [0.99878037, 0.50803645, 0.99322234], [0.99918539, 0.46178468, 0.9964578 ]]) Conversion to single float .. code:: ipython3 X32 = X.astype(np.float32) beta32 = beta.astype(np.float32) Let’s write the ONNX function. .. code:: ipython3 from skl2onnx.algebra.onnx_ops import OnnxMatMul, OnnxExp, OnnxAdd, OnnxDiv onnxExpM = OnnxExp(OnnxMatMul('X', beta32, op_version=12), op_version=12) cst = np.ones((1, 3), dtype=np.float32) onnxExpM1 = OnnxAdd(onnxExpM, cst, op_version=12) # use of broadcasting onnxPred = OnnxDiv(onnxExpM, onnxExpM1, op_version=12) Let’s convert it to ONNX format: .. code:: ipython3 inputs = {'X': X[:1].astype(np.float32)} model_onnx = onnxPred.to_onnx(inputs) print(onnx2str(model_onnx)) .. parsed-literal:: ir_version: 4 producer_name: "skl2onnx" producer_version: "1.7.0" domain: "ai.onnx" model_version: 0 graph { node { input: "X" input: "Ma_MatMulcst" output: "Ma_Y0" name: "Ma_MatMul" op_type: "MatMul" domain: "" } node { ... Let’s save it in a file. .. code:: ipython3 with open("model-1.onnx", "wb") as f: f.write(model_onnx.SerializeToString()) Let’s load it back into Python .. code:: ipython3 import onnx model2 = onnx.load("model-1.onnx") Let’s be more visual: .. code:: ipython3 onnx2dotnb(model_onnx, orientation='LR') .. raw:: html
Case of a simple linear regression: dedicated operator ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code:: ipython3 from skl2onnx.algebra.onnx_ops import OnnxLinearRegressor onnxdocstring2html(OnnxLinearRegressor.__doc__) .. raw:: html <string>

Attributes

  • coefficients: Weights of the model(s).

  • intercepts: Weights of the intercepts, if used.

  • post_transform: Indicates the transform to apply to the regression output vector.<br>One of 'NONE,' 'SOFTMAX,' 'LOGISTIC,' 'SOFTMAX_ZERO,' or 'PROBIT' Default value is ``name: "post_transform"

    System Message: WARNING/2 (<string>, line 5); backlink

    Inline literal start-string without end-string.

s: "NONE" type: STRING `` * targets: The total number of regression targets, 1 if not defined. Default value is ``name: "targets" i: 1 type: INT ``

System Message: WARNING/2 (<string>, line 6); backlink

Inline literal start-string without end-string.

Inputs

  • X (heterogeneous)T: Data to be regressed.

Outputs

  • Y (heterogeneous)tensor(float): Regression outputs (one per target, per example).

Type Constraints

  • T tensor(float), tensor(double), tensor(int64), tensor(int32): The input must be a tensor of a numeric type.

.. code:: ipython3 lin_reg = OnnxLinearRegressor('input', coefficients=list(beta.ravel().astype(np.float64)), targets=2) inputs = {'input': X[:1].astype(np.float32)} try: model_onnx = lin_reg.to_onnx(inputs) except Exception as e: print(str(e).split("\n")[0]) .. parsed-literal:: Shape inference fails. .. code:: ipython3 from onnxconverter_common.data_types import FloatTensorType model_onnx = lin_reg.to_onnx(inputs, outputs=[('Yp', FloatTensorType((1, 1)))]) onnx2dotnb(model_onnx, width="80%") .. raw:: html
Conversion to ONNX ------------------ Describe the prediction function of a machine learned model with `Onnx Operators `__. scikit-learn to ONNX: sklearn-onnx ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code:: ipython3 clr = LogisticRegression(multi_class="auto", solver="liblinear").fit(X, y) clr .. parsed-literal:: LogisticRegression(solver='liblinear') .. code:: ipython3 from skl2onnx import to_onnx model_onnx = to_onnx(clr, X.astype(np.float32), target_opset=12) onnx2dotnb(model_onnx) .. raw:: html
Conversion of a pipeline ~~~~~~~~~~~~~~~~~~~~~~~~ .. code:: ipython3 from sklearn.pipeline import Pipeline from sklearn.decomposition import PCA pipe = Pipeline([('pca', PCA(n_components=2)), ('lr', LogisticRegression())]) pipe.fit(X, y) .. parsed-literal:: Pipeline(steps=[('pca', PCA(n_components=2)), ('lr', LogisticRegression())]) .. code:: ipython3 model_onnx = to_onnx(pipe, X.astype(np.float32), target_opset=12) onnx2dotnb(model_onnx, orientation="TB", width="30%") .. raw:: html
Runtime ------- Compute predictions of a machine learned models based on its `ONNX `__ definition. onnxruntime ~~~~~~~~~~~ There are `multiple runtimes `__. For CPU, GPU, ARM, `onnxruntime `__ is one option. .. code:: ipython3 from onnxruntime import InferenceSession sess = InferenceSession(model_onnx.SerializeToString()) label, proba = sess.run(None, {'X': X32}) label[:3] .. parsed-literal:: array([0, 0, 0], dtype=int64) .. code:: ipython3 pipe.predict_proba(X32)[:3] .. parsed-literal:: array([[9.81390001e-01, 1.86099916e-02, 7.11872743e-09], [9.76009954e-01, 2.39900265e-02, 1.93148667e-08], [9.84706803e-01, 1.52931912e-02, 6.28132306e-09]]) .. code:: ipython3 proba[:3] .. parsed-literal:: [{0: 0.9813900589942932, 1: 0.018609998747706413, 2: 7.118746925272035e-09}, {0: 0.9760100245475769, 1: 0.023990022018551826, 2: 1.9314878585419137e-08}, {0: 0.9847068190574646, 1: 0.015293179079890251, 2: 6.281324793633303e-09}] .. code:: ipython3 pipe.predict_proba(X32)[:3] .. parsed-literal:: array([[9.81390001e-01, 1.86099916e-02, 7.11872743e-09], [9.76009954e-01, 2.39900265e-02, 1.93148667e-08], [9.84706803e-01, 1.52931912e-02, 6.28132306e-09]]) .. code:: ipython3 import pandas pandas.DataFrame(proba).head() .. raw:: html
0 1 2
0 0.981390 0.018610 7.118747e-09
1 0.976010 0.023990 1.931488e-08
2 0.984707 0.015293 6.281325e-09
3 0.975605 0.024395 2.240644e-08
4 0.983403 0.016597 5.354823e-09
ONNX.js ~~~~~~~ `API ONNX.js `__ The example is not working yet. .. code:: ipython3 clr = LogisticRegression(multi_class="auto", solver="liblinear").fit(X, y) model_onnx = to_onnx(clr, X[:1].astype(np.float32)) .. code:: ipython3 model_js = ("var myarr = new Uint8Array([%s]);" % ",".join(map(lambda x: str(x), model_onnx.SerializeToString()))) model_js[:200] + " ... " + model_js[-20:] .. parsed-literal:: 'var myarr = new Uint8Array([8,4,18,8,115,107,108,50,111,110,110,120,26,5,49,46,55,46,48,34,7,97,105,46,111,110,110,120,40,0,50,0,58,229,4,10,141,2,10,1,88,18,5,108,97,98,101,108,18,18,112,114,111,98,9 ... 0,46,109,108,16,1]);' .. code:: ipython3 script = """ %s var myOnnxSession = new onnx.InferenceSession({ backendHint: 'webgl' }); var inferenceInputs = [ new onnx.Tensor(new Float32Array([5.1, 3.5, 1.4, 0.2]), "float32", [1, 4]) ]; myOnnxSession.loadModel(myarr).then(() => { myOnnxSession.run(inferenceInputs).then(output => { const outputTensor = output.values().next().value; document.getElementById("__ID__").innerHTML = "

" + String(outputTensor) + "

" }).catch(function(err) { document.getElementById("__ID__").innerHTML = err.message; }); }).catch(function(err) { document.getElementById("__ID__").innerHTML = err.message; }); """ % model_js .. code:: ipython3 from jyquickhelper import RenderJS jr = RenderJS(script, libs = [dict(path="https://cdn.jsdelivr.net/npm/onnxjs/dist/onnx.min.js", name="onnx", exports="onnx")]) jr .. raw:: html
Benchmark --------- LogisticRegression ~~~~~~~~~~~~~~~~~~ .. code:: ipython3 clr = LogisticRegression(multi_class="auto", solver="liblinear").fit(X, y) .. code:: ipython3 %timeit clr.predict_proba(X[:1]) .. parsed-literal:: 80.9 µs ± 5.95 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) .. code:: ipython3 sess = InferenceSession(model_onnx.SerializeToString()) X32 = X.astype(np.float32) %timeit sess.run(None, {'X': X32[:1]}) .. parsed-literal:: 22 µs ± 4.13 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) RandomForestClassifier ~~~~~~~~~~~~~~~~~~~~~~ .. code:: ipython3 clr = RandomForestClassifier(n_estimators=10).fit(X, y) .. code:: ipython3 %timeit clr.predict_proba(X[:1]) .. parsed-literal:: 890 µs ± 63 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) .. code:: ipython3 sess = InferenceSession(model_onnx.SerializeToString()) X32 = X.astype(np.float32) %timeit sess.run(None, {'X': X32[:1]}) .. parsed-literal:: 19.5 µs ± 1.04 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each) Transfer Learning ----------------- Insert an ONNX model into a pipeline. OnnxTransformer ~~~~~~~~~~~~~~~ .. code:: ipython3 pipe = Pipeline([('pca', PCA(n_components=2)), ('lr', LogisticRegression(multi_class="auto"))]) pipe.fit(X, y) model_onnx = to_onnx(pipe, X[:1].astype(np.float32)) .. code:: ipython3 from mlprodict.sklapi import OnnxTransformer tr = OnnxTransformer(model_onnx.SerializeToString(), output_name="output_probability") tr.fit() tr.transform(X)[:5] .. raw:: html
0 1 2
0 0.981390 0.018610 7.118747e-09
1 0.976010 0.023990 1.931488e-08
2 0.984707 0.015293 6.281325e-09
3 0.975605 0.024395 2.240644e-08
4 0.983403 0.016597 5.354823e-09
Within a pipeline ~~~~~~~~~~~~~~~~~ .. code:: ipython3 pipe = Pipeline([('onnx', OnnxTransformer(model_onnx.SerializeToString(), output_name="output_probability")), ('lr', LogisticRegression(multi_class="auto"))]) pipe.fit(X, y) dot = pipeline2dot(pipe, X) from jyquickhelper import RenderJsDot RenderJsDot(dot) .. raw:: html
Appendix -------- Profile RandomForestClassifier ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code:: ipython3 from sklearn.ensemble import RandomForestClassifier clr = RandomForestClassifier(n_estimators=2).fit(X, y) profile_fct_graph(lambda: [clr.predict(X) for i in range(0, 1000)], nb=30, figsize=(15, 3), title="Cumulated time inside functions when predicting\nRandomForestClassifier"); .. image:: onnx_sklearn_consortium_74_0.png Open source tools in this talk ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code:: ipython3 import onnx, skl2onnx, sklearn, onnxruntime, mlprodict mods = [onnx, skl2onnx, onnxruntime, sklearn, mlprodict] for m in mods: print(m.__name__, m.__version__) .. parsed-literal:: onnx 1.7.105 skl2onnx 1.7.0 onnxruntime 1.3.993 sklearn 0.24.dev0 mlprodict 0.3.1134