Coverage for src/ensae_projects/hackathon/image_knn.py: 92%

144 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-07-20 04:37 +0200

1# -*- coding: utf-8 -*- 

2""" 

3@file 

4@brief Builds a knn classifier for image in order to find close images. 

5""" 

6import os 

7import numpy 

8from PIL.Image import Image 

9from sklearn.neighbors import NearestNeighbors 

10from .image_helper import img2gray, enumerate_image_class, read_image, image_zoom 

11 

12 

13class ImageNearestNeighbors(NearestNeighbors): 

14 """ 

15 Builds a model on the top of :epkg:`NearestNeighbors` 

16 in order to find close images. 

17 

18 :param transform: function which transform every image 

19 :param image_size: every image is zoomed to keep the same dimension 

20 :param kwargs: see :epkg:`NearestNeighbors` 

21 """ 

22 

23 def __init__(self, transform='gray', image_size=(10, 10), **kwargs): 

24 NearestNeighbors.__init__(self, **kwargs) 

25 self.image_size = image_size 

26 self.transform = transform 

27 self._get_transform() 

28 

29 def _get_transform(self): 

30 """ 

31 Returns the associated transform function with ``self.transform_``. 

32 """ 

33 if self.transform == "gray": 

34 pre = img2gray 

35 elif self.transform is None: 

36 pre = None 

37 else: 

38 raise ValueError( 

39 "No transform assicated with value '{0}'.".format(self.transform)) 

40 if pre is None: 

41 return lambda img: image_zoom(img, new_size=self.image_size) 

42 else: 

43 return lambda img: image_zoom(pre(img), new_size=self.image_size) 

44 

45 def _folder2matrix(self, folder, fLOG): 

46 """ 

47 Converts images stored in a folder into a matrix of features. 

48 """ 

49 transform = self._get_transform() 

50 imgs = [] 

51 subs = [] 

52 stack = [] 

53 for i, (name, sub) in enumerate(enumerate_image_class(folder, abspath=False)): 

54 if fLOG is not None and i % 1000 == 0: 

55 fLOG("[ImageNearestNeighbors] processing image {0}: " 

56 "'{1}' - class '{2}'".format(i, name, sub)) 

57 imgs.append(name.replace("\\", "/")) 

58 subs.append(sub) 

59 img = read_image(os.path.join(folder, name)) 

60 trimg = transform(img) 

61 stack.append(numpy.array(trimg).ravel()) 

62 X = numpy.vstack(stack) 

63 return X, imgs, subs 

64 

65 def _imglist2matrix(self, list_of_images, fLOG): 

66 """ 

67 Converts a list of images into a matrix of features. 

68 """ 

69 transform = self._get_transform() 

70 imgs = [] 

71 subs = [] 

72 stack = [] 

73 for i, name in enumerate(list_of_images): 

74 if isinstance(name, tuple): 

75 name, sub = name 

76 else: 

77 sub = None 

78 if fLOG is not None and i % 1000 == 0: 

79 fLOG("[ImageNearestNeighbors] processing image {0}: " 

80 "'{1}' - class '{2}'".format(i, img, sub)) 

81 if isinstance(name, Image): 

82 imgs.append(None) 

83 img = name 

84 else: 

85 imgs.append(name.replace("\\", "/")) 

86 img = read_image(name) 

87 subs.append(sub) 

88 trimg = transform(img) 

89 stack.append(numpy.array(trimg).ravel()) 

90 X = numpy.vstack(stack) 

91 return X, imgs, subs 

92 

93 def fit(self, X, y=None, fLOG=None): # pylint: disable=W0221 

94 """ 

95 Fits the model. *X* can be a folder. 

96 

97 @param X matrix or str for a subfolder of images 

98 @param y unused 

99 @param fLOG logging function 

100 

101 If *X* is a folder, the method relies on function 

102 @see fct enumerate_image_class. In that case, the method 

103 also creates attributes: 

104 

105 * ``image_names_``: all image names 

106 * ``image_classes_``: subfolder the image belongs too 

107 """ 

