Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief CorrPlot functionalities. 

4 

5It comes from `corrplot.py <https://raw.githubusercontent.com/biokit/biokit/master/biokit/viz/corrplot.py>`_ 

6which I copied here because the module does not properly work on Python 3 (import issues). 

7See also `biokit license <https://github.com/biokit/biokit/blob/master/LICENSE>`_. 

8 

9:author: Thomas Cokelaer 

10:references: http://cran.r-project.org/web/packages/corrplot/vignettes/corrplot-intro.html 

11""" 

12import numpy 

13from scipy.cluster.hierarchy import dendrogram, fcluster 

14import pandas 

15from .linkage import Linkage 

16from ._colormap import cmap_builder 

17 

18 

19class Corrplot(Linkage): 

20 """ 

21 An implementation of correlation plotting tools (corrplot). 

22 the class requires `scipy <http://www.scipy.org/>`_. 

23 

24 Here is a simple example with a correlation matrix as an input (stored in 

25 a pandas dataframe): 

26 

27 .. plot:: 

28 :width: 50% 

29 :include-source: 

30 

31 import pandas 

32 import numpy 

33 letters = "ABCDEFGHIJKLM"[0:10] 

34 df = pandas.DataFrame(dict(( (k, numpy.random.random(10)+ord(k)-65) for k in letters))) 

35 

36 import matplotlib.pyplot as plt 

37 plt.close('all') 

38 plt.style.use('ggplot') 

39 

40 from pyensae.graph_helper import Corrplot 

41 c = Corrplot(df) 

42 c.plot() 

43 plt.show() 

44 

45 This class requires module `colormap <https://pypi.python.org/pypi/colormap>`_. 

46 """ 

47 

48 def __init__(self, data, na=0): 

49 """ 

50 Plots the content of square matrix that contains correlation values. 

51 

52 :param data: input can be a dataframe (Pandas), or list of lists (python) or 

53 a numpy matrix. Note, however, that values must be between -1 and 1. If not, 

54 or if the matrix (or list of lists) is not squared, then correlation is 

55 computed. The data or computed correlation is stored in :attr:`df` attribute. 

56 :param na: replace NA values with this value (default 0) 

57 

58 The :attr:`params` contains some tunable parameters for the colorbar in the 

59 :meth:`plot` method. 

60 

61 :: 

62 

63 # can be a list of lists, the correlation matrix is then a 2x2 matrix 

64 c = Corrplot([[1,1], [2,4], [3,3], [4,4]]) 

65 

66 """ 

67 super(Corrplot, self).__init__() # pylint: disable=R1725 

68 

69 # we delay import in case this is not needed 

70 self.cmap_builder = cmap_builder 

71 

72 self.df = pandas.DataFrame(data, copy=True) 

73 

74 compute_correlation = False 

75 

76 w, h = self.df.shape 

77 if self.df.max().max() > 1 or self.df.min().min() < -1: 

78 compute_correlation = True 

79 if w != h: 

80 compute_correlation = True 

81 if list(self.df.index) != list(self.df.columns): 

82 compute_correlation = True 

83 

84 if compute_correlation: 

85 cor = self.df.corr() 

86 self.df = cor 

87 

88 # replace NA with zero 

89 self.df.fillna(na, inplace=True) 

90 

91 #: tunable parameters for the :meth:`plot` method. 

92 self.params = { 

93 'colorbar.N': 100, 

94 'colorbar.shrink': .8, 

95 'colorbar.orientation': 'vertical'} 

96 

97 def _set_default_cmap(self): 

98 self.cm = self.cmap_builder('#AA0000', 'white', 'darkblue') 

99 

100 def order(self, method='complete', metric='euclidean', inplace=False): 

101 """ 

102 Rearranges the order of rows and columns after clustering. 

103 

104 :param method: any scipy method (e.g., single, average, centroid, 

105 median, ward). See scipy.cluster.hierarchy.linkage 

106 :param metric: any scipy distance (euclidean, hamming, jaccard) 

107 See scipy.spatial.distance or scipy.cluster.hieararchy 

108 :param bool inplace: if set to True, the dataframe is replaced 

109 

110 You probably do not need to use that method. Use :meth:`plot` and 

111 the two parameters order_metric and order_method instead. 

