Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief CorrPlot functionalities.
5It comes from `corrplot.py <https://raw.githubusercontent.com/biokit/biokit/master/biokit/viz/corrplot.py>`_
6which I copied here because the module does not properly work on Python 3 (import issues).
7See also `biokit license <https://github.com/biokit/biokit/blob/master/LICENSE>`_.
9:author: Thomas Cokelaer
10:references: http://cran.r-project.org/web/packages/corrplot/vignettes/corrplot-intro.html
11"""
12import numpy
13from scipy.cluster.hierarchy import dendrogram, fcluster
14import pandas
15from .linkage import Linkage
16from ._colormap import cmap_builder
19class Corrplot(Linkage):
20 """
21 An implementation of correlation plotting tools (corrplot).
22 the class requires `scipy <http://www.scipy.org/>`_.
24 Here is a simple example with a correlation matrix as an input (stored in
25 a pandas dataframe):
27 .. plot::
28 :width: 50%
29 :include-source:
31 import pandas
32 import numpy
33 letters = "ABCDEFGHIJKLM"[0:10]
34 df = pandas.DataFrame(dict(( (k, numpy.random.random(10)+ord(k)-65) for k in letters)))
36 import matplotlib.pyplot as plt
37 plt.close('all')
38 plt.style.use('ggplot')
40 from pyensae.graph_helper import Corrplot
41 c = Corrplot(df)
42 c.plot()
43 plt.show()
45 This class requires module `colormap <https://pypi.python.org/pypi/colormap>`_.
46 """
48 def __init__(self, data, na=0):
49 """
50 Plots the content of square matrix that contains correlation values.
52 :param data: input can be a dataframe (Pandas), or list of lists (python) or
53 a numpy matrix. Note, however, that values must be between -1 and 1. If not,
54 or if the matrix (or list of lists) is not squared, then correlation is
55 computed. The data or computed correlation is stored in :attr:`df` attribute.
56 :param na: replace NA values with this value (default 0)
58 The :attr:`params` contains some tunable parameters for the colorbar in the
59 :meth:`plot` method.
61 ::
63 # can be a list of lists, the correlation matrix is then a 2x2 matrix
64 c = Corrplot([[1,1], [2,4], [3,3], [4,4]])
66 """
67 super(Corrplot, self).__init__() # pylint: disable=R1725
69 # we delay import in case this is not needed
70 self.cmap_builder = cmap_builder
72 self.df = pandas.DataFrame(data, copy=True)
74 compute_correlation = False
76 w, h = self.df.shape
77 if self.df.max().max() > 1 or self.df.min().min() < -1:
78 compute_correlation = True
79 if w != h:
80 compute_correlation = True
81 if list(self.df.index) != list(self.df.columns):
82 compute_correlation = True
84 if compute_correlation:
85 cor = self.df.corr()
86 self.df = cor
88 # replace NA with zero
89 self.df.fillna(na, inplace=True)
91 #: tunable parameters for the :meth:`plot` method.
92 self.params = {
93 'colorbar.N': 100,
94 'colorbar.shrink': .8,
95 'colorbar.orientation': 'vertical'}
97 def _set_default_cmap(self):
98 self.cm = self.cmap_builder('#AA0000', 'white', 'darkblue')
100 def order(self, method='complete', metric='euclidean', inplace=False):
101 """
102 Rearranges the order of rows and columns after clustering.
104 :param method: any scipy method (e.g., single, average, centroid,
105 median, ward). See scipy.cluster.hierarchy.linkage
106 :param metric: any scipy distance (euclidean, hamming, jaccard)
107 See scipy.spatial.distance or scipy.cluster.hieararchy
108 :param bool inplace: if set to True, the dataframe is replaced
110 You probably do not need to use that method. Use :meth:`plot` and
111 the two parameters order_metric and order_method instead.
112 """
113 if None is method or None is metric:
114 return self.df
115 Y = self.linkage(self.df, method=method, metric=metric)
116 ind1 = fcluster(Y, 0.7 * max(Y[:, 2]), 'distance')
117 Z = dendrogram(Y, no_plot=True)
118 idx1 = Z['leaves']
119 cor2 = self.df.iloc[idx1].T.iloc[idx1].T
120 if inplace is True:
121 self.df = cor2
122 else:
123 return cor2
124 self.Y = Y
125 self.Z = Z
126 self.idx1 = idx1
127 self.ind1 = ind1
128 return cor2
130 def plot(self, fig=None, grid=True,
131 rotation=30, lower=None, upper=None,
132 shrink=0.9, axisbg='white', colorbar=True, label_color='black',
133 fontsize='small', edgecolor='black', method='ellipse',
134 order_method='complete', order_metric='euclidean', cmap=None,
135 ax=None, binarise_color=False, figsize=None):
136 """
137 Plots the correlation matrix from the content of :attr:`df`
138 (dataframe).
140 By default, the correlation is shown on the upper and lower triangle and is
141 symmetric wrt to the diagonal. The symbols are ellipses. The symbols can
142 be changed to e.g. rectangle. The symbols are shown on upper and lower sides but
143 you could choose a symbol for the upper side and another for the lower side using
144 the **lower** and **upper** parameters.
146 :param fig: Create a new figure by default. If an instance of an existing
147 figure is provided, the corrplot is overlayed on the figure provided.
148 Can also be the number of the figure.
149 :param grid: add grid (Defaults to grey color). You can set it to False or a color.
150 :param rotation: rotate labels on y-axis
151 :param lower: if set to a valid method, plots the data on the lower
152 left triangle
153 :param upper: if set to a valid method, plots the data on the upper
154 left triangle
155 :param float shrink: maximum space used (in percent) by a symbol.
156 If negative values are provided, the absolute value is taken.
157 If greater than 1, the symbols wiill overlap.
158 :param axisbg: color of the background (defaults to white).
159 :param colorbar: add the colorbar (defaults to True).
160 :param str label_color: (defaults to black).
161 :param fontsize: size of the fonts defaults to 'small'.
162 :param method: shape to be used in 'ellipse', 'square', 'rectangle',
163 'color', 'text', 'circle', 'number', 'pie'.
165 :param order_method: see :meth:`order`.
166 :param order_metric: see : meth:`order`.
167 :param cmap: a valid cmap from matplotlib or colormap package (e.g.,
168 'jet', or 'copper'). Default is red/white/blue colors.
169 :param binarise_color: two colors only, negative, positive
170 :param ax: a matplotlib axes.
171 :param figsize: gives that parameter to the new created figure
172 :return: ax (matplotlib axes)
174 The colorbar can be tuned with the parameters stored in :attr:`params`.
175 Here is an example. See notebook for other examples:
177 ::
179 c = corrplot.Corrplot(dataframe)
180 c.plot(cmap=('Orange', 'white', 'green'))
181 c.plot(method='circle')
182 c.plot(colorbar=False, shrink=.8, upper='circle' )
183 """
184 import matplotlib.pyplot as plt # pylint: disable=C0415
186 # default
187 if cmap is not None:
188 try:
189 if isinstance(cmap, str):
190 self.cm = self.cmap_builder(cmap)
191 else:
192 self.cm = self.cmap_builder(*cmap)
193 except Exception:
194 self._set_default_cmap()
195 else:
196 self._set_default_cmap()
198 self.shrink = abs(shrink)
199 self.fontsize = fontsize
200 self.edgecolor = edgecolor
202 df = self.order(method=order_method, metric=order_metric)
204 # figure can be a number or an instance; otherwise creates it
205 params = dict(facecolor=axisbg)
206 if isinstance(fig, int):
207 params["num"] = fig.number
208 elif fig is not None:
209 params["num"] = fig.number
210 else:
211 params["num"] = None
212 if figsize is not None:
213 params["figsize"] = figsize
214 fig = plt.figure(**params)
216 # do we have an axes to plot the data in ?
217 if ax is None:
218 ax = plt.subplot(1, 1, 1, aspect='equal', facecolor=axisbg)
219 else:
220 # if so, clear the axes. Colorbar cannot be removed easily.
221 plt.sca(ax)
222 ax.clear()
224 # subplot resets the bg color, let us set it again
225 fig.set_facecolor(axisbg)
227 width, height = df.shape
228 labels = (df.columns)
230 if upper is None and lower is None:
231 mode = 'method'
232 elif upper and lower:
233 mode = 'both'
234 elif lower is not None:
235 mode = 'lower'
236 elif upper is not None:
237 mode = 'upper'
239 self.binarise_color = binarise_color
240 if mode == 'upper':
241 self._add_patches(df, upper, 'upper', ax, diagonal=True)
242 elif mode == 'lower':
243 self._add_patches(df, lower, 'lower', ax, diagonal=True)
244 elif mode == 'method':
245 self._add_patches(df, method, 'both', ax, diagonal=True)
246 elif mode == 'both':
247 self._add_patches(df, upper, 'upper', ax, diagonal=False)
248 self._add_patches(df, lower, 'lower', ax, diagonal=False)
250 # shift the limits to englobe the patches correctly
251 ax.set_xlim(-0.5, width - .5)
252 ax.set_ylim(-0.5, height - .5)
254 # set xticks/xlabels on top
255 ax.xaxis.tick_top()
256 xtickslocs = numpy.arange(len(labels))
257 ax.set_xticks(xtickslocs)
258 ax.set_xticklabels(labels, rotation=rotation, color=label_color,
259 fontsize=fontsize, ha='left')
261 ax.invert_yaxis()
262 ytickslocs = numpy.arange(len(labels))
263 ax.set_yticks(ytickslocs)
264 ax.set_yticklabels(labels, fontsize=fontsize, color=label_color)
265 plt.tight_layout()
267 if grid is not False:
268 if grid is True:
269 grid = 'grey'
270 for i in range(0, width):
271 ratio1 = float(i) / width
272 ratio2 = float(i + 2) / width
273 # set axis off
274 # 2 - set xlabels along the diagonal
275 # set colorbar either on left or bottom
276 if mode == 'lower':
277 ax.axvline(i + .5, ymin=1 - ratio1, ymax=0., color=grid)
278 ax.axhline(i + .5, xmin=0, xmax=ratio2, color=grid)
279 if mode == 'upper':
280 ax.axvline(i + .5, ymin=1 - ratio2, ymax=1, color=grid)
281 ax.axhline(i + .5, xmin=ratio1, xmax=1, color=grid)
282 if mode in ['method', 'both']:
283 ax.axvline(i + .5, color=grid)
284 ax.axhline(i + .5, color=grid)
286 # can probably be simplified
287 if mode == 'lower':
288 ax.axvline(-.5, ymin=0, ymax=1, color='grey')
289 ax.axvline(width - .5, ymin=0, ymax=1. /
290 width, color='grey', lw=2)
291 ax.axhline(width - .5, xmin=0, xmax=1, color='grey', lw=2)
292 ax.axhline(-.5, xmin=0, xmax=1. / width, color='grey', lw=2)
293 ax.xticks([])
294 for i in range(0, width):
295 ax.text(i, i - .6, labels[i], fontsize=fontsize,
296 color=label_color,
297 rotation=rotation, verticalalignment='bottom')
298 ax.text(-.6, i, labels[i], fontsize=fontsize,
299 color=label_color,
300 rotation=0, horizontalalignment='right')
301 ax.set_axis_off()
302 # can probably be simplified
303 elif mode == 'upper':
304 ax.axvline(width - .5, ymin=0, ymax=1, color='grey', lw=2)
305 ax.axvline(-.5, ymin=1 - 1. / width,
306 ymax=1, color='grey', lw=2)
307 ax.axhline(-.5, xmin=0, xmax=1, color='grey', lw=2)
308 ax.axhline(width - .5, xmin=1 - 1. / width,
309 xmax=1, color='grey', lw=2)
310 ax.yticks([])
311 for i in range(0, width):
312 ax.text(-.6 + i, i, labels[i], fontsize=fontsize,
313 color=label_color, horizontalalignment='right',
314 rotation=0)
315 ax.text(i, -.5, labels[i], fontsize=fontsize,
316 color=label_color, rotation=rotation, verticalalignment='bottom')
317 ax.set_axis_off()
319 # set all ticks length to zero
320 ax = plt.gca()
321 ax.tick_params(axis='both', which='both', length=0)
323 if colorbar:
324 N = self.params['colorbar.N'] + 1
325 if N < 2:
326 raise RuntimeError("No colorbar to draw.")
327 cb = plt.gcf().colorbar(
328 self.collection, orientation=self.params['colorbar.orientation'],
329 shrink=self.params['colorbar.shrink'],
330 boundaries=numpy.linspace(0, 1, N),
331 ticks=[0, .25, 0.5, 0.75, 1])
332 cb.ax.set_yticklabels([-1, -.5, 0, .5, 1])
333 # make sure it goes from -1 to 1 even though actual values may not
334 # reach that range
335 # cb.set_clim(0, 1)
336 # not working in matplotlib 3.3.0
338 return ax
340 def _add_patches(self, df, method, fill, ax, diagonal=True):
342 from matplotlib.patches import Ellipse, Circle, Rectangle, Wedge
343 from matplotlib.collections import PatchCollection
345 width, height = df.shape
347 patches = []
348 colors = []
349 for x in range(width):
350 for y in range(height):
351 if fill == 'lower' and x > y:
352 continue
353 if fill == 'upper' and x < y:
354 continue
355 if diagonal is False and x == y:
356 continue
357 datum = (df.iloc[x, y] + 1.) / 2.
358 d = df.iloc[x, y]
359 d_abs = numpy.abs(d)
360 #c = self.pvalues[x, y]
361 rotate = -45 if d > 0 else +45
362 #cmap = self.poscm if d >= 0 else self.negcm
363 if method in ['ellipse', 'square', 'rectangle', 'color']:
364 if method == 'ellipse':
365 func = Ellipse
366 patch = func((x, y), width=1 * self.shrink,
367 height=(self.shrink - d_abs * self.shrink), angle=rotate)
368 else:
369 func = Rectangle
370 w = h = d_abs * self.shrink
371 offset = (1 - w) / 2.
372 if method == 'color':
373 w = 1
374 h = 1
375 offset = 0
376 patch = func((x + offset - .5, y + offset - .5), width=w,
377 height=h, angle=0)
378 if self.edgecolor:
379 patch.set_edgecolor(self.edgecolor)
380 # patch.set_facecolor(cmap(d_abs))
381 colors.append(datum)
382 if d_abs > 0.05:
383 patch.set_linestyle('dotted')
384 # ax.add_artist(patch)
385 patches.append(patch)
386 elif method == 'circle':
387 patch = Circle((x, y), radius=d_abs * self.shrink / 2.)
388 if self.edgecolor:
389 patch.set_edgecolor(self.edgecolor)
390 # patch.set_facecolor(cmap(d_abs))
391 colors.append(datum)
392 if d_abs > 0.05:
393 patch.set_linestyle('dotted')
394 # ax.add_artist(patch)
395 patches.append(patch)
396 elif method in ['number', 'text']:
397 if d < 0:
398 edgecolor = self.cm(-1.0)
399 elif d >= 0:
400 edgecolor = self.cm(1.0)
401 d_str = "{:.2f}".format(d).replace(
402 "0.", ".").replace(".00", "")
403 ax.text(x, y, d_str, color=edgecolor,
404 fontsize=self.fontsize, horizontalalignment='center',
405 weight='bold', alpha=max(0.5, d_abs),
406 withdash=False)
407 elif method == 'pie':
408 S = 360 * d_abs
409 patch = [
410 Wedge((x, y), 1 * self.shrink / 2., -90, S - 90),
411 Wedge((x, y), 1 * self.shrink / 2., S - 90, 360 - 90),
412 ]
413 # patch[0].set_facecolor(cmap(d_abs))
414 # patch[1].set_facecolor('white')
415 colors.append(datum)
416 colors.append(0.5)
417 if self.edgecolor:
418 patch[0].set_edgecolor(self.edgecolor)
419 patch[1].set_edgecolor(self.edgecolor)
421 # ax.add_artist(patch[0])
422 # ax.add_artist(patch[1])
423 patches.append(patch[0])
424 patches.append(patch[1])
425 else:
426 raise ValueError(
427 'Method for the symbols is not known. Use e.g, square, circle')
429 if self.binarise_color:
430 colors = [1 if color > 0.5 else -1 for color in colors]
432 if len(patches):
433 col1 = PatchCollection(
434 patches, array=numpy.array(colors), cmap=self.cm)
435 ax.add_collection(col1)
437 self.collection = col1