Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Implements *LassoRandomForestRegressor*. 

4""" 

5import numpy 

6from sklearn.base import BaseEstimator, RegressorMixin, clone 

7from sklearn.ensemble import RandomForestRegressor 

8from sklearn.linear_model import Lasso 

9 

10 

11class LassoRandomForestRegressor(BaseEstimator, RegressorMixin): 

12 """ 

13 Fits a random forest and then selects trees by using a 

14 Lasso regression. The traning produces the following 

15 attributes: 

16 

17 * `rf_estimator_`: trained random forest 

18 * `lasso_estimator_`: trained Lasso 

19 * `estimators_`: trained estimators mapped to a not null coefficients 

20 * `intercept_`: bias 

21 * `coef_`: estimators weights 

22 """ 

23 

24 def __init__(self, rf_estimator=None, lasso_estimator=None): 

25 """ 

26 @param rf_estimator random forest estimator, 

27 :epkg:`sklearn:ensemble:RandomForestRegressor` 

28 by default 

29 @param lass_estimator Lasso estimator, 

30 :epkg:`sklearn:linear_model:LassoRegression` 

31 by default 

32 """ 

33 BaseEstimator.__init__(self) 

34 RegressorMixin.__init__(self) 

35 if rf_estimator is None: 

36 rf_estimator = RandomForestRegressor() 

37 if lasso_estimator is None: 

38 lasso_estimator = Lasso() 

39 self.rf_estimator = rf_estimator 

40 self.lasso_estimator = lasso_estimator 

41 

42 def fit(self, X, y, sample_weight=None): 

43 """ 

44 Fits the random forest first, then applies a lasso 

45 and finally removes all trees mapped to a null coefficient. 

46 

47 @param X training features 

48 @param y training labels 

49 @param sample_weight sample weights 

50 """ 

51 self.rf_estimator_ = clone(self.rf_estimator) 

52 self.rf_estimator_.fit(X, y, sample_weight) 

53 

54 estims = self.rf_estimator_.estimators_ 

55 estimators = numpy.array(estims).ravel() 

56 X2 = numpy.zeros((X.shape[0], len(estimators))) 

57 for i, est in enumerate(estimators): 

58 pred = est.predict(X) 

59 X2[:, i] = pred 

60 

61 self.lasso_estimator_ = clone(self.lasso_estimator) 

62 self.lasso_estimator_.fit(X2, y) 

63 

64 not_null = self.lasso_estimator_.coef_ != 0 

65 self.intercept_ = self.lasso_estimator_.intercept_ 

66 self.estimators_ = estimators[not_null] 

67 self.coef_ = self.lasso_estimator_.coef_[not_null] 

68 return self 

69 

70 def decision_function(self, X): 

71 """ 

72 Computes the predictions. 

73 """ 

74 prediction = None 

75 for i, est in enumerate(self.estimators_): 

76 pred = est.predict(X) 

77 if prediction is None: 

78 prediction = pred * self.coef_[i] 

79 else: 

80 prediction += pred * self.coef_[i] 

81 return prediction + self.intercept_ 

82 

83 def predict(self, X): 

84 """ 

85 Computes the predictions. 

86 """ 

87 return self.decision_function(X)