.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "gyexamples/plot_fbegin_investigate.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_gyexamples_plot_fbegin_investigate.py: Intermediate results and investigation ====================================== .. index:: investigate, intermediate results There are many reasons why a user wants more than using the converted model into ONNX. Intermediate results may be needed, the output of every node in the graph. The ONNX may need to be altered to remove some nodes. Transfer learning is usually removing the last layers of a deep neural network. Another reaason is debugging. It often happens that the runtime fails to compute the predictions due to a shape mismatch. Then it is useful the get the shape of every intermediate result. This example looks into two ways of doing it. .. contents:: :local: Look into pipeline steps ++++++++++++++++++++++++ The first way is a tricky one: it overloads methods *transform*, *predict* and *predict_proba* to keep a copy of inputs and outputs. It then goes through every step of the pipeline. If the pipeline has *n* steps, it converts the pipeline with step 1, then the pipeline with steps 1, 2, then 1, 2, 3... .. GENERATED FROM PYTHON SOURCE LINES 31-43 .. code-block:: default from pyquickhelper.helpgen.graphviz_helper import plot_graphviz from mlprodict.onnxrt import OnnxInference import numpy from onnxruntime import InferenceSession from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from sklearn.cluster import KMeans from sklearn.datasets import load_iris from skl2onnx import to_onnx from skl2onnx.helpers import collect_intermediate_steps from skl2onnx.common.data_types import FloatTensorType .. GENERATED FROM PYTHON SOURCE LINES 44-45 The pipeline. .. GENERATED FROM PYTHON SOURCE LINES 45-55 .. code-block:: default data = load_iris() X = data.data pipe = Pipeline(steps=[ ('std', StandardScaler()), ('km', KMeans(3, n_init=3)) ]) pipe.fit(X) .. raw:: html
Pipeline(steps=[('std', StandardScaler()),
                    ('km', KMeans(n_clusters=3, n_init=3))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 56-59 The function goes through every step, overloads the methods *transform* and returns an ONNX graph for every step. .. GENERATED FROM PYTHON SOURCE LINES 59-64 .. code-block:: default steps = collect_intermediate_steps( pipe, "pipeline", [("X", FloatTensorType([None, X.shape[1]]))], target_opset=17) .. GENERATED FROM PYTHON SOURCE LINES 65-67 We call method transform to population the cache the overloaded methods *transform* keeps. .. GENERATED FROM PYTHON SOURCE LINES 67-69 .. code-block:: default pipe.transform(X) .. rst-class:: sphx-glr-script-out .. code-block:: none array([[0.21295824, 3.93798775, 3.10754476], [0.99604549, 3.95382074, 2.63656042], [0.65198444, 4.13117447, 2.94387269], [0.9034561 , 4.13271765, 2.84486296], [0.40215457, 4.06144609, 3.28962799], [1.21154793, 3.86223792, 3.51313646], [0.50244932, 4.15677572, 3.12819036], [0.09132468, 3.91801321, 2.97331633], [1.42174651, 4.3382174 , 2.8801353 ], [0.78993078, 3.99583006, 2.75905459], [0.78999385, 3.87890998, 3.31837573], [0.27618123, 4.02120307, 3.02948398], [1.03497888, 4.10011155, 2.76595612], [1.33482453, 4.56305254, 3.17050153], [1.63865558, 4.11847936, 3.89656355], [2.39898792, 4.47638882, 4.51935428], [1.20748818, 3.99261962, 3.61560711], [0.21618828, 3.86198022, 3.04390949], [1.20986655, 3.69134594, 3.35084273], [0.86706182, 4.05714483, 3.4998904 ], [0.50401564, 3.6138549 , 2.79419741], [0.66826437, 3.89803055, 3.27565102], [0.68658071, 4.45629399, 3.57494522], [0.47945627, 3.52156831, 2.54264558], [0.36345425, 3.93070404, 2.94590654], [0.99023912, 3.82090185, 2.51845424], [0.22683089, 3.73456384, 2.81300086], [0.2947186 , 3.84523973, 3.04731246], [0.25361098, 3.82824961, 2.93723696], [0.65019824, 4.03594026, 2.84760069], [0.80138328, 3.95437119, 2.69811903], [0.52309257, 3.5225036 , 2.72242829], [1.57658655, 4.46954142, 4.12376611], [1.87652483, 4.41909369, 4.2394859 ], [0.76858489, 3.91652237, 2.68072354], [0.54896332, 3.96064006, 2.83659725], [0.63079314, 3.75488844, 3.04726953], [0.45982568, 4.20049965, 3.38988959], [1.2336976 , 4.35265263, 2.96684876], [0.14580827, 3.85471765, 2.93706958], [0.20261743, 3.95886277, 3.10997274], [2.67055552, 4.57078291, 2.83602543], [0.90927099, 4.35694733, 3.1224242 ], [0.50081008, 3.65022755, 2.85446137], [0.92159916, 3.86390737, 3.3486771 ], [1.01946042, 3.9484868 , 2.61399316], [0.86953764, 4.09883868, 3.53005539], [0.72275914, 4.1725665 , 2.96744987], [0.72324305, 3.93062743, 3.33732825], [0.30295342, 3.91489912, 2.87886068], [3.43619989, 0.95750736, 1.95387566], [2.97232682, 0.9466811 , 1.47821452], [3.51850037, 0.71871864, 1.74151636], [3.33264308, 2.62604617, 0.90594068], [3.35747592, 1.0514057 , 0.96210827], [2.77550662, 1.73850861, 0.38001323], [3.01808184, 0.96451982, 1.65670177], [2.77360088, 3.23578838, 1.43712469], [3.21148368, 1.09377472, 1.1603633 ], [2.66294828, 2.35055888, 0.751242 ], [3.62389817, 3.66373389, 1.90332417], [2.70011145, 1.38542327, 0.83030561], [3.53658932, 2.681228 , 1.20130773], [2.98813829, 1.21797039, 0.72085649], [2.32311723, 1.9782459 , 0.74295345], [3.14311522, 0.9551773 , 1.51813446], [2.68234835, 1.60094448, 0.84854221], [2.63954211, 2.05708912, 0.52775884], [3.97369206, 2.279978 , 1.17560102], [2.87494798, 2.39366852, 0.60474619], [3.03853641, 1.12271439, 1.40894213], [2.8022861 , 1.56352928, 0.55727595], [3.68305664, 1.59156824, 0.80815932], [2.96833851, 1.47908056, 0.60504529], [2.9760862 , 1.23493556, 0.94838793], [3.13002382, 0.98428558, 1.2761158 ], [3.56679427, 1.05707167, 1.31005896], [3.5903606 , 0.47318724, 1.48525791], [2.93839428, 1.27132776, 0.65520603], [2.58203512, 2.37083163, 0.70326188], [2.99796537, 2.61709488, 0.82409272], [2.92597852, 2.69730785, 0.91137003], [2.68907313, 1.95551469, 0.35629162], [3.42215998, 1.34739865, 0.60036521], [2.62771445, 1.81182563, 0.94032904], [2.75915071, 1.35365664, 1.76682038], [3.30075052, 0.74728786, 1.54027482], [3.73017167, 2.15189351, 1.03483069], [2.37943811, 1.80338336, 0.84423689], [2.98789866, 2.33447813, 0.53323013], [2.89079656, 2.19216038, 0.43280167], [2.86642713, 1.18139498, 0.90190562], [2.86642575, 2.04560794, 0.32536203], [2.96966239, 3.27492589, 1.46387138], [2.77003779, 1.97639214, 0.27418362], [2.38255534, 1.76008279, 0.8415383 ], [2.55559903, 1.72969503, 0.59040569], [2.8455521 , 1.33226942, 0.77059821], [2.56987887, 2.98655062, 1.22724181], [2.64007308, 1.82288546, 0.38893008], [4.24274589, 1.05539208, 2.38646488], [3.57067982, 1.44261206, 0.82146892], [4.44150237, 0.58069397, 2.21282662], [3.69480186, 0.70205827, 1.24534544], [4.11613683, 0.49647223, 1.80759366], [5.03326801, 1.26891296, 2.84115224], [3.3503222 , 2.66587618, 1.2031793 ], [4.577021 , 0.96542298, 2.2981266 ], [4.363498 , 1.37097986, 1.54127294], [4.79334275, 1.54449866, 3.30225324], [3.62749566, 0.43950677, 1.79090042], [3.89360823, 0.97220888, 1.1628307 ], [4.1132966 , 0.27949769, 1.87366791], [3.82688169, 1.84740408, 0.97308603], [3.91538879, 1.43413837, 1.4534224 ], [3.89835633, 0.66092371, 1.98097779], [3.70128288, 0.39475286, 1.46658814], [5.18341242, 2.17187773, 3.92063363], [5.58136629, 1.86919068, 3.00612708], [4.02615768, 2.30984352, 1.15354861], [4.31907679, 0.57932392, 2.34636886], [3.4288432 , 1.5450542 , 0.96691967], [5.19031307, 1.52115666, 2.83120791], [3.64273089, 1.0897639 , 0.90651751], [4.00723617, 0.55112377, 2.22940092], [4.2637671 , 0.7361751 , 2.39962985], [3.45930032, 1.00725521, 0.86750338], [3.27575645, 0.87030875, 1.09591797], [4.05342943, 0.78591862, 1.45005755], [4.1585729 , 0.79296371, 2.09415614], [4.71100584, 1.10815582, 2.33553379], [5.12224641, 2.24518287, 3.95317717], [4.13401784, 0.82658273, 1.53936501], [3.39830644, 1.04816813, 0.85464265], [3.63719075, 1.53567185, 0.83082003], [5.08776655, 1.32694219, 2.92169281], [4.00416552, 1.06808926, 2.38371464], [3.58815834, 0.4499276 , 1.54100839], [3.19454679, 0.99768379, 1.04298328], [4.09907253, 0.28498177, 2.04279992], [4.28416057, 0.58561639, 2.15181579], [4.17402084, 0.53688622, 2.13603497], [3.57067982, 1.44261206, 0.82146892], [4.32128686, 0.58535117, 2.31981484], [4.3480018 , 0.87974172, 2.49921263], [4.1240495 , 0.50849909, 1.89629069], [3.97564407, 1.4623514 , 1.06655593], [3.7539635 , 0.36817752, 1.52091659], [3.7969924 , 1.06805486, 2.2485655 ], [3.25638099, 1.06004591, 1.07639373]]) .. GENERATED FROM PYTHON SOURCE LINES 70-72 We compute every step and compare ONNX and scikit-learn outputs. .. GENERATED FROM PYTHON SOURCE LINES 72-93 .. code-block:: default for step in steps: print('----------------------------') print(step['model']) onnx_step = step['onnx_step'] sess = InferenceSession(onnx_step.SerializeToString(), providers=['CPUExecutionProvider']) onnx_outputs = sess.run(None, {'X': X.astype(numpy.float32)}) onnx_output = onnx_outputs[-1] skl_outputs = step['model']._debug.outputs['transform'] # comparison diff = numpy.abs(skl_outputs.ravel() - onnx_output.ravel()).max() print("difference", diff) # That was the first way: dynamically overwrite # every method transform or predict in a scikit-learn # pipeline to capture the input and output of every step, # compare them to the output produced by truncated ONNX # graphs built from the first one. # .. rst-class:: sphx-glr-script-out .. code-block:: none ---------------------------- StandardScaler() difference 4.799262827148709e-07 ---------------------------- KMeans(n_clusters=3, n_init=3) difference 4.332024853268002e-06 .. GENERATED FROM PYTHON SOURCE LINES 94-101 Python runtime to look into every node ++++++++++++++++++++++++++++++++++++++ The python runtime may be useful to easily look into every node of the ONNX graph. This option can be used to check when the computation fails due to nan values or a dimension mismatch. .. GENERATED FROM PYTHON SOURCE LINES 101-109 .. code-block:: default onx = to_onnx(pipe, X[:1].astype(numpy.float32)) oinf = OnnxInference(onx) oinf.run({'X': X[:2].astype(numpy.float32)}, verbose=1, fLOG=print) .. rst-class:: sphx-glr-script-out .. code-block:: none +ki='Ad_Addcst': (3,) (dtype=float32 min=1.0327403545379639 max=5.035177230834961) +ki='Ge_Gemmcst': (3, 4) (dtype=float32 min=-1.3049873113632202 max=1.0688906908035278) +ki='Mu_Mulcst': (1,) (dtype=float32 min=0.0 max=0.0) -- OnnxInference: run 8 nodes with 1 inputs Onnx-Scaler(X) -> variable (name='Scaler') +kr='variable': (2, 4) (dtype=float32 min=-1.340226411819458 max=1.0190045833587646) Onnx-ReduceSumSquare(variable) -> Re_reduced0 (name='Re_ReduceSumSquare') +kr='Re_reduced0': (2, 1) (dtype=float32 min=4.850505828857422 max=5.376197338104248) Onnx-Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0 (name='Mu_Mul') +kr='Mu_C0': (2, 1) (dtype=float32 min=0.0 max=0.0) Onnx-Gemm(variable, Ge_Gemmcst, Mu_C0) -> Ge_Y0 (name='Ge_Gemm') +kr='Ge_Y0': (2, 3) (dtype=float32 min=-10.366023063659668 max=7.692880630493164) Onnx-Add(Re_reduced0, Ge_Y0) -> Ad_C01 (name='Ad_Add') +kr='Ad_C01': (2, 3) (dtype=float32 min=-4.98982572555542 max=12.543386459350586) Onnx-Add(Ad_Addcst, Ad_C01) -> Ad_C0 (name='Ad_Add1') +kr='Ad_C0': (2, 3) (dtype=float32 min=0.045351505279541016 max=15.632697105407715) Onnx-Sqrt(Ad_C0) -> scores (name='Sq_Sqrt') +kr='scores': (2, 3) (dtype=float32 min=0.2129589319229126 max=3.9538204669952393) Onnx-ArgMin(Ad_C0) -> label (name='Ar_ArgMin') +kr='label': (2,) (dtype=int64 min=0 max=0) {'label': array([0, 0]), 'scores': array([[0.21295893, 3.9379876 , 3.107545 ], [0.99604493, 3.9538205 , 2.6365604 ]], dtype=float32)} .. GENERATED FROM PYTHON SOURCE LINES 110-111 And to get a sense of the intermediate results. .. GENERATED FROM PYTHON SOURCE LINES 111-118 .. code-block:: default oinf.run({'X': X[:2].astype(numpy.float32)}, verbose=3, fLOG=print) # This way is usually better if you need to investigate # issues within the code of the runtime for an operator. # .. rst-class:: sphx-glr-script-out .. code-block:: none +ki='Ad_Addcst': (3,) (dtype=float32 min=1.0327403545379639 max=5.035177230834961 [5.035177 3.0893104 1.0327404] +ki='Ge_Gemmcst': (3, 4) (dtype=float32 min=-1.3049873113632202 max=1.0688906908035278 [[-1.0145789 0.85326266 -1.3049873 -1.2548935 ] [ 1.0688907 0.05759433 0.9689332 1.0023146 ] [-0.07723422 -0.9306213 0.32313818 0.23727821]] +ki='Mu_Mulcst': (1,) (dtype=float32 min=0.0 max=0.0 [0.] -kv='X' shape=(2, 4) dtype=float32 min=0.20000000298023224 max=5.099999904632568 -- OnnxInference: run 8 nodes with 1 inputs Onnx-Scaler(X) -> variable (name='Scaler') +kr='variable': (2, 4) (dtype=float32 min=-1.340226411819458 max=1.0190045833587646) [[-0.9006812 1.0190046 -1.3402264 -1.3154442 ] [-1.1430167 -0.13197924 -1.3402264 -1.3154442 ]] Onnx-ReduceSumSquare(variable) -> Re_reduced0 (name='Re_ReduceSumSquare') +kr='Re_reduced0': (2, 1) (dtype=float32 min=4.850505828857422 max=5.376197338104248) [[5.3761973] [4.850506 ]] Onnx-Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0 (name='Mu_Mul') +kr='Mu_C0': (2, 1) (dtype=float32 min=0.0 max=0.0) [[0.] [0.]] Onnx-Gemm(variable, Ge_Gemmcst, Mu_C0) -> Ge_Y0 (name='Ge_Gemm') +kr='Ge_Y0': (2, 3) (dtype=float32 min=-10.366023063659668 max=7.692880630493164) [[-10.366023 7.042239 3.2478971] [ -8.893578 7.6928806 1.0682037]] Onnx-Add(Re_reduced0, Ge_Y0) -> Ad_C01 (name='Ad_Add') +kr='Ad_C01': (2, 3) (dtype=float32 min=-4.98982572555542 max=12.543386459350586) [[-4.9898257 12.418436 8.624094 ] [-4.0430717 12.543386 5.9187098]] Onnx-Add(Ad_Addcst, Ad_C01) -> Ad_C0 (name='Ad_Add1') +kr='Ad_C0': (2, 3) (dtype=float32 min=0.045351505279541016 max=15.632697105407715) [[ 0.04535151 15.507747 9.656835 ] [ 0.9921055 15.632697 6.9514503 ]] Onnx-Sqrt(Ad_C0) -> scores (name='Sq_Sqrt') +kr='scores': (2, 3) (dtype=float32 min=0.2129589319229126 max=3.9538204669952393) [[0.21295893 3.9379876 3.107545 ] [0.99604493 3.9538205 2.6365604 ]] Onnx-ArgMin(Ad_C0) -> label (name='Ar_ArgMin') +kr='label': (2,) (dtype=int64 min=0 max=0) [0 0] [VALIDATE] type [VALIDATE] mis={} {'label': array([0, 0]), 'scores': array([[0.21295893, 3.9379876 , 3.107545 ], [0.99604493, 3.9538205 , 2.6365604 ]], dtype=float32)} .. GENERATED FROM PYTHON SOURCE LINES 119-121 Final graph +++++++++++ .. GENERATED FROM PYTHON SOURCE LINES 121-125 .. code-block:: default ax = plot_graphviz(oinf.to_dot()) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) .. image-sg:: /gyexamples/images/sphx_glr_plot_fbegin_investigate_001.png :alt: plot fbegin investigate :srcset: /gyexamples/images/sphx_glr_plot_fbegin_investigate_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 1.490 seconds) .. _sphx_glr_download_gyexamples_plot_fbegin_investigate.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_fbegin_investigate.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_fbegin_investigate.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_