Coverage for mlinsights/mlmodel/transfer_transformer.py: 88%

41 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-28 08:46 +0100

1""" 

2@file 

3@brief Implements a transformer which wraps a predictor 

4to do transfer learning. 

5""" 

6import inspect 

7from sklearn.base import BaseEstimator, TransformerMixin 

8from .sklearn_testing import clone_with_fitted_parameters 

9 

10 

11class TransferTransformer(BaseEstimator, TransformerMixin): 

12 """ 

13 Wraps a predictor or a transformer in a transformer. 

14 This model is frozen: it cannot be trained and only 

15 computes the predictions. 

16 

17 .. index:: transfer learning, frozen model 

18 """ 

19 

20 def __init__(self, estimator, method=None, copy_estimator=True, 

21 trainable=False): 

22 """ 

23 @param estimator estimator to wrap in a transformer, it is cloned 

24 with the training data (deep copy) when fitted 

25 @param method if None, guess what method should be called, 

26 *transform* for a transformer, 

27 *predict_proba* for a classifier, 

28 *decision_function* if found, 

29 *predict* otherwiser 

30 @param copy_estimator copy the model instead of taking a reference 

31 @param trainable the transfered model must be trained 

32 """ 

33 TransformerMixin.__init__(self) 

34 BaseEstimator.__init__(self) 

35 self.estimator = estimator 

36 self.copy_estimator = copy_estimator 

37 self.trainable = trainable 

38 if method is None: 

39 if hasattr(estimator, "transform"): 

40 method = "transform" 

41 elif hasattr(estimator, "predict_proba"): 

42 method = "predict_proba" 

43 elif hasattr(estimator, "decision_function"): 

44 method = "decision_function" 

45 elif hasattr(estimator, "predict"): 

46 method = "predict" 

47 else: 

48 raise AttributeError( # pragma: no cover 

49 f"Cannot find a method transform, predict_proba, decision_function, " 

50 f"predict in object {type(estimator)}.") 

51 if not hasattr(estimator, method): 

52 raise AttributeError( # pragma: no cover 

53 f"Cannot find method '{method}' in object {type(estimator)}") 

54 self.method = method 

55 

56 def fit(self, X=None, y=None, sample_weight=None): 

57 """ 

58 The function does nothing. 

59 

60 :param X: unused 

61 :param y: unused 

62 :param sample_weight: unused 

63 :return: self: returns an instance of self. 

64 

65 Fitted attributes: 

66 

67 * `estimator_`: already trained estimator 

68 """ 

69 if self.copy_estimator: 

70 self.estimator_ = clone_with_fitted_parameters(self.estimator) 

71 from .sklearn_testing import assert_estimator_equal # pylint: disable=C0415 

72 assert_estimator_equal(self.estimator_, self.estimator) 

73 else: 

74 self.estimator_ = self.estimator 

75 if self.trainable: 

76 insp = inspect.signature(self.estimator_.fit) 

77 pars = insp.parameters 

78 if 'y' in pars and 'sample_weight' in pars: 

79 self.estimator_.fit(X, y, sample_weight) 

80 elif 'y' in pars: 

81 self.estimator_.fit(X, y) 

82 elif 'sample_weight' in pars: 

83 self.estimator_.fit(X, sample_weight=sample_weight) 

84 else: 

85 self.estimator_.fit(X) 

86 return self 

87 

88 def transform(self, X): 

89 """ 

90 Runs the predictions. 

91 

92 :param X: numpy array or sparse matrix of shape [n_samples,n_features] 

93 Training data 

94 :return: tranformed *X* 

95 """ 

96 meth = getattr(self.estimator_, self.method) 

97 return meth(X)