Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Implements *LassoRandomForestRegressor*.
4"""
5import numpy
6from sklearn.base import BaseEstimator, RegressorMixin, clone
7from sklearn.ensemble import RandomForestRegressor
8from sklearn.linear_model import Lasso
11class LassoRandomForestRegressor(BaseEstimator, RegressorMixin):
12 """
13 Fits a random forest and then selects trees by using a
14 Lasso regression. The traning produces the following
15 attributes:
17 * `rf_estimator_`: trained random forest
18 * `lasso_estimator_`: trained Lasso
19 * `estimators_`: trained estimators mapped to a not null coefficients
20 * `intercept_`: bias
21 * `coef_`: estimators weights
22 """
24 def __init__(self, rf_estimator=None, lasso_estimator=None):
25 """
26 @param rf_estimator random forest estimator,
27 :epkg:`sklearn:ensemble:RandomForestRegressor`
28 by default
29 @param lass_estimator Lasso estimator,
30 :epkg:`sklearn:linear_model:LassoRegression`
31 by default
32 """
33 BaseEstimator.__init__(self)
34 RegressorMixin.__init__(self)
35 if rf_estimator is None:
36 rf_estimator = RandomForestRegressor()
37 if lasso_estimator is None:
38 lasso_estimator = Lasso()
39 self.rf_estimator = rf_estimator
40 self.lasso_estimator = lasso_estimator
42 def fit(self, X, y, sample_weight=None):
43 """
44 Fits the random forest first, then applies a lasso
45 and finally removes all trees mapped to a null coefficient.
47 @param X training features
48 @param y training labels
49 @param sample_weight sample weights
50 """
51 self.rf_estimator_ = clone(self.rf_estimator)
52 self.rf_estimator_.fit(X, y, sample_weight)
54 estims = self.rf_estimator_.estimators_
55 estimators = numpy.array(estims).ravel()
56 X2 = numpy.zeros((X.shape[0], len(estimators)))
57 for i, est in enumerate(estimators):
58 pred = est.predict(X)
59 X2[:, i] = pred
61 self.lasso_estimator_ = clone(self.lasso_estimator)
62 self.lasso_estimator_.fit(X2, y)
64 not_null = self.lasso_estimator_.coef_ != 0
65 self.intercept_ = self.lasso_estimator_.intercept_
66 self.estimators_ = estimators[not_null]
67 self.coef_ = self.lasso_estimator_.coef_[not_null]
68 return self
70 def decision_function(self, X):
71 """
72 Computes the predictions.
73 """
74 prediction = None
75 for i, est in enumerate(self.estimators_):
76 pred = est.predict(X)
77 if prediction is None:
78 prediction = pred * self.coef_[i]
79 else:
80 prediction += pred * self.coef_[i]
81 return prediction + self.intercept_
83 def predict(self, X):
84 """
85 Computes the predictions.
86 """
87 return self.decision_function(X)