Coverage for mlprodict/onnx_conv/register_rewritten_converters.py: 100%

35 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Rewrites some of the converters implemented in 

4:epkg:`sklearn-onnx`. 

5""" 

6from sklearn.compose import TransformedTargetRegressor 

7from skl2onnx.common._registration import ( 

8 _converter_pool, _shape_calculator_pool) 

9try: 

10 from skl2onnx.common._registration import RegisteredConverter 

11except ImportError: # pragma: no cover 

12 # sklearn-onnx <= 1.6.0 

13 RegisteredConverter = lambda fct, opts: fct 

14from skl2onnx import update_registered_converter 

15from .sklconv.tree_converters import ( 

16 new_convert_sklearn_decision_tree_classifier, 

17 new_convert_sklearn_decision_tree_regressor, 

18 new_convert_sklearn_gradient_boosting_classifier, 

19 new_convert_sklearn_gradient_boosting_regressor, 

20 new_convert_sklearn_random_forest_classifier, 

21 new_convert_sklearn_random_forest_regressor) 

22from .sklconv.svm_converters import ( 

23 new_convert_sklearn_svm_classifier, 

24 new_convert_sklearn_svm_regressor) 

25from .sklconv.function_transformer_converters import ( 

26 new_calculate_sklearn_function_transformer_output_shapes, 

27 new_convert_sklearn_function_transformer) 

28from .sklconv.transformed_target_regressor import ( 

29 transformer_target_regressor_shape_calculator, 

30 transformer_target_regressor_converter) 

31 

32 

33_overwritten_operators = { 

34 # 

35 'SklearnOneClassSVM': RegisteredConverter( 

36 new_convert_sklearn_svm_regressor, 

37 _converter_pool['SklearnOneClassSVM'].get_allowed_options()), 

38 'SklearnSVR': RegisteredConverter( 

39 new_convert_sklearn_svm_regressor, 

40 _converter_pool['SklearnSVR'].get_allowed_options()), 

41 'SklearnSVC': RegisteredConverter( 

42 new_convert_sklearn_svm_classifier, 

43 _converter_pool['SklearnSVC'].get_allowed_options()), 

44 # 

45 'SklearnDecisionTreeRegressor': RegisteredConverter( 

46 new_convert_sklearn_decision_tree_regressor, 

47 _converter_pool['SklearnDecisionTreeRegressor'].get_allowed_options()), 

48 'SklearnDecisionTreeClassifier': RegisteredConverter( 

49 new_convert_sklearn_decision_tree_classifier, 

50 _converter_pool['SklearnDecisionTreeClassifier'].get_allowed_options()), 

51 # 

52 'SklearnExtraTreeRegressor': RegisteredConverter( 

53 new_convert_sklearn_decision_tree_regressor, 

54 _converter_pool['SklearnExtraTreeRegressor'].get_allowed_options()), 

55 'SklearnExtraTreeClassifier': RegisteredConverter( 

56 new_convert_sklearn_decision_tree_classifier, 

57 _converter_pool['SklearnExtraTreeClassifier'].get_allowed_options()), 

58 # 

59 'SklearnExtraTreesRegressor': RegisteredConverter( 

60 new_convert_sklearn_random_forest_regressor, 

61 _converter_pool['SklearnExtraTreesRegressor'].get_allowed_options()), 

62 'SklearnExtraTreesClassifier': RegisteredConverter( 

63 new_convert_sklearn_random_forest_classifier, 

64 _converter_pool['SklearnExtraTreesClassifier'].get_allowed_options()), 

65 # 

66 'SklearnFunctionTransformer': RegisteredConverter( 

67 new_convert_sklearn_function_transformer, 

68 _converter_pool['SklearnFunctionTransformer'].get_allowed_options()), 

69 # 

70 'SklearnGradientBoostingRegressor': RegisteredConverter( 

71 new_convert_sklearn_gradient_boosting_regressor, 

72 _converter_pool['SklearnGradientBoostingRegressor'].get_allowed_options()), 

73 'SklearnGradientBoostingClassifier': RegisteredConverter( 

74 new_convert_sklearn_gradient_boosting_classifier, 

75 _converter_pool['SklearnGradientBoostingClassifier'].get_allowed_options()), 

76 # 

77 'SklearnHistGradientBoostingRegressor': RegisteredConverter( 

78 new_convert_sklearn_random_forest_regressor, 

79 _converter_pool['SklearnHistGradientBoostingRegressor'].get_allowed_options()), 

80 'SklearnHistGradientBoostingClassifier': RegisteredConverter( 

81 new_convert_sklearn_random_forest_classifier, 

82 _converter_pool['SklearnHistGradientBoostingClassifier'].get_allowed_options()), 

83 # 

84 'SklearnRandomForestRegressor': RegisteredConverter( 

85 new_convert_sklearn_random_forest_regressor, 

86 _converter_pool['SklearnRandomForestRegressor'].get_allowed_options()), 

87 'SklearnRandomForestClassifier': RegisteredConverter( 

88 new_convert_sklearn_random_forest_classifier, 

89 _converter_pool['SklearnRandomForestClassifier'].get_allowed_options()), 

90} 

91 

92_overwritten_shape_calculator = { 

93 "SklearnFunctionTransformer": 

94 new_calculate_sklearn_function_transformer_output_shapes, 

95} 

96 

97 

98def register_rewritten_operators(new_converters=None, 

99 new_shape_calculators=None): 

100 """ 

