Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Extracts objects from an image based on deep learning.
4"""
5from contextlib import redirect_stdout
6import io
7import os
8import numpy
9from PIL import Image
10import skimage
11from skimage.io._plugins.pil_plugin import pil_to_ndarray
12import chainer # pylint: disable=E0401
13import fcn # pylint: disable=E0401
14from .dlbase import DeepLearningImage
17class DLImageSegmentation(DeepLearningImage):
18 """
19 Segments an image.
20 Inspired from
21 `infer.py <https://github.com/wkentaro/fcn/blob/master/examples/voc/infer.py>`_.
22 See notebook :ref:`imagesegmentationrst`.
23 """
25 def __init__(self, model="FCN8s", n_class=21, gpu=False, class_name=None, fLOG=None):
26 """
27 @param model model name
28 @param n_class number of classes
29 @param gpu use gpu
30 @param class_name class names
31 @param fLOG logging function
33 List of known models:
35 * ``'FCN8s'``: image segmentation
36 """
37 self._fLOG = fLOG
38 if model == "FCN8s":
39 self.log(
40 "[DLImageSegmentation] download model '{0}'".format(model))
41 f = io.StringIO()
42 with redirect_stdout(f):
43 model_file = fcn.models.FCN8s.download()
44 self.log('[DLImageSegmentation] {0}'.format(f.getvalue()))
45 self._model_file = model_file
46 model_class = fcn.models.FCN8s
47 model = model_class(n_class=n_class)
48 self.log("[DLImageSegmentation] load_npz '{0}'".format(model_file))
49 chainer.serializers.load_npz( # pylint: disable=E1101
50 model_file, model) # pylint: disable=E1101
51 else:
52 raise NotImplementedError(
53 "Unable to interpret '{0}'".format(model))
55 DeepLearningImage.__init__(self, model, gpu=gpu, fLOG=fLOG)
56 self._n_class = n_class
57 if class_name is None:
58 self._class_name = class_name = fcn.datasets.VOC2012ClassSeg.class_names
59 else:
60 self._class_name = class_name
61 self.log("[DLImageSegmentation] class_name '{0}'".format(class_name))
63 if gpu:
64 self.log("[DLImageSegmentation] gpu")
65 chainer.cuda.get_device(self._gpu).use() # pylint: disable=E1101
66 model.to_gpu()
67 else:
68 self.log("[DLImageSegmentation] cpu")
70 @property
71 def ModelFile(self):
72 """
73 Returns the model file name.
74 """
75 return self._model_file
77 @staticmethod
78 def _new_size(old_size, new_size):
79 """
80 Computes a new size.
82 @param old_size current size
83 @param new_size new desired size
84 @return new size
86 *new_size* can be of:
88 * (int, int): this is the new size
89 * ('max2', int): this size is divided by 2 until the greater dimension
90 is below a threshold
91 """
92 if not isinstance(new_size, tuple):
93 raise TypeError("new_size must be a tuple")
94 if not isinstance(old_size, tuple):
95 raise TypeError("old_size must be a tuple")
96 if len(old_size) != 2:
97 raise ValueError("old_size must have two values")
98 if len(new_size) != 2:
99 raise ValueError("new_size must have two values")
100 if isinstance(new_size[0], str):
101 if new_size[0] == 'max2':
102 mx = max(old_size)
103 p = 1
104 while mx > new_size[1]:
105 mx //= 2
106 p *= 2
107 return (old_size[0] // p, old_size[1] // p)
108 else:
109 raise ValueError(
110 "Unable to interpret '{0}'".format(new_size[0]))
111 elif isinstance(new_size[0], int):
112 return new_size
113 else:
114 raise TypeError("new_size[0] must be an int")
116 def _load_image(self, img, resize=None):
117 """
118 Loads an image as a :epkg:`numpy:array`.
120 @param img image
121 @param resize resize the image before predicting,
122 see @see me _new_size
123 @return :epkg:`numpy:array`
124 """
125 if isinstance(img, str):
126 # Loads the image.
127 if not os.path.exists(img):
128 raise FileNotFoundError(img)
129 if resize is None:
130 feat = skimage.io.imread(img)
131 else:
132 pilimg = Image.open(img)
133 si = DLImageSegmentation._new_size(pilimg.size, resize)
134 pilimg2 = pilimg.resize(si)
135 feat = pil_to_ndarray(pilimg2)
136 elif isinstance(img, numpy.ndarray):
137 if resize is None:
138 feat = img
139 else:
140 # Does not work...
141 # feat = skimage.transform.resize(img, resize)
142 # So...
143 pilimg = Image.fromarray(img).convert('RGB')
144 pilimg2 = pilimg.resize(resize)
145 feat = pil_to_ndarray(pilimg)
146 else:
147 raise NotImplementedError(
148 "Not implemented for type '{0}'".format(type(img)))
149 return feat
151 def _preprocess(self, feat, preprocess=True):
152 """
153 Preprocesses the image before prediction.
155 @param feat image (output of @see me _load_image)
156 @param preprocess applies some preprocessing or not
157 @return preprocessed image
158 """
159 if preprocess:
160 input, = fcn.datasets.transform_lsvrc2012_vgg16( # pylint: disable=W0632
161 (feat,)) # pylint: disable=W0632
162 input = input[numpy.newaxis, :, :, :] # pylint: disable=E0401,E1126
163 return input
164 else:
165 return feat
167 def predict(self, img, resize=None):
168 """
169 Applies the model on features *X*.
171 @param img image
172 @param resize resize the image before predicting,
173 see @see me _new_size
174 @return (image, prediction)
175 """
176 feat = self._load_image(img, resize=resize)
177 input = self._preprocess(feat, preprocess=True)
178 if self._gpu:
179 input = chainer.cuda.to_gpu(input) # pylint: disable=E1101
181 with chainer.no_backprop_mode(): # pylint: disable=E1101
182 input = chainer.Variable(input) # pylint: disable=E1101
183 with chainer.using_config('train', False): # pylint: disable=E1101
184 self._model(input)
185 lbl_pred = chainer.functions.argmax( # pylint: disable=E1101
186 self._model.score, axis=1)[0]
187 lbl_pred = chainer.cuda.to_cpu( # pylint: disable=E1101
188 lbl_pred.data) # pylint: disable=E1101
190 return feat, lbl_pred
192 def plot(self, img, pred):
193 """
194 Displays the segmentation.
196 @param img initial image
197 @return new image
198 """
199 img = self._load_image(img)
200 viz = fcn.utils.visualize_segmentation(
201 lbl_pred=pred, img=img, n_class=self._n_class, label_names=self._class_name)
202 return viz