112 """ 

113 if None is method or None is metric: 

114 return self.df 

115 Y = self.linkage(self.df, method=method, metric=metric) 

116 ind1 = fcluster(Y, 0.7 * max(Y[:, 2]), 'distance') 

117 Z = dendrogram(Y, no_plot=True) 

118 idx1 = Z['leaves'] 

119 cor2 = self.df.iloc[idx1].T.iloc[idx1].T 

120 if inplace is True: 

121 self.df = cor2 

122 else: 

123 return cor2 

124 self.Y = Y 

125 self.Z = Z 

126 self.idx1 = idx1 

127 self.ind1 = ind1 

128 return cor2 

129 

130 def plot(self, fig=None, grid=True, 

131 rotation=30, lower=None, upper=None, 

132 shrink=0.9, axisbg='white', colorbar=True, label_color='black', 

133 fontsize='small', edgecolor='black', method='ellipse', 

134 order_method='complete', order_metric='euclidean', cmap=None, 

135 ax=None, binarise_color=False, figsize=None): 

136 """ 

137 Plots the correlation matrix from the content of :attr:`df` 

138 (dataframe). 

139 

140 By default, the correlation is shown on the upper and lower triangle and is 

141 symmetric wrt to the diagonal. The symbols are ellipses. The symbols can 

142 be changed to e.g. rectangle. The symbols are shown on upper and lower sides but 

143 you could choose a symbol for the upper side and another for the lower side using 

144 the **lower** and **upper** parameters. 

145 

146 :param fig: Create a new figure by default. If an instance of an existing 

147 figure is provided, the corrplot is overlayed on the figure provided. 

148 Can also be the number of the figure. 

149 :param grid: add grid (Defaults to grey color). You can set it to False or a color. 

150 :param rotation: rotate labels on y-axis 

151 :param lower: if set to a valid method, plots the data on the lower 

152 left triangle 

153 :param upper: if set to a valid method, plots the data on the upper 

154 left triangle 

155 :param float shrink: maximum space used (in percent) by a symbol. 

156 If negative values are provided, the absolute value is taken. 

157 If greater than 1, the symbols wiill overlap. 

158 :param axisbg: color of the background (defaults to white). 

159 :param colorbar: add the colorbar (defaults to True). 

160 :param str label_color: (defaults to black). 

161 :param fontsize: size of the fonts defaults to 'small'. 

162 :param method: shape to be used in 'ellipse', 'square', 'rectangle', 

163 'color', 'text', 'circle', 'number', 'pie'. 

164 

165 :param order_method: see :meth:`order`. 

166 :param order_metric: see : meth:`order`. 

167 :param cmap: a valid cmap from matplotlib or colormap package (e.g., 

168 'jet', or 'copper'). Default is red/white/blue colors. 

169 :param binarise_color: two colors only, negative, positive 

170 :param ax: a matplotlib axes. 

171 :param figsize: gives that parameter to the new created figure 

172 :return: ax (matplotlib axes) 

173 

174 The colorbar can be tuned with the parameters stored in :attr:`params`. 

175 Here is an example. See notebook for other examples: 

176 

177 :: 

178 

179 c = corrplot.Corrplot(dataframe) 

180 c.plot(cmap=('Orange', 'white', 'green')) 

181 c.plot(method='circle') 

182 c.plot(colorbar=False, shrink=.8, upper='circle' ) 