101 Registers modified operators and returns the old values. 

102 

103 :param new_converters: converters to rewrite or None 

104 to rewrite default ones 

105 :param new_shape_calculators: shape calculators to rewrite or 

106 None to rewrite default ones 

107 @return old converters, old shape calculators 

108 """ 

109 old_conv = None 

110 old_shape = None 

111 

112 if new_converters is None: 

113 for rew in _overwritten_operators: 

114 if rew not in _converter_pool: 

115 raise KeyError( # pragma: no cover 

116 f"skl2onnx was not imported and '{rew}' was not registered.") 

117 old_conv = {k: _converter_pool[k] for k in _overwritten_operators} 

118 _converter_pool.update(_overwritten_operators) 

119 else: 

120 for rew in new_converters: 

121 if rew not in _converter_pool: 

122 raise KeyError( # pragma: no cover 

123 f"skl2onnx was not imported and '{rew}' was not registered.") 

124 old_conv = {k: _converter_pool[k] for k in new_converters} 

125 _converter_pool.update(new_converters) 

126 

127 if new_shape_calculators is None: 

128 for rew in _overwritten_shape_calculator: 

129 if rew not in _shape_calculator_pool: 

130 raise KeyError( # pragma: no cover 

131 f"skl2onnx was not imported and '{rew}' was not registered.") 

132 old_shape = {k: _shape_calculator_pool[k] 

133 for k in _overwritten_shape_calculator} 

134 _shape_calculator_pool.update(_overwritten_shape_calculator) 

135 else: 

136 for rew in new_shape_calculators: 

137 if rew not in _shape_calculator_pool: 

138 raise KeyError( # pragma: no cover 

139 f"skl2onnx was not imported and '{rew}' was not registered.") 

140 old_shape = {k: _shape_calculator_pool[k] 

141 for k in new_shape_calculators} 

142 _shape_calculator_pool.update(new_shape_calculators) 

143 return old_conv, old_shape 

144 

145 

146def register_new_operators(): 

147 """ 

148 Registers new operator relying on pieces implemented in this package 

149 such as the numpy API for ONNX. 

150 """ 

151 update_registered_converter( 

152 TransformedTargetRegressor, "SklearnTransformedTargetRegressor", 

153 transformer_target_regressor_shape_calculator, 

154 transformer_target_regressor_converter, 

155 overwrite=True, options=None)