Coverage for src/ensae_projects/hackathon/image_knn.py: 92%
144 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-07-20 04:37 +0200
« prev ^ index » next coverage.py v7.1.0, created at 2023-07-20 04:37 +0200
1# -*- coding: utf-8 -*-
2"""
3@file
4@brief Builds a knn classifier for image in order to find close images.
5"""
6import os
7import numpy
8from PIL.Image import Image
9from sklearn.neighbors import NearestNeighbors
10from .image_helper import img2gray, enumerate_image_class, read_image, image_zoom
13class ImageNearestNeighbors(NearestNeighbors):
14 """
15 Builds a model on the top of :epkg:`NearestNeighbors`
16 in order to find close images.
18 :param transform: function which transform every image
19 :param image_size: every image is zoomed to keep the same dimension
20 :param kwargs: see :epkg:`NearestNeighbors`
21 """
23 def __init__(self, transform='gray', image_size=(10, 10), **kwargs):
24 NearestNeighbors.__init__(self, **kwargs)
25 self.image_size = image_size
26 self.transform = transform
27 self._get_transform()
29 def _get_transform(self):
30 """
31 Returns the associated transform function with ``self.transform_``.
32 """
33 if self.transform == "gray":
34 pre = img2gray
35 elif self.transform is None:
36 pre = None
37 else:
38 raise ValueError(
39 "No transform assicated with value '{0}'.".format(self.transform))
40 if pre is None:
41 return lambda img: image_zoom(img, new_size=self.image_size)
42 else:
43 return lambda img: image_zoom(pre(img), new_size=self.image_size)
45 def _folder2matrix(self, folder, fLOG):
46 """
47 Converts images stored in a folder into a matrix of features.
48 """
49 transform = self._get_transform()
50 imgs = []
51 subs = []
52 stack = []
53 for i, (name, sub) in enumerate(enumerate_image_class(folder, abspath=False)):
54 if fLOG is not None and i % 1000 == 0:
55 fLOG("[ImageNearestNeighbors] processing image {0}: "
56 "'{1}' - class '{2}'".format(i, name, sub))
57 imgs.append(name.replace("\\", "/"))
58 subs.append(sub)
59 img = read_image(os.path.join(folder, name))
60 trimg = transform(img)
61 stack.append(numpy.array(trimg).ravel())
62 X = numpy.vstack(stack)
63 return X, imgs, subs
65 def _imglist2matrix(self, list_of_images, fLOG):
66 """
67 Converts a list of images into a matrix of features.
68 """
69 transform = self._get_transform()
70 imgs = []
71 subs = []
72 stack = []
73 for i, name in enumerate(list_of_images):
74 if isinstance(name, tuple):
75 name, sub = name
76 else:
77 sub = None
78 if fLOG is not None and i % 1000 == 0:
79 fLOG("[ImageNearestNeighbors] processing image {0}: "
80 "'{1}' - class '{2}'".format(i, img, sub))
81 if isinstance(name, Image):
82 imgs.append(None)
83 img = name
84 else:
85 imgs.append(name.replace("\\", "/"))
86 img = read_image(name)
87 subs.append(sub)
88 trimg = transform(img)
89 stack.append(numpy.array(trimg).ravel())
90 X = numpy.vstack(stack)
91 return X, imgs, subs
93 def fit(self, X, y=None, fLOG=None): # pylint: disable=W0221
94 """
95 Fits the model. *X* can be a folder.
97 @param X matrix or str for a subfolder of images
98 @param y unused
99 @param fLOG logging function
101 If *X* is a folder, the method relies on function
102 @see fct enumerate_image_class. In that case, the method
103 also creates attributes:
105 * ``image_names_``: all image names
106 * ``image_classes_``: subfolder the image belongs too
107 """
108 if isinstance(X, str):
109 if not os.path.exists(X):
110 raise FileNotFoundError("Folder '{0}' not found.".format(X))
111 X, imgs, subs = self._folder2matrix(X, fLOG)
112 self.image_names_ = imgs # pylint: disable=W0201
113 self.image_classes_ = subs # pylint: disable=W0201
115 elif isinstance(X, list):
116 if isinstance(X[0], Image):
117 transform = self._get_transform()
118 X = numpy.array([numpy.array(transform(img)).ravel()
119 for img in X])
120 elif isinstance(X[0], str):
121 # image names
122 X, imgs, subs = self._imglist2matrix(X, fLOG)
123 self.image_names_ = imgs # pylint: disable=W0201
124 self.image_classes_ = subs # pylint: disable=W0201
125 elif isinstance(X[0], tuple):
126 self.image_classes_ = list( # pylint: disable=W0201
127 map(lambda t: t[1], X))
128 X, imgs, _ = self._imglist2matrix([_[0] for _ in X], fLOG)
129 self.image_names_ = imgs # pylint: disable=W0201
130 else:
131 raise TypeError(
132 "X should be a list of PIL.Image not {0}".format(type(X[0])))
134 super(ImageNearestNeighbors, self).fit(X, y)
135 return self
137 def _private_kn(self, method, X, *args, fLOG=None, **kwargs):
138 """
139 Commun private function to handle the same kind of
140 inputs in all transform functions.
142 @param method method to run
143 @param X inputs, matrix, folder or list of images
144 @param args additional positinal arguments
145 @param fLOG logging function
146 @param kwargs additional named arguements
147 @return depends on *method*
148 """
149 if isinstance(X, str):
150 if not os.path.exists(X):
151 raise FileNotFoundError("Folder '{0}' not found.".format(X))
152 if os.path.isfile(X):
153 X = [X]
154 return self._private_kn(method, X, *args, **kwargs)
155 X = self._folder2matrix(X, fLOG=fLOG)[0]
157 elif isinstance(X, list):
158 if isinstance(X[0], Image):
159 transform = self._get_transform()
160 X = numpy.array([numpy.array(transform(img)).ravel()
161 for img in X])
162 elif isinstance(X[0], str):
163 # image names
164 X = self._imglist2matrix(X, None)[0]
165 elif isinstance(X[0], tuple):
166 # image names
167 X = self._imglist2matrix([_[0] for _ in X], fLOG=fLOG)[0]
168 else:
169 raise TypeError("X should be a list of Image")
170 elif isinstance(X, Image):
171 return self._private_kn(method, [X], *args, **kwargs)
173 method = getattr(NearestNeighbors, method)
174 return method(self, X, *args, **kwargs)
176 def kneighbors(self, X=None, n_neighbors=None, return_distance=True, fLOG=None): # pylint: disable=W0221
177 """
178 See :epkg:`NearestNeighbors`, method :epkg:`kneighbors`.
179 Parameter *X* can be a file, the image is then loaded and converted
180 with the same transform. *X* can also be an *Image* from :epkg:`PIL`.
181 """
182 return self._private_kn("kneighbors", X=X, n_neighbors=n_neighbors,
183 return_distance=return_distance, fLOG=fLOG)
185 def kneighbors_graph(self, X=None, n_neighbors=None, mode='connectivity', fLOG=None): # pylint: disable=W0221
186 """
187 See :epkg:`NearestNeighbors`, method :epkg:`kneighbors_graph`.
188 Parameter *X* can be a file, the image is then loaded and converted
189 with the same transform. *X* can also be an *Image* from :epkg:`PIL`.
190 """
191 return self._private_kn("kneighbors_graph", X=X, n_neighbors=n_neighbors, mode=mode, fLOG=fLOG)
193 def radius_neighbors(self, X=None, radius=None, return_distance=True, fLOG=None): # pylint: disable=W0221,W0237
194 """
195 See :epkg:`NearestNeighbors`, method :epkg:`radius_neighbors`.
196 Parameter *X* can be a file, the image is then loaded and converted
197 with the same transform. *X* can also be an *Image* from :epkg:`PIL`.
198 """
199 return self._private_kn("radius_neighbors", X=X, radius=radius,
200 return_distance=return_distance, fLOG=fLOG)
202 def get_image_names(self, indices):
203 """
204 Returns images names for the given list of indices.
206 @param indices indices can be a single array or a matrix.
207 @return same shape
208 """
209 if not hasattr(self, "image_names_"):
210 raise RuntimeError("No image names were stored during training.")
211 new_indices = indices.ravel()
212 res = numpy.array([self.image_names_[i] for i in new_indices])
213 return res.reshape(indices.shape)
215 def get_image_classes(self, indices):
216 """
217 Returns images classes for the given list of indices.
219 @param indices indices can be a single array or a matrix.
220 @return same shape
221 """
222 if not hasattr(self, "image_classes_"):
223 raise RuntimeError("No image classes were stored during training.")
224 new_indices = indices.ravel()
225 res = numpy.array([self.image_classes_[i] for i in new_indices])
226 return res.reshape(indices.shape)
228 def plot_neighbors(self, neighbors, distances=None, obs=None, return_figure=False,
229 format_distance="%1.2f", folder_or_images=None):
230 """
231 Calls :epkg:`plot_gallery_images`
232 with information on the neighbors.
234 :param neighbors: matrix of indices
235 :param distances: distances to display
236 :param obs: original image, if not None, will be placed
237 on the first row
238 :param return_figure: returns ``fig, ax`` instead of ``ax``
239 :param format_distance: used to format distances
240 :param folder_or_images: image paths may be relative
241 to some folder, in that case, they should be relative
242 to this folder, it can also be a list of images
243 :return: *ax* or *fix, ax* if *return_figure* is True
244 """
245 from mlinsights.plotting import plot_gallery_images
246 names = self.get_image_names(neighbors)
247 if hasattr(self, "image_classes_"):
248 subs = self.get_image_classes(neighbors)
249 else:
250 subs = numpy.array([["" for i in range(names.shape[1])]
251 for j in range(names.shape[0])])
253 labels = []
254 if distances is not None:
255 for i in range(names.shape[0]):
256 for j in range(names.shape[1]):
257 labels.append("{0} d={1}".format(
258 subs[i, j], format_distance % distances[i, j]))
259 else:
260 for i in range(names.shape[0]):
261 for j in range(names.shape[1]):
262 labels.append(subs[i, j] + " i=%d" % neighbors[i, j])
263 subs = numpy.array(labels).reshape(subs.shape)
265 if obs is not None:
266 if isinstance(obs, str):
267 obs = read_image(obs)
268 row = numpy.array([object() for i in range(names.shape[1])])
269 row[0] = obs
270 names = numpy.vstack([row, names])
271 text = numpy.array(["" for i in range(names.shape[1])])
272 text[0] = "-"
273 subs = numpy.vstack([text, subs])
275 fi = None if isinstance(folder_or_images, list) else folder_or_images
276 return plot_gallery_images(names, subs, return_figure=return_figure,
277 folder_image=fi)