Coverage for mlprodict/plotting/plotting_benchmark.py: 100%

69 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Useful plots. 

4""" 

5import numpy 

6 

7 

8def heatmap(data, row_labels, col_labels, ax=None, 

9 cbar_kw=None, cbarlabel=None, **kwargs): 

10 """ 

11 Creates a heatmap from a numpy array and two lists of labels. 

12 See @see fn plot_benchmark_metrics for an example. 

13 

14 @param data a 2D numpy array of shape (N, M). 

15 @param row_labels a list or array of length N with the labels for the rows. 

16 @param col_labels a list or array of length M with the labels for the columns. 

17 @param ax a `matplotlib.axes.Axes` instance to which the heatmap is plotted, 

18 if not provided, use current axes or create a new one. Optional. 

19 @param cbar_kw a dictionary with arguments to `matplotlib.Figure.colorbar 

20 <https://matplotlib.org/api/_as_gen/matplotlib.pyplot.colorbar.html>`_. 

21 Optional. 

22 @param cbarlabel the label for the colorbar. Optional. 

23 @param kwargs all other arguments are forwarded to `imshow 

24 <https://matplotlib.org/api/_as_gen/matplotlib.pyplot.imshow.html>`_ 

25 @return ax, image, color bar 

26 """ 

27 import matplotlib.pyplot as plt # delayed 

28 

29 if not ax: 

30 ax = plt.gca() # pragma: no cover 

31 

32 # Plot the heatmap 

33 im = ax.imshow(data, **kwargs) 

34 

35 # Create colorbar 

36 if cbar_kw is None: 

37 cbar_kw = {} 

38 cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw) 

39 if cbarlabel is not None: 

40 cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom") 

41 

42 # We want to show all ticks... 

43 ax.set_xticks(numpy.arange(data.shape[1])) 

44 ax.set_yticks(numpy.arange(data.shape[0])) 

45 # ... and label them with the respective list entries. 

46 ax.set_xticklabels(col_labels) 

47 ax.set_yticklabels(row_labels) 

48 

49 # Let the horizontal axes labeling appear on top. 

50 ax.tick_params(top=True, bottom=False, 

51 labeltop=True, labelbottom=False) 

52 

53 # Rotate the tick labels and set their alignment. 

54 plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", 

55 rotation_mode="anchor") 

56 

57 # Turn spines off and create white grid. 

58 for _, spine in ax.spines.items(): 

59 spine.set_visible(False) 

60 

61 ax.set_xticks(numpy.arange(data.shape[1] + 1) - .5, minor=True) 

62 ax.set_yticks(numpy.arange(data.shape[0] + 1) - .5, minor=True) 

63 ax.grid(which="minor", color="w", linestyle='-', linewidth=3) 

64 ax.tick_params(which="minor", bottom=False, left=False) 

65 return ax, im, cbar 

66 

67 

68def annotate_heatmap(im, data=None, valfmt="{x:.2f}", 

69 textcolors=("black", "black"), 

70 threshold=None, **textkw): 

71 """ 

72 Annotates a heatmap. 

73 See @see fn plot_benchmark_metrics for an example. 

74 

75 @param im the *AxesImage* to be labeled. 

76 @param data data used to annotate. If None, the image's data is used. Optional. 

77 @param valfmt the format of the annotations inside the heatmap. This should either 

78 use the string format method, e.g. `"$ {x:.2f}"`, or be a 

79 `matplotlib.ticker.Formatter 

80 <https://matplotlib.org/api/ticker_api.html>`_. Optional. 

81 @param textcolors a list or array of two color specifications. The first is used for 

82 values below a threshold, the second for those above. Optional. 

83 @param threshold value in data units according to which the colors from textcolors are 

84 applied. If None (the default) uses the middle of the colormap as 

85 separation. Optional. 

86 @param textkw all other arguments are forwarded to each call to `text` used to create 

87 the text labels. 

88 @return annotated objects 

