Coverage for mlprodict/onnx_conv/register_rewritten_converters.py: 100%
35 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Rewrites some of the converters implemented in
4:epkg:`sklearn-onnx`.
5"""
6from sklearn.compose import TransformedTargetRegressor
7from skl2onnx.common._registration import (
8 _converter_pool, _shape_calculator_pool)
9try:
10 from skl2onnx.common._registration import RegisteredConverter
11except ImportError: # pragma: no cover
12 # sklearn-onnx <= 1.6.0
13 RegisteredConverter = lambda fct, opts: fct
14from skl2onnx import update_registered_converter
15from .sklconv.tree_converters import (
16 new_convert_sklearn_decision_tree_classifier,
17 new_convert_sklearn_decision_tree_regressor,
18 new_convert_sklearn_gradient_boosting_classifier,
19 new_convert_sklearn_gradient_boosting_regressor,
20 new_convert_sklearn_random_forest_classifier,
21 new_convert_sklearn_random_forest_regressor)
22from .sklconv.svm_converters import (
23 new_convert_sklearn_svm_classifier,
24 new_convert_sklearn_svm_regressor)
25from .sklconv.function_transformer_converters import (
26 new_calculate_sklearn_function_transformer_output_shapes,
27 new_convert_sklearn_function_transformer)
28from .sklconv.transformed_target_regressor import (
29 transformer_target_regressor_shape_calculator,
30 transformer_target_regressor_converter)
33_overwritten_operators = {
34 #
35 'SklearnOneClassSVM': RegisteredConverter(
36 new_convert_sklearn_svm_regressor,
37 _converter_pool['SklearnOneClassSVM'].get_allowed_options()),
38 'SklearnSVR': RegisteredConverter(
39 new_convert_sklearn_svm_regressor,
40 _converter_pool['SklearnSVR'].get_allowed_options()),
41 'SklearnSVC': RegisteredConverter(
42 new_convert_sklearn_svm_classifier,
43 _converter_pool['SklearnSVC'].get_allowed_options()),
44 #
45 'SklearnDecisionTreeRegressor': RegisteredConverter(
46 new_convert_sklearn_decision_tree_regressor,
47 _converter_pool['SklearnDecisionTreeRegressor'].get_allowed_options()),
48 'SklearnDecisionTreeClassifier': RegisteredConverter(
49 new_convert_sklearn_decision_tree_classifier,
50 _converter_pool['SklearnDecisionTreeClassifier'].get_allowed_options()),
51 #
52 'SklearnExtraTreeRegressor': RegisteredConverter(
53 new_convert_sklearn_decision_tree_regressor,
54 _converter_pool['SklearnExtraTreeRegressor'].get_allowed_options()),
55 'SklearnExtraTreeClassifier': RegisteredConverter(
56 new_convert_sklearn_decision_tree_classifier,
57 _converter_pool['SklearnExtraTreeClassifier'].get_allowed_options()),
58 #
59 'SklearnExtraTreesRegressor': RegisteredConverter(
60 new_convert_sklearn_random_forest_regressor,
61 _converter_pool['SklearnExtraTreesRegressor'].get_allowed_options()),
62 'SklearnExtraTreesClassifier': RegisteredConverter(
63 new_convert_sklearn_random_forest_classifier,
64 _converter_pool['SklearnExtraTreesClassifier'].get_allowed_options()),
65 #
66 'SklearnFunctionTransformer': RegisteredConverter(
67 new_convert_sklearn_function_transformer,
68 _converter_pool['SklearnFunctionTransformer'].get_allowed_options()),
69 #
70 'SklearnGradientBoostingRegressor': RegisteredConverter(
71 new_convert_sklearn_gradient_boosting_regressor,
72 _converter_pool['SklearnGradientBoostingRegressor'].get_allowed_options()),
73 'SklearnGradientBoostingClassifier': RegisteredConverter(
74 new_convert_sklearn_gradient_boosting_classifier,
75 _converter_pool['SklearnGradientBoostingClassifier'].get_allowed_options()),
76 #
77 'SklearnHistGradientBoostingRegressor': RegisteredConverter(
78 new_convert_sklearn_random_forest_regressor,
79 _converter_pool['SklearnHistGradientBoostingRegressor'].get_allowed_options()),
80 'SklearnHistGradientBoostingClassifier': RegisteredConverter(
81 new_convert_sklearn_random_forest_classifier,
82 _converter_pool['SklearnHistGradientBoostingClassifier'].get_allowed_options()),
83 #
84 'SklearnRandomForestRegressor': RegisteredConverter(
85 new_convert_sklearn_random_forest_regressor,
86 _converter_pool['SklearnRandomForestRegressor'].get_allowed_options()),
87 'SklearnRandomForestClassifier': RegisteredConverter(
88 new_convert_sklearn_random_forest_classifier,
89 _converter_pool['SklearnRandomForestClassifier'].get_allowed_options()),
90}
92_overwritten_shape_calculator = {
93 "SklearnFunctionTransformer":
94 new_calculate_sklearn_function_transformer_output_shapes,
95}
98def register_rewritten_operators(new_converters=None,
99 new_shape_calculators=None):
100 """
101 Registers modified operators and returns the old values.
103 :param new_converters: converters to rewrite or None
104 to rewrite default ones
105 :param new_shape_calculators: shape calculators to rewrite or
106 None to rewrite default ones
107 @return old converters, old shape calculators
108 """
109 old_conv = None
110 old_shape = None
112 if new_converters is None:
113 for rew in _overwritten_operators:
114 if rew not in _converter_pool:
115 raise KeyError( # pragma: no cover
116 f"skl2onnx was not imported and '{rew}' was not registered.")
117 old_conv = {k: _converter_pool[k] for k in _overwritten_operators}
118 _converter_pool.update(_overwritten_operators)
119 else:
120 for rew in new_converters:
121 if rew not in _converter_pool:
122 raise KeyError( # pragma: no cover
123 f"skl2onnx was not imported and '{rew}' was not registered.")
124 old_conv = {k: _converter_pool[k] for k in new_converters}
125 _converter_pool.update(new_converters)
127 if new_shape_calculators is None:
128 for rew in _overwritten_shape_calculator:
129 if rew not in _shape_calculator_pool:
130 raise KeyError( # pragma: no cover
131 f"skl2onnx was not imported and '{rew}' was not registered.")
132 old_shape = {k: _shape_calculator_pool[k]
133 for k in _overwritten_shape_calculator}
134 _shape_calculator_pool.update(_overwritten_shape_calculator)
135 else:
136 for rew in new_shape_calculators:
137 if rew not in _shape_calculator_pool:
138 raise KeyError( # pragma: no cover
139 f"skl2onnx was not imported and '{rew}' was not registered.")
140 old_shape = {k: _shape_calculator_pool[k]
141 for k in new_shape_calculators}
142 _shape_calculator_pool.update(new_shape_calculators)
143 return old_conv, old_shape
146def register_new_operators():
147 """
148 Registers new operator relying on pieces implemented in this package
149 such as the numpy API for ONNX.
150 """
151 update_registered_converter(
152 TransformedTargetRegressor, "SklearnTransformedTargetRegressor",
153 transformer_target_regressor_shape_calculator,
154 transformer_target_regressor_converter,
155 overwrite=True, options=None)