Coverage for mlprodict/onnx_conv/scorers/register.py: 100%

50 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2""" 

3@file 

4@brief Registers new converters. 

5""" 

6import copy 

7from sklearn.base import BaseEstimator, TransformerMixin 

8from sklearn import __all__ as sklearn__all__, __version__ as sklearn_version 

9from skl2onnx import ( 

10 update_registered_converter, 

11 update_registered_parser) 

12from skl2onnx.common.data_types import guess_tensor_type 

13from skl2onnx.common._apply_operation import apply_identity 

14 

15 

16class CustomScorerTransform(BaseEstimator, TransformerMixin): 

17 """ 

18 Wraps a scoring function into a transformer. Function @see fn 

19 register_scorers must be called to register the converter 

20 associated to this transform. It takes two inputs, expected values 

21 and predicted values and returns a score for each observation. 

22 """ 

23 

24 def __init__(self, name, fct, kwargs): 

25 """ 

26 @param name function name 

27 @param fct python function 

28 @param kwargs parameters function 

29 """ 

30 BaseEstimator.__init__(self) 

31 TransformerMixin.__init__(self) 

32 self.name_fct = name 

33 self._fct = fct 

34 self.kwargs = kwargs 

35 

36 def __repr__(self): # pylint: disable=W0222 

37 return "{}('{}', {}, {})".format( 

38 self.__class__.__name__, self.name_fct, 

39 self._fct.__name__, self.kwargs) 

40 

41 

42def custom_scorer_transform_parser(scope, model, inputs, custom_parsers=None): 

43 """ 

44 This function updates the inputs and the outputs for 

45 a @see cl CustomScorerTransform. 

46 

47 :param scope: Scope object 

48 :param model: A scikit-learn object (e.g., *OneHotEncoder* 

49 or *LogisticRegression*) 

50 :param inputs: A list of variables 

51 :param custom_parsers: parsers determines which outputs is expected 

52 for which particular task, default parsers are defined for 

53 classifiers, regressors, pipeline but they can be rewritten, 

54 *custom_parsers* is a dictionary 

55 ``{ type: fct_parser(scope, model, inputs, custom_parsers=None) }`` 

56 :return: A list of output variables which will be passed to next 

57 stage 

58 """ 

59 if custom_parsers is not None: # pragma: no cover 

60 raise NotImplementedError( 

61 "Case custom_parsers not empty is not implemented yet.") 

62 if isinstance(model, str): 

63 raise RuntimeError( # pragma: no cover 

64 f"Parameter model must be an object not a string '{model}'.") 

65 if len(inputs) != 2: 

66 raise RuntimeError( # pragma: no cover 

67 f"Two inputs expected not {len(inputs)}.") 

68 alias = 'Mlprodict' + model.__class__.__name__ 

69 this_operator = scope.declare_local_operator(alias, model) 

70 this_operator.inputs = inputs 

71 

72 scores = scope.declare_local_variable( 

73 'scores', guess_tensor_type(inputs[0].type)) 

74 this_operator.outputs.append(scores) 

75 return this_operator.outputs 

76 

77 

78def custom_scorer_transform_shape_calculator(operator): 

79 """ 

80 Computes the output shapes for a @see cl CustomScorerTransform. 

81 """ 

82 if len(operator.inputs) != 2: 

83 raise RuntimeError("Two inputs expected.") # pragma: no cover 

84 if len(operator.outputs) != 1: 

85 raise RuntimeError("One output expected.") # pragma: no cover 

86 

87 N = operator.inputs[0].type.shape[0] 

88 operator.outputs[0].type = copy.deepcopy(operator.inputs[0].type) 

89 operator.outputs[0].type.shape = [N, 1] 

90 

91 

92def custom_scorer_transform_converter(scope, operator, container): 

93 """ 

94 Selects the appropriate converter for a @see cl CustomScorerTransform. 

95 """ 

96 op = operator.raw_operator 

97 name = op.name_fct 

98 this_operator = scope.declare_local_operator('fct_' + name) 

99 this_operator.raw_operator = op 

100 this_operator.inputs = operator.inputs 

101 

102 score_name = scope.declare_local_variable( 

103 'scores', operator.inputs[0].type) 

104 this_operator.outputs.append(score_name) 

105 apply_identity(scope, score_name.full_name, 

106 operator.outputs[0].full_name, container) 

107 

108 

109def empty_shape_calculator(operator): 

110 """ 

111 Does nothing. 

112 """ 

113 pass 

114 

115 

116def register_scorers(): 

117 """ 

118 Registers operators for @see cl CustomScorerTransform. 

119 """ 

120 from .cdist_score import score_cdist_sum, convert_score_cdist_sum 

121 done = [] 

122 update_registered_parser( 

123 CustomScorerTransform, 

124 custom_scorer_transform_parser) 

125 

126 update_registered_converter( 

127 CustomScorerTransform, 

128 'MlprodictCustomScorerTransform', 

129 custom_scorer_transform_shape_calculator, 

130 custom_scorer_transform_converter) 

131 done.append(CustomScorerTransform) 

132 

133 update_registered_converter( 

134 score_cdist_sum, 'fct_score_cdist_sum', 

135 empty_shape_calculator, convert_score_cdist_sum, 

136 options={'cdist': [None, 'single-node']}) 

137 

138 return done