89 """ 

90 if not isinstance(data, (list, numpy.ndarray)): 

91 data = im.get_array() 

92 if threshold is not None: 

93 threshold = im.norm(threshold) # pragma: no cover 

94 else: 

95 threshold = im.norm(data.max()) / 2. 

96 

97 kw = dict(horizontalalignment="center", verticalalignment="center") 

98 kw.update(textkw) 

99 

100 # Get the formatter in case a string is supplied 

101 if isinstance(valfmt, str): 

102 import matplotlib # delayed 

103 valfmt = matplotlib.ticker.StrMethodFormatter(valfmt) 

104 

105 texts = [] 

106 for i in range(data.shape[0]): 

107 for j in range(data.shape[1]): 

108 kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)]) 

109 text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) 

110 texts.append(text) 

111 

112 return texts 

113 

114 

115def plot_benchmark_metrics(metric, xlabel=None, ylabel=None, 

116 middle=1., transpose=False, ax=None, 

117 cbar_kw=None, cbarlabel=None, 

118 valfmt="{x:.2f}x"): 

119 """ 

120 Plots a heatmap which represents a benchmark. 

121 See example below. 

122 

123 @param metric dictionary ``{ (x,y): value }`` 

124 @param xlabel x label 

125 @param ylabel y label 

126 @param middle force the white color to be this value 

127 @param transpose switches *x* and *y* 

128 @param ax axis to borrow 

129 @param cbar_kw a dictionary with arguments to `matplotlib.Figure.colorbar 

130 <https://matplotlib.org/api/_as_gen/matplotlib.pyplot.colorbar.html>`_. 

131 Optional. 

132 @param cbarlabel the label for the colorbar. Optional. 

133 @param valfmt format for the annotations 

134 @return ax, colorbar 

135 

136 .. exref:: 

137 :title: Plot benchmark improvments 

138 :lid: plot-2d-benchmark-metric 

139 

140 .. plot:: 

141 

142 import matplotlib.pyplot as plt 

143 from mlprodict.plotting.plotting_benchmark import plot_benchmark_metrics 

144 

145 data = {(1, 1): 0.1, (10, 1): 1, (1, 10): 2, 

146 (10, 10): 100, (100, 1): 100, (100, 10): 1000} 

147 

148 fig, ax = plt.subplots(1, 2, figsize=(10, 4)) 

149 plot_benchmark_metrics(data, ax=ax[0], cbar_kw={'shrink': 0.6}) 

150 plot_benchmark_metrics(data, ax=ax[1], transpose=True, 

151 xlabel='X', ylabel='Y', 

152 cbarlabel="ratio") 

153 plt.show() 

154 """ 

155 if transpose: 

156 metric = {(k[1], k[0]): v for k, v in metric.items()} 

157 return plot_benchmark_metrics(metric, ax=ax, xlabel=ylabel, ylabel=xlabel, 

158 middle=middle, transpose=False, 

159 cbar_kw=cbar_kw, cbarlabel=cbarlabel) 

160 

161 from matplotlib.colors import LogNorm # delayed 

162 

163 x = numpy.array(list(sorted(set(k[0] for k in metric)))) 

164 y = numpy.array(list(sorted(set(k[1] for k in metric)))) 

165 rx = {v: i for i, v in enumerate(x)} 

166 ry = {v: i for i, v in enumerate(y)} 

167 

168 X, _ = numpy.meshgrid(x, y) 

169 zm = numpy.zeros(X.shape, dtype=numpy.float64) 

170 for k, v in metric.items(): 

171 zm[ry[k[1]], rx[k[0]]] = v 

172 

173 xs = [str(_) for _ in x] 

174 ys = [str(_) for _ in y] 

175 vmin = min(metric.values()) 

176 vmax = max(metric.values()) 

177 if middle is not None: 

178 v1 = middle / vmin 

179 v2 = middle / vmax 

180 vmin = min(vmin, v2) 

181 vmax = max(vmax, v1) 

182 ax, im, cbar = heatmap(zm, ys, xs, ax=ax, cmap="bwr", 

183 norm=LogNorm(vmin=vmin, vmax=vmax), 

184 cbarlabel=cbarlabel, cbar_kw=cbar_kw) 

185 annotate_heatmap(im, valfmt=valfmt) 

186 if xlabel is not None: 

187 ax.set_xlabel(xlabel) 

188 if ylabel is not None: 

189 ax.set_ylabel(ylabel) 

190 return ax, cbar