Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Caches training.
4"""
5from distutils.version import StrictVersion
6from sklearn import __version__ as skl_version
7from sklearn.base import clone
8from sklearn.pipeline import Pipeline, _fit_transform_one
9from sklearn.utils import _print_elapsed_time
10from .cache_model import MLCache
13def isskl023():
14 "Tells if :epkg:`scikit-learn` is more recent than 0.23."
15 v1 = ".".join(skl_version.split('.')[:2])
16 return StrictVersion(v1) >= StrictVersion('0.23')
19class PipelineCache(Pipeline):
20 """
21 Same as :epkg:`sklearn:pipeline:Pipeline` but it can
22 skip training if it detects a step was already trained
23 the model was already trained accross
24 even in a different pipeline.
26 :param steps: list
27 List of (name, transform) tuples (implementing fit/transform) that are
28 chained, in the order in which they are chained, with the last object
29 an estimator.
30 :param cache_name: name of the cache, if None, a new name is created
31 :param verbose: boolean, optional
32 If True, the time elapsed while fitting each step will be printed as it
33 is completed.
35 Other attributes:
37 :param named_steps: bunch object, a dictionary with attribute access
38 Read-only attribute to access any step parameter by user given name.
39 Keys are step names and values are steps parameters.
40 """
42 def __init__(self, steps, cache_name=None, verbose=False):
43 self.cache_name = cache_name
44 Pipeline.__init__(self, steps, memory=None, verbose=verbose)
45 if cache_name is None:
46 cache_name = "Pipeline%d" % id(self)
47 self.cache_name = cache_name
49 def _get_fit_params_steps(self, fit_params):
50 fit_params_steps = {name: {} for name, step in self.steps
51 if step is not None}
53 for pname, pval in fit_params.items():
54 if '__' not in pname:
55 if not isinstance(pval, dict):
56 raise ValueError( # pragma: no cover
57 "For scikit-learn < 0.23, "
58 "Pipeline.fit does not accept the {} parameter. "
59 "You can pass parameters to specific steps of your "
60 "pipeline using the stepname__parameter format, e.g. "
61 "`Pipeline.fit(X, y, logisticregression__sample_weight"
62 "=sample_weight)`.".format(pname))
63 else:
64 fit_params_steps[pname].update(pval)
65 else:
66 step, param = pname.split('__', 1)
67 fit_params_steps[step][param] = pval
68 return fit_params_steps
70 def _fit(self, X, y=None, **fit_params):
72 self.steps = list(self.steps)
73 self._validate_steps()
74 fit_params_steps = self._get_fit_params_steps(fit_params)
75 if not MLCache.has_cache(self.cache_name):
76 self.cache_ = MLCache.create_cache(self.cache_name)
77 else:
78 self.cache_ = MLCache.get_cache(self.cache_name)
79 Xt = X
80 for (step_idx, name, transformer) in self._iter(
81 with_final=False, filter_passthrough=False):
82 if (transformer is None or transformer == 'passthrough'):
83 with _print_elapsed_time('Pipeline', self._log_message(step_idx)):
84 continue
86 params = transformer.get_params()
87 params['__class__'] = transformer.__class__.__name__
88 params['X'] = Xt
89 if ((hasattr(transformer, 'is_classifier') and transformer.is_classifier()) or
90 (hasattr(transformer, 'is_regressor') and transformer.is_regressor())):
91 params['y'] = y
92 cached = self.cache_.get(params)
93 if cached is None:
94 cloned_transformer = clone(transformer)
95 Xt, fitted_transformer = _fit_transform_one(
96 cloned_transformer, Xt, y, None,
97 message_clsname='PipelineCache',
98 message=self._log_message(step_idx),
99 **fit_params_steps[name])
100 self.cache_.cache(params, fitted_transformer)
101 else:
102 fitted_transformer = cached
103 Xt = fitted_transformer.transform(Xt)
105 self.steps[step_idx] = (name, fitted_transformer)
106 if isskl023():
107 return Xt
108 if self._final_estimator == 'passthrough':
109 return Xt, {}
110 return Xt, fit_params_steps[self.steps[-1][0]]