108 if isinstance(X, str): 

109 if not os.path.exists(X): 

110 raise FileNotFoundError("Folder '{0}' not found.".format(X)) 

111 X, imgs, subs = self._folder2matrix(X, fLOG) 

112 self.image_names_ = imgs # pylint: disable=W0201 

113 self.image_classes_ = subs # pylint: disable=W0201 

114 

115 elif isinstance(X, list): 

116 if isinstance(X[0], Image): 

117 transform = self._get_transform() 

118 X = numpy.array([numpy.array(transform(img)).ravel() 

119 for img in X]) 

120 elif isinstance(X[0], str): 

121 # image names 

122 X, imgs, subs = self._imglist2matrix(X, fLOG) 

123 self.image_names_ = imgs # pylint: disable=W0201 

124 self.image_classes_ = subs # pylint: disable=W0201 

125 elif isinstance(X[0], tuple): 

126 self.image_classes_ = list( # pylint: disable=W0201 

127 map(lambda t: t[1], X)) 

128 X, imgs, _ = self._imglist2matrix([_[0] for _ in X], fLOG) 

129 self.image_names_ = imgs # pylint: disable=W0201 

130 else: 

131 raise TypeError( 

132 "X should be a list of PIL.Image not {0}".format(type(X[0]))) 

133 

134 super(ImageNearestNeighbors, self).fit(X, y) 

135 return self 

136 

137 def _private_kn(self, method, X, *args, fLOG=None, **kwargs): 

138 """ 

139 Commun private function to handle the same kind of 

140 inputs in all transform functions. 

141 

142 @param method method to run 

143 @param X inputs, matrix, folder or list of images 

144 @param args additional positinal arguments 

145 @param fLOG logging function 

146 @param kwargs additional named arguements 

147 @return depends on *method* 

148 """ 

149 if isinstance(X, str): 

150 if not os.path.exists(X): 

151 raise FileNotFoundError("Folder '{0}' not found.".format(X)) 

152 if os.path.isfile(X): 

153 X = [X] 

154 return self._private_kn(method, X, *args, **kwargs) 

155 X = self._folder2matrix(X, fLOG=fLOG)[0] 

156 

157 elif isinstance(X, list): 

158 if isinstance(X[0], Image): 

159 transform = self._get_transform() 

160 X = numpy.array([numpy.array(transform(img)).ravel() 

161 for img in X]) 

162 elif isinstance(X[0], str): 

163 # image names 

164 X = self._imglist2matrix(X, None)[0] 

165 elif isinstance(X[0], tuple): 

166 # image names 

167 X = self._imglist2matrix([_[0] for _ in X], fLOG=fLOG)[0] 

168 else: 

169 raise TypeError("X should be a list of Image") 

170 elif isinstance(X, Image): 

171 return self._private_kn(method, [X], *args, **kwargs) 

172 

173 method = getattr(NearestNeighbors, method) 

174 return method(self, X, *args, **kwargs) 

175 

176 def kneighbors(self, X=None, n_neighbors=None, return_distance=True, fLOG=None): # pylint: disable=W0221 

177 """ 

178 See :epkg:`NearestNeighbors`, method :epkg:`kneighbors`. 

179 Parameter *X* can be a file, the image is then loaded and converted 

180 with the same transform. *X* can also be an *Image* from :epkg:`PIL`. 

181 """ 

182 return self._private_kn("kneighbors", X=X, n_neighbors=n_neighbors, 

183 return_distance=return_distance, fLOG=fLOG) 

184 

185 def kneighbors_graph(self, X=None, n_neighbors=None, mode='connectivity', fLOG=None): # pylint: disable=W0221 

186 """ 

187 See :epkg:`NearestNeighbors`, method :epkg:`kneighbors_graph`. 

188 Parameter *X* can be a file, the image is then loaded and converted 

189 with the same transform. *X* can also be an *Image* from :epkg:`PIL`. 

190 """ 

