Coverage for src/pyensae/mlhelper/table_formula.py: 89%
80 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-03 02:16 +0200
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-03 02:16 +0200
1# -*- coding: utf-8 -*-
2"""
3@file
4@brief Adds functionalities to a dataframe.
5"""
6import datetime
7import pandas
10class TableFormula(pandas.DataFrame): # pylint: disable=W0223
11 """
12 Extends class :epkg:`pandas:DataFrame` or proposes extensions
13 to existing functions using lambda functions.
14 See `Extending Pandas
15 <https://pandas.pydata.org/pandas-docs/
16 stable/development/extending.html>`_.
17 """
19 @property
20 def _constructor(self):
21 return TableFormula
23 def sort(self, function_sort, reverse=False):
24 """
25 Sorts rows based on the values returned by *function_sort*.
27 @param function_sort lambda function
28 @param reverse reverse order
30 The function creates a column ``__key__`` and removes it later.
31 The changes happen inplace.
32 """
33 if "__key__" in self.columns:
34 raise ValueError(
35 "__key__ cannot be used in the original dataframe.")
36 self["__key__"] = self.apply(lambda row: function_sort(row), axis=1)
37 self.sort_values("__key__", inplace=True, ascending=not reverse)
38 self.drop("__key__", inplace=True, axis=1)
40 def fgroupby(self, function_key, function_values, columns=None,
41 function_agg=None, function_weight=None):
42 """
43 Groups information based on columns defined by lambda functions.
45 @param function_key defines the key
46 @param function_values defines the values
47 @param columns name of the columns, if None, new ones will be created
48 @param function_agg how to aggregate the data, if None, the default is
49 :epkg:`pandas:DataFrame:sum`.
50 @param function_weight defines weights, can be None
52 The function uses columns ``__key__``, ``__weight__``.
53 You should not use these names.
54 Others columns are created ``__value_{0}__`` and
55 ``__weight_{0}__``. All of them are created and removed
56 before returning the result.
58 Example:
60 ::
62 group = table.groupby(lambda v: v["name"],
63 [lambda v: v["d_a"]],
64 ["sum_d_a"],
65 [lambda vec, w: sum(vec) / w],
66 lambda v: v["d_b"])
67 """
68 if "__key__" in self.columns:
69 raise ValueError(
70 "__key__ cannot be used in the original dataframe.")
71 if "__weight__" in self.columns:
72 raise ValueError(
73 "__weight__ cannot be used in the original dataframe.")
75 cp = self.copy()
76 cp["__key__"] = cp.apply(lambda row: function_key(row), axis=1)
77 if function_weight is not None:
78 cp["__weight__"] = cp.apply(
79 lambda row: function_weight(row), axis=1)
81 if columns is None:
82 columns = ["fv{0}" for i in range(len(function_values))]
83 if len(columns) != len(function_values):
84 raise ValueError(
85 "Parameters function_values and columns must have the same size.")
86 if function_agg is None:
87 function_agg = [pandas.DataFrame.sum for c in columns]
88 if len(function_agg) != len(function_values):
89 raise ValueError(
90 "Parameters function_values and function_agg must have the same size.")
92 values = []
93 rep = dict()
94 for v, cnew in zip(function_values, columns):
95 n = "__value_{0}__".format(cnew)
96 values.append(n)
97 rep[n] = cnew
98 if function_weight is None:
99 cp[n] = cp.apply(lambda row, v=v: v(row), axis=1)
100 else:
101 cp[n] = cp.apply(lambda row, v=v: v(
102 row), axis=1) * cp["__weight__"]
104 if function_weight is None:
105 aggs = {k: v for k, v in zip( # pylint: disable=R1721
106 values, function_agg)} # pylint: disable=R1721
107 gr = cp.groupby("__key__", as_index=False).agg(aggs)
108 else:
109 sum_weight = cp["__weight__"].sum()
110 aggs = {k: (lambda c, v=v: v(c, sum_weight)) # pylint: disable=W0631
111 for k, v in zip(values, function_agg)}
112 gr = cp.groupby("__key__", as_index=False).agg(aggs)
113 gr.columns = [rep.get(_, _) for _ in gr.columns]
114 gr = gr.drop("__key__", axis=1)
115 return TableFormula(gr)
117 def add_column_index(self, index, name=None):
118 """
119 Changes the index.
121 @param index new_index
122 @param name name of the index
124 The changes happen inplace.
125 """
126 self["__key__"] = index
127 self.set_index("__key__", inplace=True)
128 self.index.rename(name, inplace=True)
130 def add_column_vector(self, name, values):
131 """
132 Adds a column knowing its name and a vector of values.
134 @param name name of the column
135 @param values values
137 The changes happen inplace.
138 """
139 self[name] = values
141 def addc(self, name, function_value):
142 """
143 Adds a column knowing its name and a lambda function.
145 @param name name of the column
146 @param function_value function
148 The changes happen inplace.
149 """
150 self[name] = self.apply(lambda row: function_value(row), axis=1)
152 def graph_XY(self, curves, xlabel=None, ylabel=None, marker=True,
153 link_point=False, title=None, format_date="%Y-%m-%d",
154 legend_loc=0, figsize=None, ax=None):
155 """
156 @param curves list of 3-uples (generator for X, generator for Y, label)
157 for some layout, it can also be:
158 (generator for X, generator for Y, generator for labels, label)
159 @param xlabel label for X axis
160 @param ylabel label for Y axis
161 @param marker add a marker for each point
162 @param link_point link points between them
163 @param title graph title
164 @param format_date if X axis is a datetime object, the function will use this format
165 to print dates
166 @param legend_loc location of the legend
167 @param figsize size of the figure
168 @param ax :epkg:`matplotlib:Axis` or None to create a new one
169 @return :epkg:`matplotlib:Axis`
171 For the legend position, see `matplotlib <http://matplotlib.org/api/legend_api.html>`_.
173 Example:
175 ::
177 table.graph_XY ( [ [ lambda v: v["sum_a"], lambda v: v["sum_b"], "xy label 1"],
178 [ lambda v: v["sum_b"], lambda v: v["sum_c"], "xy label 2"],
179 ])
180 """
181 if ax is None:
182 import matplotlib.pyplot as plt # pylint: disable=C0415
183 fig, ax = plt.subplots(1, 1, figsize=figsize)
185 smarker = {(True, True): 'o-', (True, False): 'o', (False, True): '-',
186 # (False, False) :''
187 }[marker, link_point]
189 has_date = False
190 for xf, yf, label in curves:
191 x = self.apply(xf, axis=1)
192 y = self.apply(yf, axis=1)
193 if isinstance(x[0], datetime.datetime):
194 import matplotlib.dates # pylint: disable=C0415
195 x = [matplotlib.dates.date2num(d) for d in x]
196 has_date = True
197 ax.plot(x, y, smarker, label=label)
199 if has_date:
200 import matplotlib.dates # pylint: disable=C0415
201 hfmt = matplotlib.dates.DateFormatter(format_date)
202 if "%H" in format_date or "%M" in format_date:
203 ax.xaxis.set_major_locator(matplotlib.dates.MinuteLocator())
204 ax.xaxis.set_major_formatter(hfmt)
205 fig = ax.get_figure()
206 fig.autofmt_xdate()
208 ax.legend(loc=legend_loc)
209 return ax