Coverage for src/code_beatrix/ai/image_segmentation.py: 82%

101 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-04-29 13:45 +0200

1""" 

2@file 

3@brief Extracts objects from an image based on deep learning. 

4""" 

5from contextlib import redirect_stdout 

6import io 

7import os 

8import numpy 

9from PIL import Image 

10import skimage 

11from skimage.io._plugins.pil_plugin import pil_to_ndarray 

12import chainer # pylint: disable=E0401 

13import fcn # pylint: disable=E0401 

14from .dlbase import DeepLearningImage 

15 

16 

17class DLImageSegmentation(DeepLearningImage): 

18 """ 

19 Segments an image. 

20 Inspired from 

21 `infer.py <https://github.com/wkentaro/fcn/blob/master/examples/voc/infer.py>`_. 

22 See notebook :ref:`imagesegmentationrst`. 

23 """ 

24 

25 def __init__(self, model="FCN8s", n_class=21, gpu=False, class_name=None, fLOG=None): 

26 """ 

27 @param model model name 

28 @param n_class number of classes 

29 @param gpu use gpu 

30 @param class_name class names 

31 @param fLOG logging function 

32 

33 List of known models: 

34 

35 * ``'FCN8s'``: image segmentation 

36 """ 

37 self._fLOG = fLOG 

38 if model == "FCN8s": 

39 self.log( 

40 "[DLImageSegmentation] download model '{0}'".format(model)) 

41 f = io.StringIO() 

42 with redirect_stdout(f): 

43 model_file = fcn.models.FCN8s.download() 

44 self.log('[DLImageSegmentation] {0}'.format(f.getvalue())) 

45 self._model_file = model_file 

46 model_class = fcn.models.FCN8s 

47 model = model_class(n_class=n_class) 

48 self.log("[DLImageSegmentation] load_npz '{0}'".format(model_file)) 

49 chainer.serializers.load_npz( # pylint: disable=E1101 

50 model_file, model) # pylint: disable=E1101 

51 else: 

52 raise NotImplementedError( 

53 "Unable to interpret '{0}'".format(model)) 

54 

55 DeepLearningImage.__init__(self, model, gpu=gpu, fLOG=fLOG) 

56 self._n_class = n_class 

57 if class_name is None: 

58 self._class_name = class_name = fcn.datasets.VOC2012ClassSeg.class_names 

59 else: 

60 self._class_name = class_name 

61 self.log("[DLImageSegmentation] class_name '{0}'".format(class_name)) 

62 

63 if gpu: 

64 self.log("[DLImageSegmentation] gpu") 

65 chainer.cuda.get_device(self._gpu).use() # pylint: disable=E1101 

66 model.to_gpu() 

67 else: 

68 self.log("[DLImageSegmentation] cpu") 

69 

70 @property 

71 def ModelFile(self): 

72 """ 

73 Returns the model file name. 

74 """ 

75 return self._model_file 

76 

77 @staticmethod 

78 def _new_size(old_size, new_size): 

79 """ 

80 Computes a new size. 

81 

82 @param old_size current size 

83 @param new_size new desired size 

84 @return new size 

85 

86 *new_size* can be of: 

87 

88 * (int, int): this is the new size 

89 * ('max2', int): this size is divided by 2 until the greater dimension 

90 is below a threshold 

91 """ 

92 if not isinstance(new_size, tuple): 

93 raise TypeError("new_size must be a tuple") 

94 if not isinstance(old_size, tuple): 

95 raise TypeError("old_size must be a tuple") 

96 if len(old_size) != 2: 

97 raise ValueError("old_size must have two values") 

98 if len(new_size) != 2: 

99 raise ValueError("new_size must have two values") 

100 if isinstance(new_size[0], str): 

101 if new_size[0] == 'max2': 

102 mx = max(old_size) 

103 p = 1 

104 while mx > new_size[1]: 

105 mx //= 2 

106 p *= 2 

107 return (old_size[0] // p, old_size[1] // p) 

108 else: 

109 raise ValueError( 

110 "Unable to interpret '{0}'".format(new_size[0])) 

111 elif isinstance(new_size[0], int): 

112 return new_size 

113 else: 

114 raise TypeError("new_size[0] must be an int") 

115 

116 def _load_image(self, img, resize=None): 

117 """ 

118 Loads an image as a :epkg:`numpy:array`. 

119 

120 @param img image 

121 @param resize resize the image before predicting, 

122 see @see me _new_size 

123 @return :epkg:`numpy:array` 

124 """ 

125 if isinstance(img, str): 

126 # Loads the image. 

127 if not os.path.exists(img): 

128 raise FileNotFoundError(img) 

129 if resize is None: 

130 feat = skimage.io.imread(img) 

131 else: 

132 pilimg = Image.open(img) 

133 si = DLImageSegmentation._new_size(pilimg.size, resize) 

134 pilimg2 = pilimg.resize(si) 

135 feat = pil_to_ndarray(pilimg2) 

136 elif isinstance(img, numpy.ndarray): 

137 if resize is None: 

138 feat = img 

139 else: 

140 # Does not work... 

141 # feat = skimage.transform.resize(img, resize) 

142 # So... 

143 pilimg = Image.fromarray(img).convert('RGB') 

144 pilimg2 = pilimg.resize(resize) 

145 feat = pil_to_ndarray(pilimg) 

146 else: 

147 raise NotImplementedError( 

148 "Not implemented for type '{0}'".format(type(img))) 

149 return feat 

150 

151 def _preprocess(self, feat, preprocess=True): 

152 """ 

153 Preprocesses the image before prediction. 

154 

155 @param feat image (output of @see me _load_image) 

156 @param preprocess applies some preprocessing or not 

157 @return preprocessed image 

158 """ 

159 if preprocess: 

160 input, = fcn.datasets.transform_lsvrc2012_vgg16( # pylint: disable=W0632 

161 (feat,)) # pylint: disable=W0632 

162 input = input[numpy.newaxis, :, :, 

163 :] # pylint: disable=E0401,E1126 

164 return input 

165 return feat 

166 

167 def predict(self, img, resize=None): # pylint: disable=W0237 

168 """ 

169 Applies the model on features *X*. 

170 

171 @param img image 

172 @param resize resize the image before predicting, 

173 see @see me _new_size 

174 @return (image, prediction) 

175 """ 

176 feat = self._load_image(img, resize=resize) 

177 input = self._preprocess(feat, preprocess=True) 

178 if self._gpu: 

179 input = chainer.cuda.to_gpu(input) # pylint: disable=E1101 

180 

181 with chainer.no_backprop_mode(): # pylint: disable=E1101 

182 input = chainer.Variable(input) # pylint: disable=E1101 

183 with chainer.using_config('train', False): # pylint: disable=E1101 

184 self._model(input) 

185 lbl_pred = chainer.functions.argmax( # pylint: disable=E1101 

186 self._model.score, axis=1)[0] 

187 lbl_pred = chainer.cuda.to_cpu( # pylint: disable=E1101 

188 lbl_pred.data) # pylint: disable=E1101 

189 

190 return feat, lbl_pred 

191 

192 def plot(self, img, pred): 

193 """ 

194 Displays the segmentation. 

195 

196 @param img initial image 

197 @return new image 

198 """ 

199 img = self._load_image(img) 

200 viz = fcn.utils.visualize_segmentation( 

201 lbl_pred=pred, img=img, n_class=self._n_class, label_names=self._class_name) 

202 return viz