Coverage for mlprodict/onnx_conv/operator_converters/conv_transfer_transformer.py: 96%

57 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Converters for models from :epkg:`mlinsights`. 

4""" 

5from sklearn.base import is_classifier 

6from skl2onnx import get_model_alias 

7from skl2onnx.common._registration import ( 

8 get_shape_calculator, _converter_pool, _shape_calculator_pool) 

9from skl2onnx._parse import _parse_sklearn 

10from skl2onnx.common._apply_operation import apply_identity 

11from skl2onnx.common._topology import Scope, Variable # pylint: disable=E0611,E0001 

12from skl2onnx._supported_operators import sklearn_operator_name_map 

13 

14 

15def parser_transfer_transformer(scope, model, inputs, custom_parsers=None): 

16 """ 

17 Parser for :epkg:`TransferTransformer`. 

18 """ 

19 if len(inputs) != 1: 

20 raise RuntimeError( # pragma: no cover 

21 "Only one input (not %d) is allowed for model type %r." 

22 "" % (len(inputs), type(model))) 

23 if custom_parsers is not None and model in custom_parsers: 

24 return custom_parsers[model]( 

25 scope, model, inputs, custom_parsers=custom_parsers) 

26 

27 if model.method == 'predict_proba': 

28 name = 'probabilities' 

29 elif model.method == 'transform': 

30 name = 'variable' 

31 else: 

32 raise NotImplementedError( # pragma: no cover 

33 "Unable to defined the output for method='{}' and model='{}'." 

34 "".format(model.method, model.__class__.__name__)) 

35 

36 prob = scope.declare_local_variable(name, inputs[0].type.__class__()) 

37 alias = get_model_alias(type(model)) 

38 this_operator = scope.declare_local_operator(alias, model) 

39 this_operator.inputs = inputs 

40 this_operator.outputs.append(prob) 

41 return this_operator.outputs 

42 

43 

44def shape_calculator_transfer_transformer(operator): 

45 """ 

46 Shape calculator for :epkg:`TransferTransformer`. 

47 """ 

48 if len(operator.inputs) != 1: 

49 raise RuntimeError( # pragma: no cover 

50 "Only one input (not %d) is allowed for model %r." 

51 "" % (len(operator.inputs), operator)) 

52 op = operator.raw_operator 

53 alias = get_model_alias(type(op.estimator_)) 

54 calc = get_shape_calculator(alias) 

55 

56 options = (None if not hasattr(operator.scope, 'options') 

57 else operator.scope.options) 

58 if is_classifier(op.estimator_): 

59 if options is None: 

60 options = {} 

61 options = {id(op.estimator_): {'zipmap': False}} 

62 registered_models = dict( 

63 conv=_converter_pool, shape=_shape_calculator_pool, 

64 aliases=sklearn_operator_name_map) 

65 scope = Scope('temp', options=options, 

66 registered_models=registered_models) 

67 inputs = [ 

68 Variable(v.onnx_name, v.onnx_name, type=v.type, scope=scope) 

69 for v in operator.inputs] 

70 res = _parse_sklearn(scope, op.estimator_, inputs) 

71 this_operator = res[0]._parent 

72 calc(this_operator) 

73 

74 if op.method == 'predict_proba': 

75 operator.outputs[0].type = this_operator.outputs[1].type 

76 elif op.method == 'transform': 

77 operator.outputs[0].type = this_operator.outputs[0].type 

78 else: 

79 raise NotImplementedError( # pragma: no cover 

80 "Unable to defined the output for method='{}' and model='{}'.".format( 

81 op.method, op.__class__.__name__)) 

82 if len(operator.inputs) != 1: 

83 raise RuntimeError( # pragma: no cover 

84 "Only one input (not %d) is allowed for model %r." 

85 "" % (len(operator.inputs), operator)) 

86 

87 

88def convert_transfer_transformer(scope, operator, container): 

89 """ 

90 Converters for :epkg:`TransferTransformer`. 

91 """ 

92 op = operator.raw_operator 

93 

94 opts = scope.get_options(op) 

95 if opts is None: 

96 opts = {} 

97 if is_classifier(op.estimator_): 

98 opts['zipmap'] = False 

99 container.add_options(id(op.estimator_), opts) 

100 scope.add_options(id(op.estimator_), opts) 

101 

102 outputs = _parse_sklearn(scope, op.estimator_, operator.inputs) 

103 

104 if op.method == 'predict_proba': 

105 index = 1 

106 elif op.method == 'transform': 

107 index = 0 

108 else: 

109 raise NotImplementedError( # pragma: no cover 

110 "Unable to defined the output for method='{}' and model='{}'." 

111 "".format(op.method, op.__class__.__name__)) 

112 

113 apply_identity(scope, outputs[index].onnx_name, 

114 operator.outputs[0].full_name, container, 

115 operator_name=scope.get_unique_operator_name("IdentityTT"))