Choose appropriate output of a classifier#

A scikit-learn classifier usually returns a matrix of probabilities. By default, sklearn-onnx converts that matrix into a list of dictionaries where each probabily is mapped to its class id or name. That mechanism retains the class names but is slower. Let’s see what other options are available.

Train a model and convert it#

from timeit import repeat
import numpy
import sklearn
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import onnxruntime as rt
import onnx
import skl2onnx
from skl2onnx.common.data_types import FloatTensorType
from skl2onnx import to_onnx
from sklearn.linear_model import LogisticRegression
from sklearn.multioutput import MultiOutputClassifier

iris = load_iris()
X, y = iris.data, iris.target
X = X.astype(numpy.float32)
y = y * 2 + 10  # to get labels different from [0, 1, 2]
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = LogisticRegression(max_iter=500)
clr.fit(X_train, y_train)
print(clr)

onx = to_onnx(clr, X_train, target_opset={'': 14, 'ai.onnx.ml': 2})
LogisticRegression(max_iter=500)

Default behaviour: zipmap=True#

The output type for the probabilities is a list of dictionaries.

sess = rt.InferenceSession(onx.SerializeToString(),
                           providers=['CPUExecutionProvider'])
res = sess.run(None, {'X': X_test})
print(res[1][:2])
print("probabilities type:", type(res[1]))
print("type for the first observations:", type(res[1][0]))
[{10: 1.1451697901065927e-05, 12: 0.020342448726296425, 14: 0.9796461462974548}, {10: 0.01027061976492405, 12: 0.7854514718055725, 14: 0.20427796244621277}]
probabilities type: <class 'list'>
type for the first observations: <class 'dict'>

Option zipmap=False#

Probabilities are now a matrix.

initial_type = [('float_input', FloatTensorType([None, 4]))]
options = {id(clr): {'zipmap': False}}
onx2 = to_onnx(clr, X_train, options=options,
               target_opset={'': 14, 'ai.onnx.ml': 2})

sess2 = rt.InferenceSession(onx2.SerializeToString(),
                            providers=['CPUExecutionProvider'])
res2 = sess2.run(None, {'X': X_test})
print(res2[1][:2])
print("probabilities type:", type(res2[1]))
print("type for the first observations:", type(res2[1][0]))
[[1.1451698e-05 2.0342449e-02 9.7964615e-01]
 [1.0270620e-02 7.8545147e-01 2.0427796e-01]]
probabilities type: <class 'numpy.ndarray'>
type for the first observations: <class 'numpy.ndarray'>

Option zipmap=’columns’#

This options removes the final operator ZipMap and splits the probabilities into columns. The final model produces one output for the label, and one output per class.

options = {id(clr): {'zipmap': 'columns'}}
onx3 = to_onnx(clr, X_train, options=options,
               target_opset={'': 14, 'ai.onnx.ml': 2})

sess3 = rt.InferenceSession(onx3.SerializeToString(),
                            providers=['CPUExecutionProvider'])
res3 = sess3.run(None, {'X': X_test})
for i, out in enumerate(sess3.get_outputs()):
    print(
        f"output: '{out.name}' shape={res3[i].shape} values={res3[i][:2]}...")
output: 'output_label' shape=(38,) values=[14 12]...
output: 'i10' shape=(38,) values=[1.1451698e-05 1.0270620e-02]...
output: 'i12' shape=(38,) values=[0.02034245 0.7854515 ]...
output: 'i14' shape=(38,) values=[0.97964615 0.20427796]...

Let’s compare prediction time#

print("Average time with ZipMap:")
print(sum(repeat(lambda: sess.run(None, {'X': X_test}),
                 number=100, repeat=10)) / 10)

print("Average time without ZipMap:")
print(sum(repeat(lambda: sess2.run(None, {'X': X_test}),
                 number=100, repeat=10)) / 10)

print("Average time without ZipMap but with columns:")
print(sum(repeat(lambda: sess3.run(None, {'X': X_test}),
                 number=100, repeat=10)) / 10)

# The prediction is much faster without ZipMap
# on this example.
# The optimisation is even faster when the classes
# are described with strings and not integers
# as the final result (list of dictionaries) may copy
# many times the same information with onnxruntime.
Average time with ZipMap:
0.011111683095805347
Average time without ZipMap:
0.005920915305614472
Average time without ZipMap but with columns:
0.00971748341107741

Option zimpap=False and output_class_labels=True#

Option zipmap=False seems a better choice because it is much faster but labels are lost in the process. Option output_class_labels can be used to expose the labels as a third output.

initial_type = [('float_input', FloatTensorType([None, 4]))]
options = {id(clr): {'zipmap': False, 'output_class_labels': True}}
onx4 = to_onnx(clr, X_train, options=options,
               target_opset={'': 14, 'ai.onnx.ml': 2})

