Coverage for mlinsights/mlbatch/pipeline_cache.py: 84%

55 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-08-09 08:45 +0200

1""" 

2@file 

3@brief Caches training. 

4""" 

5from distutils.version import StrictVersion # pylint: disable=W0402 

6from sklearn import __version__ as skl_version 

7from sklearn.base import clone 

8from sklearn.pipeline import Pipeline, _fit_transform_one 

9from sklearn.utils import _print_elapsed_time 

10from .cache_model import MLCache 

11 

12 

13def isskl023(): 

14 "Tells if :epkg:`scikit-learn` is more recent than 0.23." 

15 v1 = ".".join(skl_version.split('.')[:2]) 

16 return StrictVersion(v1) >= StrictVersion('0.23') 

17 

18 

19class PipelineCache(Pipeline): 

20 """ 

21 Same as :epkg:`sklearn:pipeline:Pipeline` but it can 

22 skip training if it detects a step was already trained 

23 the model was already trained accross 

24 even in a different pipeline. 

25 

26 :param steps: list 

27 List of (name, transform) tuples (implementing fit/transform) that are 

28 chained, in the order in which they are chained, with the last object 

29 an estimator. 

30 :param cache_name: name of the cache, if None, a new name is created 

31 :param verbose: boolean, optional 

32 If True, the time elapsed while fitting each step will be printed as it 

33 is completed. 

34 

35 Other attributes: 

36 

37 :param named_steps: bunch object, a dictionary with attribute access 

38 Read-only attribute to access any step parameter by user given name. 

39 Keys are step names and values are steps parameters. 

40 """ 

41 

42 def __init__(self, steps, cache_name=None, verbose=False): 

43 self.cache_name = cache_name 

44 Pipeline.__init__(self, steps, memory=None, verbose=verbose) 

45 if cache_name is None: 

46 cache_name = "Pipeline%d" % id(self) 

47 self.cache_name = cache_name 

48 

49 def _get_fit_params_steps(self, fit_params): 

50 fit_params_steps = {name: {} for name, step in self.steps 

51 if step is not None} 

52 

53 for pname, pval in fit_params.items(): 

54 if '__' not in pname: 

55 if not isinstance(pval, dict): 

56 raise ValueError( # pragma: no cover 

57 f"For scikit-learn < 0.23, " 

58 f"Pipeline.fit does not accept the {pname} parameter. " 

59 f"You can pass parameters to specific steps of your " 

60 f"pipeline using the stepname__parameter format, e.g. " 

61 f"`Pipeline.fit(X, y, logisticregression__sample_weight" 

62 f"=sample_weight)`.") 

63 else: 

64 fit_params_steps[pname].update(pval) 

65 else: 

66 step, param = pname.split('__', 1) 

67 fit_params_steps[step][param] = pval 

68 return fit_params_steps 

69 

70 def _fit(self, X, y=None, **fit_params): 

71 

72 self.steps = list(self.steps) 

73 self._validate_steps() 

74 fit_params_steps = self._get_fit_params_steps(fit_params) 

75 if not MLCache.has_cache(self.cache_name): 

76 self.cache_ = MLCache.create_cache(self.cache_name) 

77 else: 

78 self.cache_ = MLCache.get_cache(self.cache_name) 

79 Xt = X 

80 for (step_idx, name, transformer) in self._iter( 

81 with_final=False, filter_passthrough=False): 

82 if (transformer is None or transformer == 'passthrough'): 

83 with _print_elapsed_time('Pipeline', self._log_message(step_idx)): 

84 continue 

85 

86 params = transformer.get_params() 

87 params['__class__'] = transformer.__class__.__name__ 

88 params['X'] = Xt 

89 if ((hasattr(transformer, 'is_classifier') and transformer.is_classifier()) or 

90 (hasattr(transformer, 'is_regressor') and transformer.is_regressor())): 

91 params['y'] = y 

92 cached = self.cache_.get(params) 

93 if cached is None: 

94 cloned_transformer = clone(transformer) 

95 Xt, fitted_transformer = _fit_transform_one( 

96 cloned_transformer, Xt, y, None, 

97 message_clsname='PipelineCache', 

98 message=self._log_message(step_idx), 

99 **fit_params_steps[name]) 

100 self.cache_.cache(params, fitted_transformer) 

101 else: 

102 fitted_transformer = cached 

103 Xt = fitted_transformer.transform(Xt) 

104 

105 self.steps[step_idx] = (name, fitted_transformer) 

106 if isskl023(): 

107 return Xt 

108 if self._final_estimator == 'passthrough': 

109 return Xt, {} 

110 return Xt, fit_params_steps[self.steps[-1][0]]