191 return self._private_kn("kneighbors_graph", X=X, n_neighbors=n_neighbors, mode=mode, fLOG=fLOG) 

192 

193 def radius_neighbors(self, X=None, radius=None, return_distance=True, fLOG=None): # pylint: disable=W0221,W0237 

194 """ 

195 See :epkg:`NearestNeighbors`, method :epkg:`radius_neighbors`. 

196 Parameter *X* can be a file, the image is then loaded and converted 

197 with the same transform. *X* can also be an *Image* from :epkg:`PIL`. 

198 """ 

199 return self._private_kn("radius_neighbors", X=X, radius=radius, 

200 return_distance=return_distance, fLOG=fLOG) 

201 

202 def get_image_names(self, indices): 

203 """ 

204 Returns images names for the given list of indices. 

205 

206 @param indices indices can be a single array or a matrix. 

207 @return same shape 

208 """ 

209 if not hasattr(self, "image_names_"): 

210 raise RuntimeError("No image names were stored during training.") 

211 new_indices = indices.ravel() 

212 res = numpy.array([self.image_names_[i] for i in new_indices]) 

213 return res.reshape(indices.shape) 

214 

215 def get_image_classes(self, indices): 

216 """ 

217 Returns images classes for the given list of indices. 

218 

219 @param indices indices can be a single array or a matrix. 

220 @return same shape 

221 """ 

222 if not hasattr(self, "image_classes_"): 

223 raise RuntimeError("No image classes were stored during training.") 

224 new_indices = indices.ravel() 

225 res = numpy.array([self.image_classes_[i] for i in new_indices]) 

226 return res.reshape(indices.shape) 

227 

228 def plot_neighbors(self, neighbors, distances=None, obs=None, return_figure=False, 

229 format_distance="%1.2f", folder_or_images=None): 

230 """ 

231 Calls :epkg:`plot_gallery_images` 

232 with information on the neighbors. 

233 

234 :param neighbors: matrix of indices 

235 :param distances: distances to display 

236 :param obs: original image, if not None, will be placed 

237 on the first row 

238 :param return_figure: returns ``fig, ax`` instead of ``ax`` 

239 :param format_distance: used to format distances 

240 :param folder_or_images: image paths may be relative 

241 to some folder, in that case, they should be relative 

242 to this folder, it can also be a list of images 

243 :return: *ax* or *fix, ax* if *return_figure* is True 

244 """ 

245 from mlinsights.plotting import plot_gallery_images 

246 names = self.get_image_names(neighbors) 

247 if hasattr(self, "image_classes_"): 

248 subs = self.get_image_classes(neighbors) 

249 else: 

250 subs = numpy.array([["" for i in range(names.shape[1])] 

251 for j in range(names.shape[0])]) 

252 

253 labels = [] 

254 if distances is not None: 

255 for i in range(names.shape[0]): 

256 for j in range(names.shape[1]): 

257 labels.append("{0} d={1}".format( 

258 subs[i, j], format_distance % distances[i, j])) 

259 else: 

260 for i in range(names.shape[0]): 

261 for j in range(names.shape[1]): 

262 labels.append(subs[i, j] + " i=%d" % neighbors[i, j]) 

263 subs = numpy.array(labels).reshape(subs.shape) 

264 

265 if obs is not None: 

266 if isinstance(obs, str): 

267 obs = read_image(obs) 

268 row = numpy.array([object() for i in range(names.shape[1])]) 

269 row[0] = obs 

270 names = numpy.vstack([row, names]) 

271 text = numpy.array(["" for i in range(names.shape[1])]) 

272 text[0] = "-" 

273 subs = numpy.vstack([text, subs]) 

274 

275 fi = None if isinstance(folder_or_images, list) else folder_or_images 

276 return plot_gallery_images(names, subs, return_figure=return_figure, 

277 folder_image=fi)