183 """ 

184 import matplotlib.pyplot as plt # pylint: disable=C0415 

185 

186 # default 

187 if cmap is not None: 

188 try: 

189 if isinstance(cmap, str): 

190 self.cm = self.cmap_builder(cmap) 

191 else: 

192 self.cm = self.cmap_builder(*cmap) 

193 except Exception: 

194 self._set_default_cmap() 

195 else: 

196 self._set_default_cmap() 

197 

198 self.shrink = abs(shrink) 

199 self.fontsize = fontsize 

200 self.edgecolor = edgecolor 

201 

202 df = self.order(method=order_method, metric=order_metric) 

203 

204 # figure can be a number or an instance; otherwise creates it 

205 params = dict(facecolor=axisbg) 

206 if isinstance(fig, int): 

207 params["num"] = fig.number 

208 elif fig is not None: 

209 params["num"] = fig.number 

210 else: 

211 params["num"] = None 

212 if figsize is not None: 

213 params["figsize"] = figsize 

214 fig = plt.figure(**params) 

215 

216 # do we have an axes to plot the data in ? 

217 if ax is None: 

218 ax = plt.subplot(1, 1, 1, aspect='equal', facecolor=axisbg) 

219 else: 

220 # if so, clear the axes. Colorbar cannot be removed easily. 

221 plt.sca(ax) 

222 ax.clear() 

223 

224 # subplot resets the bg color, let us set it again 

225 fig.set_facecolor(axisbg) 

226 

227 width, height = df.shape 

228 labels = (df.columns) 

229 

230 if upper is None and lower is None: 

231 mode = 'method' 

232 elif upper and lower: 

233 mode = 'both' 

234 elif lower is not None: 

235 mode = 'lower' 

236 elif upper is not None: 

237 mode = 'upper' 

238 

239 self.binarise_color = binarise_color 

240 if mode == 'upper': 

241 self._add_patches(df, upper, 'upper', ax, diagonal=True) 

242 elif mode == 'lower': 

243 self._add_patches(df, lower, 'lower', ax, diagonal=True) 

244 elif mode == 'method': 

245 self._add_patches(df, method, 'both', ax, diagonal=True) 

246 elif mode == 'both': 

247 self._add_patches(df, upper, 'upper', ax, diagonal=False) 

248 self._add_patches(df, lower, 'lower', ax, diagonal=False) 

249 

250 # shift the limits to englobe the patches correctly 

251 ax.set_xlim(-0.5, width - .5) 

252 ax.set_ylim(-0.5, height - .5) 

253 

254 # set xticks/xlabels on top 

255 ax.xaxis.tick_top() 

256 xtickslocs = numpy.arange(len(labels)) 

257 ax.set_xticks(xtickslocs) 

258 ax.set_xticklabels(labels, rotation=rotation, color=label_color, 

259 fontsize=fontsize, ha='left') 

260 

261 ax.invert_yaxis() 

262 ytickslocs = numpy.arange(len(labels)) 

263 ax.set_yticks(ytickslocs) 

264 ax.set_yticklabels(labels, fontsize=fontsize, color=label_color) 

265 plt.tight_layout() 

266 

267 if grid is not False: 

268 if grid is True: 

269 grid = 'grey' 

270 for i in range(0, width): 

271 ratio1 = float(i) / width 

272 ratio2 = float(i + 2) / width 

273 # set axis off 

274 # 2 - set xlabels along the diagonal 

275 # set colorbar either on left or bottom 

276 if mode == 'lower': 

277 ax.axvline(i + .5, ymin=1 - ratio1, ymax=0., color=grid) 

278 ax.axhline(i + .5, xmin=0, xmax=ratio2, color=grid) 

279 if mode == 'upper': 

280 ax.axvline(i + .5, ymin=1 - ratio2, ymax=1, color=grid) 

281 ax.axhline(i + .5, xmin=ratio1, xmax=1, color=grid) 

282 if mode in ['method', 'both']: 

283 ax.axvline(i + .5, color=grid) 

284 ax.axhline(i + .5, color=grid) 

285 

286 # can probably be simplified 

287 if mode == 'lower': 

288 ax.axvline(-.5, ymin=0, ymax=1, color='grey') 

289 ax.axvline(width - .5, ymin=0, ymax=1. / 

290 width, color='grey', lw=2) 

291 ax.axhline(width - .5, xmin=0, xmax=1, color='grey', lw=2) 

292 ax.axhline(-.5, xmin=0, xmax=1. / width, color='grey', lw=2) 

293 ax.xticks([]) 

294 for i in range(0, width): 

295 ax.text(i, i - .6, labels[i], fontsize=fontsize, 

296 color=label_color, 

297 rotation=rotation, verticalalignment='bottom') 

298 ax.text(-.6, i, labels[i], fontsize=fontsize, 

299 color=label_color, 

300 rotation=0, horizontalalignment='right') 

301 ax.set_axis_off() 

302 # can probably be simplified 

303 elif mode == 'upper': 

304 ax.axvline(width - .5, ymin=0, ymax=1, color='grey', lw=2) 

305 ax.axvline(-.5, ymin=1 - 1. / width, 

306 ymax=1, color='grey', lw=2) 

307 ax.axhline(-.5, xmin=0, xmax=1, color='grey', lw=2) 

308 ax.axhline(width - .5, xmin=1 - 1. / width, 

309 xmax=1, color='grey', lw=2) 

310 ax.yticks([]) 

311 for i in range(0, width): 

312 ax.text(-.6 + i, i, labels[i], fontsize=fontsize, 

313 color=label_color, horizontalalignment='right', 

314 rotation=0) 

315 ax.text(i, -.5, labels[i], fontsize=fontsize, 

316 color=label_color, rotation=rotation, verticalalignment='bottom') 

317 ax.set_axis_off() 

318 

319 # set all ticks length to zero 

320 ax = plt.gca() 

321 ax.tick_params(axis='both', which='both', length=0) 

322 

323 if colorbar: 

324 N = self.params['colorbar.N'] + 1 

325 if N < 2: 

326 raise RuntimeError("No colorbar to draw.") 

327 cb = plt.gcf().colorbar( 

328 self.collection, orientation=self.params['colorbar.orientation'], 

329 shrink=self.params['colorbar.shrink'], 

330 boundaries=numpy.linspace(0, 1, N), 

331 ticks=[0, .25, 0.5, 0.75, 1]) 

332 cb.ax.set_yticklabels([-1, -.5, 0, .5, 1]) 

333 # make sure it goes from -1 to 1 even though actual values may not 

334 # reach that range 

335 # cb.set_clim(0, 1) 

336 # not working in matplotlib 3.3.0 

337 

338 return ax 

339 

340 def _add_patches(self, df, method, fill, ax, diagonal=True): 

341 

342 from matplotlib.patches import Ellipse, Circle, Rectangle, Wedge 

343 from matplotlib.collections import PatchCollection 

344 

345 width, height = df.shape 

346 

347 patches = [] 

348 colors = [] 

349 for x in range(width): 

350 for y in range(height): 

351 if fill == 'lower' and x > y: 

352 continue 

353 if fill == 'upper' and x < y: 

354 continue 

355 if diagonal is False and x == y: 

356 continue 

357 datum = (df.iloc[x, y] + 1.) / 2. 

358 d = df.iloc[x, y] 

359 d_abs = numpy.abs(d) 

360 #c = self.pvalues[x, y] 

361 rotate = -45 if d > 0 else +45 

362 #cmap = self.poscm if d >= 0 else self.negcm 

363 if method in ['ellipse', 'square', 'rectangle', 'color']: 

364 if method == 'ellipse': 

365 func = Ellipse 

366 patch = func((x, y), width=1 * self.shrink, 

367 height=(self.shrink - d_abs * self.shrink), angle=rotate) 

368 else: 

369 func = Rectangle 

370 w = h = d_abs * self.shrink 

371 offset = (1 - w) / 2. 

372 if method == 'color': 

373 w = 1 

374 h = 1 

375 offset = 0 

376 patch = func((x + offset - .5, y + offset - .5), width=w, 

377 height=h, angle=0) 

378 if self.edgecolor: 

379 patch.set_edgecolor(self.edgecolor) 

380 # patch.set_facecolor(cmap(d_abs)) 

381 colors.append(datum) 

382 if d_abs > 0.05: 

383 patch.set_linestyle('dotted') 

384 # ax.add_artist(patch) 

385 patches.append(patch) 

386 elif method == 'circle': 

387 patch = Circle((x, y), radius=d_abs * self.shrink / 2.) 

388 if self.edgecolor: 

389 patch.set_edgecolor(self.edgecolor) 

390 # patch.set_facecolor(cmap(d_abs)) 

391 colors.append(datum) 

392 if d_abs > 0.05: 

393 patch.set_linestyle('dotted') 

394 # ax.add_artist(patch) 

395 patches.append(patch) 

396 elif method in ['number', 'text']: 

397 if d < 0: 

398 edgecolor = self.cm(-1.0) 

399 elif d >= 0: 

400 edgecolor = self.cm(1.0) 

401 d_str = "{:.2f}".format(d).replace( 

402 "0.", ".").replace(".00", "") 

403 ax.text(x, y, d_str, color=edgecolor, 

404 fontsize=self.fontsize, horizontalalignment='center', 

405 weight='bold', alpha=max(0.5, d_abs), 

406 withdash=False) 

407 elif method == 'pie': 

408 S = 360 * d_abs 

409 patch = [ 

410 Wedge((x, y), 1 * self.shrink / 2., -90, S - 90), 

411 Wedge((x, y), 1 * self.shrink / 2., S - 90, 360 - 90), 

412 ] 

413 # patch[0].set_facecolor(cmap(d_abs)) 

414 # patch[1].set_facecolor('white') 

415 colors.append(datum) 

416 colors.append(0.5) 

417 if self.edgecolor: 

418 patch[0].set_edgecolor(self.edgecolor) 

419 patch[1].set_edgecolor(self.edgecolor) 

420 

421 # ax.add_artist(patch[0]) 

422 # ax.add_artist(patch[1]) 

423 patches.append(patch[0]) 

424 patches.append(patch[1]) 

425 else: 

426 raise ValueError( 

427 'Method for the symbols is not known. Use e.g, square, circle') 

428 

429 if self.binarise_color: 

430 colors = [1 if color > 0.5 else -1 for color in colors] 

431 

432 if len(patches): 

433 col1 = PatchCollection( 

434 patches, array=numpy.array(colors), cmap=self.cm) 

435 ax.add_collection(col1) 

436 

437 self.collection = col1