Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- coding: utf-8 -*- 

2""" 

3@file 

4@brief Implements a *learner* or a *transform* which follows the same API 

5as every :epkg:`scikit-learn` transform. 

6""" 

7import textwrap 

8import warnings 

9from .sklearn_parameters import SkLearnParameters 

10 

11 

12class SkBase: 

13 """ 

14 Pattern of a *learner* or a *transform* which follows the API 

15 of :epkg:`scikit-learn`. 

16 """ 

17 

18 def __init__(self, **kwargs): 

19 """ 

20 Stores the parameters, see 

21 @see cl SkLearnParameters, it keeps a copy of 

22 the parameters to easily implements method *get_params* 

23 and clones a model. 

24 """ 

25 self.P = SkLearnParameters(**kwargs) 

26 

27 def fit(self, X, y=None, sample_weight=None): 

28 """ 

29 Trains a model. 

30 

31 @param X features 

32 @param y target 

33 @param sample_weight weight 

34 @return self 

35 """ 

36 raise NotImplementedError() # pragma: no cover 

37 

38 def get_params(self, deep=True): 

39 """ 

40 Returns the parameters which define the objet, 

41 all are needed to clone the object. 

42 

43 @param deep unused here 

44 @return dict 

45 """ 

46 return self.P.to_dict() 

47 

48 def set_params(self, **values): 

49 """ 

50 Udpates parameters which define the object, 

51 all needed to clone the object. 

52 

53 @param values values 

54 @return dictionary 

55 """ 

56 self.P = SkLearnParameters(**values) 

57 return self 

58 

59 def __eq__(self, o): 

60 """ 

61 Compares two objects, more precisely, 

62 compares the parameters which define the object. 

63 """ 

64 return self.test_equality(o, False) 

65 

66 def test_equality(self, o, exc=True): 

67 """ 

68 Compares two objects and checks parameters have 

69 the same values. 

70 

71 @param p1 dictionary 

72 @param p2 dictionary 

73 @param exc raises an exception if there is a difference 

74 @return boolean 

75 """ 

76 if self.__class__ != o.__class__: 

77 return False 

78 p1 = self.get_params() 

79 p2 = o.get_params() 

80 return SkBase.compare_params(p1, p2, exc=exc) 

81 

82 @staticmethod 

83 def compare_params(p1, p2, exc=True): 

84 """ 

85 Compares two sets of parameters. 

86 

87 @param p1 dictionary 

88 @param p2 dictionary 

89 @param exc raises an exception if error is met 

90 @return boolean 

91 """ 

92 if p1 == p2: 

93 return True 

94 for k in p1: 

95 if k not in p2: 

96 if exc: 

97 raise KeyError("Key '{0}' was removed.".format(k)) 

98 else: 

99 return False 

100 for k in p2: 

101 if k not in p1: 

102 if exc: 

103 raise KeyError("Key '{0}' was added.".format(k)) 

104 return False 

105 for k in sorted(p1): 

106 v1, v2 = p1[k], p2[k] 

107 if hasattr(v1, 'test_equality'): 

108 b = v1.test_equality(v2, exc=exc) 

109 if exc and v1 is not v2: 

110 warnings.warn( # pragma: no cover 

111 "v2 is a clone of v1 not v1 itself for key '{0}' and class {1}." 

112 "".format(k, type(v1))) 

113 elif isinstance(v1, list) and isinstance(v2, list) and len(v1) == len(v2): 

114 b = True 

115 for e1, e2 in zip(v1, v2): 

116 if hasattr(e1, 'test_equality'): 

117 b = e1.test_equality(e2, exc=exc) 

118 if not b: 

119 return b 

120 elif isinstance(v1, dict) and isinstance(v2, dict) and set(v1) == set(v2): 

121 b = True 

122 for e1, e2 in zip(sorted(v1.items()), sorted(v2.items())): 

123 if hasattr(e1[1], 'test_equality'): 

124 b = e1[1].test_equality(e2[1], exc=exc) 

125 if not b: 

126 return b 

127 elif e1[1] != e2[1]: 

128 return False 

129 elif hasattr(v1, "get_params") and hasattr(v2, "get_params"): 

130 b = SkBase.compare_params(v1.get_params( 

131 deep=False), v2.get_params(deep=False), exc=exc) 

132 else: 

133 b = v1 == v2 

134 if not b: 

135 if exc: 

136 raise ValueError( 

137 "Values for key '{0}' are different.\n---\n{1}\n---\n{2}".format(k, v1, v2)) 

138 else: 

139 return False 

140 return True 

141 

142 def __repr__(self): 

143 """ 

144 usual 

145 """ 

146 res = "{0}({1})".format(self.__class__.__name__, str(self.P)) 

147 return "\n".join(textwrap.wrap(res, subsequent_indent=" "))