Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- coding: utf-8 -*-
2"""
3@file
4@brief Implements a *learner* or a *transform* which follows the same API
5as every :epkg:`scikit-learn` transform.
6"""
7import textwrap
8import warnings
9from .sklearn_parameters import SkLearnParameters
12class SkBase:
13 """
14 Pattern of a *learner* or a *transform* which follows the API
15 of :epkg:`scikit-learn`.
16 """
18 def __init__(self, **kwargs):
19 """
20 Stores the parameters, see
21 @see cl SkLearnParameters, it keeps a copy of
22 the parameters to easily implements method *get_params*
23 and clones a model.
24 """
25 self.P = SkLearnParameters(**kwargs)
27 def fit(self, X, y=None, sample_weight=None):
28 """
29 Trains a model.
31 @param X features
32 @param y target
33 @param sample_weight weight
34 @return self
35 """
36 raise NotImplementedError() # pragma: no cover
38 def get_params(self, deep=True):
39 """
40 Returns the parameters which define the objet,
41 all are needed to clone the object.
43 @param deep unused here
44 @return dict
45 """
46 return self.P.to_dict()
48 def set_params(self, **values):
49 """
50 Udpates parameters which define the object,
51 all needed to clone the object.
53 @param values values
54 @return dictionary
55 """
56 self.P = SkLearnParameters(**values)
57 return self
59 def __eq__(self, o):
60 """
61 Compares two objects, more precisely,
62 compares the parameters which define the object.
63 """
64 return self.test_equality(o, False)
66 def test_equality(self, o, exc=True):
67 """
68 Compares two objects and checks parameters have
69 the same values.
71 @param p1 dictionary
72 @param p2 dictionary
73 @param exc raises an exception if there is a difference
74 @return boolean
75 """
76 if self.__class__ != o.__class__:
77 return False
78 p1 = self.get_params()
79 p2 = o.get_params()
80 return SkBase.compare_params(p1, p2, exc=exc)
82 @staticmethod
83 def compare_params(p1, p2, exc=True):
84 """
85 Compares two sets of parameters.
87 @param p1 dictionary
88 @param p2 dictionary
89 @param exc raises an exception if error is met
90 @return boolean
91 """
92 if p1 == p2:
93 return True
94 for k in p1:
95 if k not in p2:
96 if exc:
97 raise KeyError("Key '{0}' was removed.".format(k))
98 else:
99 return False
100 for k in p2:
101 if k not in p1:
102 if exc:
103 raise KeyError("Key '{0}' was added.".format(k))
104 return False
105 for k in sorted(p1):
106 v1, v2 = p1[k], p2[k]
107 if hasattr(v1, 'test_equality'):
108 b = v1.test_equality(v2, exc=exc)
109 if exc and v1 is not v2:
110 warnings.warn( # pragma: no cover
111 "v2 is a clone of v1 not v1 itself for key '{0}' and class {1}."
112 "".format(k, type(v1)))
113 elif isinstance(v1, list) and isinstance(v2, list) and len(v1) == len(v2):
114 b = True
115 for e1, e2 in zip(v1, v2):
116 if hasattr(e1, 'test_equality'):
117 b = e1.test_equality(e2, exc=exc)
118 if not b:
119 return b
120 elif isinstance(v1, dict) and isinstance(v2, dict) and set(v1) == set(v2):
121 b = True
122 for e1, e2 in zip(sorted(v1.items()), sorted(v2.items())):
123 if hasattr(e1[1], 'test_equality'):
124 b = e1[1].test_equality(e2[1], exc=exc)
125 if not b:
126 return b
127 elif e1[1] != e2[1]:
128 return False
129 elif hasattr(v1, "get_params") and hasattr(v2, "get_params"):
130 b = SkBase.compare_params(v1.get_params(
131 deep=False), v2.get_params(deep=False), exc=exc)
132 else:
133 b = v1 == v2
134 if not b:
135 if exc:
136 raise ValueError(
137 "Values for key '{0}' are different.\n---\n{1}\n---\n{2}".format(k, v1, v2))
138 else:
139 return False
140 return True
142 def __repr__(self):
143 """
144 usual
145 """
146 res = "{0}({1})".format(self.__class__.__name__, str(self.P))
147 return "\n".join(textwrap.wrap(res, subsequent_indent=" "))