Note
Click here to download the full example code
A converter for a TransformedTargetRegressor#
There is no easy way to convert a
sklearn.preprocessing.FunctionTransformer
or
a :epkg:`sklearn.compose.TransformedTargetRegressor` unless
the function is written in such a way the conversion is implicit.
from typing import Any
import numpy as np
from sklearn.compose import TransformedTargetRegressor
from sklearn.preprocessing import FunctionTransformer
from sklearn.linear_model import LinearRegression
from mlprodict.onnx_conv import to_onnx
from mlprodict import __max_supported_opset__ as TARGET_OPSET
from mlprodict.npy import onnxnumpy_default, NDArray
from mlprodict.onnxrt import OnnxInference
import mlprodict.npy.numpy_onnx_impl as npnx
TransformedTargetRegressor#
@onnxnumpy_default
def onnx_log_1(x: NDArray[Any, np.float32]) -> NDArray[(None, None), np.float32]:
return npnx.log1p(x)
@onnxnumpy_default
def onnx_exp_1(x: NDArray[Any, np.float32]) -> NDArray[(None, None), np.float32]:
return npnx.exp(x) - np.float32(1)
model = TransformedTargetRegressor(
regressor=LinearRegression(),
func=onnx_log_1, inverse_func=onnx_exp_1)
x = np.arange(18).reshape((-1, 3)).astype(np.float32)
y = x.sum(axis=1)
model.fit(x, y)
expected = model.predict(x)
print(expected)
[ 5.3555384 9.108676 15.0781555 24.572792 39.67432 63.693733 ]
Conversion to ONNX
onx = to_onnx(model, x, rewrite_ops=True, target_opset=TARGET_OPSET)
oinf = OnnxInference(onx)
got = oinf.run({'X': x})
print(got)
{'variable': array([[ 5.3555384],
[ 9.108676 ],
[15.0781555],
[24.572792 ],
[39.67432 ],
[63.693733 ]], dtype=float32)}
FunctionTransformer#
model = FunctionTransformer(onnx_log_1)
model.fit(x, y)
expected = model.transform(x)
print(expected)
[[0. 0.6931472 1.0986123]
[1.3862944 1.609438 1.7917595]
[1.9459101 2.0794415 2.1972246]
[2.3025851 2.3978953 2.4849067]
[2.5649493 2.6390574 2.7080503]
[2.7725887 2.8332133 2.8903718]]
Conversion to ONNX
onx = to_onnx(model, x, rewrite_ops=True, target_opset=TARGET_OPSET)
oinf = OnnxInference(onx)
got = oinf.run({'X': x})
print(got)
{'variable': array([[0. , 0.6931472, 1.0986123],
[1.3862944, 1.609438 , 1.7917595],
[1.9459101, 2.0794415, 2.1972246],
[2.3025851, 2.3978953, 2.4849067],
[2.5649493, 2.6390574, 2.7080503],
[2.7725887, 2.8332133, 2.8903718]], dtype=float32)}
Total running time of the script: ( 0 minutes 3.255 seconds)