sess4 = rt.InferenceSession(onx4.SerializeToString(),
                            providers=['CPUExecutionProvider'])
res4 = sess4.run(None, {'X': X_test})
print(res4[1][:2])
print("probabilities type:", type(res4[1]))
print("class labels:", res4[2])
[[1.1451698e-05 2.0342449e-02 9.7964615e-01]
 [1.0270620e-02 7.8545147e-01 2.0427796e-01]]
probabilities type: <class 'numpy.ndarray'>
class labels: [10 12 14]

Processing time.

print("Average time without ZipMap but with output_class_labels:")
print(sum(repeat(lambda: sess4.run(None, {'X': X_test}),
                 number=100, repeat=10)) / 10)
Average time without ZipMap but with output_class_labels:
0.006411535281222314

MultiOutputClassifier#

This model is equivalent to several classifiers, one for every label to predict. Instead of returning a matrix of probabilities, it returns a sequence of matrices. Let’s first modify the labels to get a problem for a MultiOutputClassifier.

y = numpy.vstack([y, y + 100]).T
y[::5, 1] = 1000  # Let's a fourth class.
print(y[:5])
[[  10 1000]
 [  10  110]
 [  10  110]
 [  10  110]
 [  10  110]]

Let’s train a MultiOutputClassifier.

X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = MultiOutputClassifier(LogisticRegression(max_iter=500))
clr.fit(X_train, y_train)
print(clr)

onx5 = to_onnx(clr, X_train, target_opset={'': 14, 'ai.onnx.ml': 2})

sess5 = rt.InferenceSession(onx5.SerializeToString(),
                            providers=['CPUExecutionProvider'])
res5 = sess5.run(None, {'X': X_test[:3]})
print(res5)
MultiOutputClassifier(estimator=LogisticRegression(max_iter=500))
somewhere/workspace/onnxcustom/onnxcustom_UT_39_std/_venv/lib/python3.9/site-packages/skl2onnx/_parse.py:529: UserWarning: Option zipmap is ignored for model <class 'sklearn.multioutput.MultiOutputClassifier'>. Set option zipmap to False to remove this message.
  warnings.warn(
[array([[ 14, 114],
       [ 14, 114],
       [ 14, 114]], dtype=int64), [array([[4.3274203e-04, 1.8162604e-01, 8.1794125e-01],
       [7.1996351e-04, 3.5074779e-01, 6.4853227e-01],
       [2.2366168e-05, 4.4211719e-02, 9.5576590e-01]], dtype=float32), array([[1.4372485e-03, 2.3755349e-01, 6.5266651e-01, 1.0834280e-01],
       [1.7852996e-03, 3.8362366e-01, 4.7284558e-01, 1.4174546e-01],
       [2.8334750e-04, 1.5673934e-01, 7.3674095e-01, 1.0623635e-01]],
      dtype=float32)]]

Option zipmap is ignored. Labels are missing but they can be added back as a third output.

onx6 = to_onnx(clr, X_train, target_opset={'': 14, 'ai.onnx.ml': 2},
               options={'zipmap': False, 'output_class_labels': True})

sess6 = rt.InferenceSession(onx6.SerializeToString(),
                            providers=['CPUExecutionProvider'])
res6 = sess6.run(None, {'X': X_test[:3]})
print("predicted labels", res6[0])
print("predicted probabilies", res6[1])
print("class labels", res6[2])
predicted labels [[ 14 114]
 [ 14 114]
 [ 14 114]]
predicted probabilies [array([[4.3274203e-04, 1.8162604e-01, 8.1794125e-01],
       [7.1996351e-04, 3.5074779e-01, 6.4853227e-01],
       [2.2366168e-05, 4.4211719e-02, 9.5576590e-01]], dtype=float32), array([[1.4372485e-03, 2.3755349e-01, 6.5266651e-01, 1.0834280e-01],
       [1.7852996e-03, 3.8362366e-01, 4.7284558e-01, 1.4174546e-01],
       [2.8334750e-04, 1.5673934e-01, 7.3674095e-01, 1.0623635e-01]],
      dtype=float32)]
class labels [array([10, 12, 14], dtype=int64), array([ 110,  112,  114, 1000], dtype=int64)]

Versions used for this example

print("numpy:", numpy.__version__)
print("scikit-learn:", sklearn.__version__)
print("onnx: ", onnx.__version__)
print("onnxruntime: ", rt.__version__)
print("skl2onnx: ", skl2onnx.__version__)
numpy: 1.24.1
scikit-learn: 1.2.0
onnx:  1.13.0
onnxruntime:  1.14.92+cpu
skl2onnx:  1.13.1

Total running time of the script: ( 0 minutes 0.926 seconds)

Gallery generated by Sphinx-Gallery