Coverage for mlprodict/onnx_conv/scorers/register.py: 100%
50 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2"""
3@file
4@brief Registers new converters.
5"""
6import copy
7from sklearn.base import BaseEstimator, TransformerMixin
8from sklearn import __all__ as sklearn__all__, __version__ as sklearn_version
9from skl2onnx import (
10 update_registered_converter,
11 update_registered_parser)
12from skl2onnx.common.data_types import guess_tensor_type
13from skl2onnx.common._apply_operation import apply_identity
16class CustomScorerTransform(BaseEstimator, TransformerMixin):
17 """
18 Wraps a scoring function into a transformer. Function @see fn
19 register_scorers must be called to register the converter
20 associated to this transform. It takes two inputs, expected values
21 and predicted values and returns a score for each observation.
22 """
24 def __init__(self, name, fct, kwargs):
25 """
26 @param name function name
27 @param fct python function
28 @param kwargs parameters function
29 """
30 BaseEstimator.__init__(self)
31 TransformerMixin.__init__(self)
32 self.name_fct = name
33 self._fct = fct
34 self.kwargs = kwargs
36 def __repr__(self): # pylint: disable=W0222
37 return "{}('{}', {}, {})".format(
38 self.__class__.__name__, self.name_fct,
39 self._fct.__name__, self.kwargs)
42def custom_scorer_transform_parser(scope, model, inputs, custom_parsers=None):
43 """
44 This function updates the inputs and the outputs for
45 a @see cl CustomScorerTransform.
47 :param scope: Scope object
48 :param model: A scikit-learn object (e.g., *OneHotEncoder*
49 or *LogisticRegression*)
50 :param inputs: A list of variables
51 :param custom_parsers: parsers determines which outputs is expected
52 for which particular task, default parsers are defined for
53 classifiers, regressors, pipeline but they can be rewritten,
54 *custom_parsers* is a dictionary
55 ``{ type: fct_parser(scope, model, inputs, custom_parsers=None) }``
56 :return: A list of output variables which will be passed to next
57 stage
58 """
59 if custom_parsers is not None: # pragma: no cover
60 raise NotImplementedError(
61 "Case custom_parsers not empty is not implemented yet.")
62 if isinstance(model, str):
63 raise RuntimeError( # pragma: no cover
64 f"Parameter model must be an object not a string '{model}'.")
65 if len(inputs) != 2:
66 raise RuntimeError( # pragma: no cover
67 f"Two inputs expected not {len(inputs)}.")
68 alias = 'Mlprodict' + model.__class__.__name__
69 this_operator = scope.declare_local_operator(alias, model)
70 this_operator.inputs = inputs
72 scores = scope.declare_local_variable(
73 'scores', guess_tensor_type(inputs[0].type))
74 this_operator.outputs.append(scores)
75 return this_operator.outputs
78def custom_scorer_transform_shape_calculator(operator):
79 """
80 Computes the output shapes for a @see cl CustomScorerTransform.
81 """
82 if len(operator.inputs) != 2:
83 raise RuntimeError("Two inputs expected.") # pragma: no cover
84 if len(operator.outputs) != 1:
85 raise RuntimeError("One output expected.") # pragma: no cover
87 N = operator.inputs[0].type.shape[0]
88 operator.outputs[0].type = copy.deepcopy(operator.inputs[0].type)
89 operator.outputs[0].type.shape = [N, 1]
92def custom_scorer_transform_converter(scope, operator, container):
93 """
94 Selects the appropriate converter for a @see cl CustomScorerTransform.
95 """
96 op = operator.raw_operator
97 name = op.name_fct
98 this_operator = scope.declare_local_operator('fct_' + name)
99 this_operator.raw_operator = op
100 this_operator.inputs = operator.inputs
102 score_name = scope.declare_local_variable(
103 'scores', operator.inputs[0].type)
104 this_operator.outputs.append(score_name)
105 apply_identity(scope, score_name.full_name,
106 operator.outputs[0].full_name, container)
109def empty_shape_calculator(operator):
110 """
111 Does nothing.
112 """
113 pass
116def register_scorers():
117 """
118 Registers operators for @see cl CustomScorerTransform.
119 """
120 from .cdist_score import score_cdist_sum, convert_score_cdist_sum
121 done = []
122 update_registered_parser(
123 CustomScorerTransform,
124 custom_scorer_transform_parser)
126 update_registered_converter(
127 CustomScorerTransform,
128 'MlprodictCustomScorerTransform',
129 custom_scorer_transform_shape_calculator,
130 custom_scorer_transform_converter)
131 done.append(CustomScorerTransform)
133 update_registered_converter(
134 score_cdist_sum, 'fct_score_cdist_sum',
135 empty_shape_calculator, convert_score_cdist_sum,
136 options={'cdist': [None, 'single-node']})
138 return done