Coverage for onnxcustom/training/data_loader.py: 97%

148 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 01:42 +0100

1""" 

2@file 

3@brief Manipulate data for training. 

4""" 

5import numpy 

6from ..utils.onnxruntime_helper import ( 

7 get_ort_device, numpy_to_ort_value, ort_device_to_string) 

8 

9 

10class OrtDataLoader: 

11 """ 

12 Draws consecutive random observations from a dataset 

13 by batch. It iterates over the datasets by drawing 

14 *batch_size* consecutive observations. 

15 

16 :param X: features 

17 :param y: labels 

18 :param sample_weight: weight or None 

19 :param batch_size: batch size (consecutive observations) 

20 :param device: :epkg:`C_OrtDevice` or a string such as `'cpu'` 

21 :param random_iter: random iteration 

22 

23 See example :ref:`l-orttraining-nn-gpu`. 

24 """ 

25 

26 def __init__(self, X, y, sample_weight=None, 

27 batch_size=20, device='cpu', random_iter=True): 

28 if len(y.shape) == 1: 

29 y = y.reshape((-1, 1)) 

30 if X.shape[0] != y.shape[0]: 

31 raise ValueError( # pragma: no cover 

32 f"Shape mismatch X.shape={X.shape!r}, y.shape={y.shape!r}.") 

33 

34 self.batch_size = batch_size 

35 self.device = get_ort_device(device) 

36 self.random_iter = random_iter 

37 

38 self.X_np = numpy.ascontiguousarray(X) 

39 self.y_np = numpy.ascontiguousarray(y).reshape((-1, 1)) 

40 

41 self.X_ort = numpy_to_ort_value(self.X_np, self.device) 

42 self.y_ort = numpy_to_ort_value(self.y_np, self.device) 

43 

44 self.desc = [(self.X_np.shape, self.X_np.dtype), 

45 (self.y_np.shape, self.y_np.dtype)] 

46 

47 if sample_weight is None: 

48 self.w_np = None 

49 self.w_ort = None 

50 else: 

51 if X.shape[0] != sample_weight.shape[0]: 

52 raise ValueError( # pragma: no cover 

53 "Shape mismatch X.shape=%r, sample_weight.shape=%r." 

54 "" % (X.shape, sample_weight.shape)) 

55 self.w_np = numpy.ascontiguousarray( 

56 sample_weight).reshape((-1, )) 

57 self.w_ort = numpy_to_ort_value(self.w_np, self.device) 

58 self.desc.append((self.w_np.shape, self.w_np.dtype)) 

59 

60 def __getstate__(self): 

61 "Removes any non pickable attribute." 

62 state = {} 

63 for att in ['X_np', 'y_np', 'w_np', 

64 'desc', 'batch_size', 'random_iter']: 

65 state[att] = getattr(self, att) 

66 state['device'] = ort_device_to_string(self.device) 

67 return state 

68 

69 def __setstate__(self, state): 

70 "Restores any non pickable attribute." 

71 for att, v in state.items(): 

72 setattr(self, att, v) 

73 self.device = get_ort_device(self.device) 

74 self.X_ort = numpy_to_ort_value(self.X_np, self.device) 

75 self.y_ort = numpy_to_ort_value(self.y_np, self.device) 

76 if self.w_np is None: 

77 self.w_ort = None 

78 else: 

79 self.w_ort = numpy_to_ort_value( 

80 self.w_np, self.device) 

81 return self 

82 

83 def __repr__(self): 

84 "usual" 

85 return "%s(..., ..., batch_size=%r, device=%r)" % ( 

86 self.__class__.__name__, self.batch_size, 

87 ort_device_to_string(self.device)) 

88 

89 def __len__(self): 

90 "Returns the number of observations." 

91 return self.desc[0][0][0] 

92 

93 def _next_iter(self, previous): 

94 if self.random_iter: 

95 b = len(self) - self.batch_size 

96 return numpy.random.randint(0, b) 

97 if previous == -1: 

98 return 0 

99 i = previous + self.batch_size 

100 if i + self.batch_size > len(self): 

101 i = len(self) - self.batch_size 

102 return i 

103 

104 def iter_numpy(self): 

105 """ 

106 Iterates over the datasets by drawing 

107 *batch_size* consecutive observations. 

108 This iterator is slow as it copies the data of every 

109 batch. The function yields :epkg:`C_OrtValue`. 

110 """ 

111 if self.device.device_type() != self.device.cpu(): 

112 raise RuntimeError( # pragma: no cover 

113 "Only CPU device is allowed if numpy arrays are requested " 

114 "not %r." % ort_device_to_string(self.device)) 

115 N = 0 

116 b = len(self) - self.batch_size 

117 if self.w_np is None: 

118 if b <= 0 or self.batch_size <= 0: 

119 yield (self.X_np, self.y_np) 

120 else: 

121 i = -1 

122 while N < len(self): 

123 i = self._next_iter(i) 

124 N += self.batch_size 

125 yield (self.X_np[i:i + self.batch_size], 

126 self.y_np[i:i + self.batch_size]) 

127 else: 

128 if b <= 0 or self.batch_size <= 0: 

129 yield (self.X_np, self.y_np, self.w_np) 

130 else: 

131 i = -1 

132 while N < len(self): 

133 i = self._next_iter(i) 

134 N += self.batch_size 

135 yield (self.X_np[i:i + self.batch_size], 

136 self.y_np[i:i + self.batch_size], 

137 self.w_np[i:i + self.batch_size]) 

138 

139 def iter_ortvalue(self): 

140 """ 

141 Iterates over the datasets by drawing 

142 *batch_size* consecutive observations. 

143 This iterator is slow as it copies the data of every 

144 batch. The function yields :epkg:`C_OrtValue`. 

145 """ 

146 N = 0 

147 b = len(self) - self.batch_size 

148 if self.w_ort is None: 

149 if b <= 0 or self.batch_size <= 0: 

150 yield (self.X_ort, self.y_ort) 

151 else: 

152 i = -1 

153 while N < len(self): 

154 i = self._next_iter(i) 

155 N += self.batch_size 

156 xp = self.X_np[i:i + self.batch_size] 

157 yp = self.y_np[i:i + self.batch_size] 

158 yield ( 

159 numpy_to_ort_value(xp, self.device), 

160 numpy_to_ort_value(yp, self.device)) 

161 else: 

162 if b <= 0 or self.batch_size <= 0: 

163 yield (self.X_ort, self.y_ort, self.w_ort) 

164 else: 

165 i = -1 

166 while N < len(self): 

167 i = self._next_iter(i) 

168 N += self.batch_size 

169 xp = self.X_np[i:i + self.batch_size] 

170 yp = self.y_np[i:i + self.batch_size] 

171 wp = self.w_np[i:i + self.batch_size] 

172 yield ( 

173 numpy_to_ort_value(xp, self.device), 

174 numpy_to_ort_value(yp, self.device), 

175 numpy_to_ort_value(wp, self.device)) 

176 

177 def iter_bind(self, bind, names): 

178 """ 

179 Iterates over the datasets by drawing 

180 *batch_size* consecutive observations. 

181 Modifies a bind structure. 

182 """ 

183 if len(names) not in (3, 4): 

184 raise NotImplementedError( 

185 "The dataloader expects three (feature name, label name, " 

186 "learning rate) or (feature name, label name, sample_weight, " 

187 "learning rate), not %r." % names) 

188 

189 n_col_x = self.desc[0][0][1] 

190 n_col_y = self.desc[1][0][1] 

191 size_x = self.desc[0][1].itemsize 

192 size_y = self.desc[1][1].itemsize 

193 size_w = None if len(self.desc) <= 2 else self.desc[2][1].itemsize 

194 

195 def local_bind(bind, offset, n): 

196 # This function assumes the data is contiguous. 

197 shape_X = (n, n_col_x) 

198 shape_y = (n, n_col_y) 

199 

200 try: 

201 bind.bind_input( 

202 names[0], self.device, self.desc[0][1], shape_X, 

203 self.X_ort.data_ptr() + offset * n_col_x * size_x) 

204 except RuntimeError as e: # pragma: no cover 

205 raise RuntimeError( 

206 "Unable to bind data input (X) %r, device=%r desc=%r " 

207 "data_ptr=%r offset=%r n_col_x=%r size_x=%r " 

208 "type(bind)=%r" % ( 

209 names[0], self.device, self.desc[0][1], 

210 self.X_ort.data_ptr(), offset, n_col_x, size_x, 

211 type(bind))) from e 

212 try: 

213 bind.bind_input( 

214 names[1], self.device, self.desc[1][1], shape_y, 

215 self.y_ort.data_ptr() + offset * n_col_y * size_y) 

216 except RuntimeError as e: # pragma: no cover 

217 raise RuntimeError( 

218 "Unable to bind data input (y) %r, device=%r desc=%r " 

219 "data_ptr=%r offset=%r n_col_y=%r size_y=%r " 

220 "type(bind)=%r" % ( 

221 names[1], self.device, self.desc[1][1], 

222 self.y_ort.data_ptr(), offset, n_col_y, size_y, 

223 type(bind))) from e 

224 

225 def local_bindw(bind, offset, n): 

226 # This function assumes the data is contiguous. 

227 shape_w = (n, ) 

228 

229 bind.bind_input( 

230 names[2], self.device, self.desc[2][1], shape_w, 

231 self.w_ort.data_ptr() + offset * size_w) 

232 

233 N = 0 

234 b = len(self) - self.batch_size 

235 if self.w_ort is None: 

236 if b <= 0 or self.batch_size <= 0: 

237 shape_x = self.desc[0][0] 

238 local_bind(bind, 0, shape_x[0]) 

239 yield shape_x[0] 

240 else: 

241 n = self.batch_size 

242 i = -1 

243 while N < len(self): 

244 i = self._next_iter(i) 

245 N += self.batch_size 

246 local_bind(bind, i, n) 

247 yield n 

248 else: 

249 if b <= 0 or self.batch_size <= 0: 

250 shape_x = self.desc[0][0] 

251 local_bind(bind, 0, shape_x[0]) 

252 local_bindw(bind, 0, shape_x[0]) 

253 yield shape_x[0] 

254 else: 

255 n = self.batch_size 

256 i = -1 

257 while N < len(self): 

258 i = self._next_iter(i) 

259 N += self.batch_size 

260 local_bind(bind, i, n) 

261 local_bindw(bind, i, n) 

262 yield n 

263 

264 @property 

265 def data_np(self): 

266 "Returns a tuple of the datasets in numpy." 

267 if self.w_np is None: 

268 return self.X_np, self.y_np 

269 return self.X_np, self.y_np, self.w_np 

270 

271 @property 

272 def data_ort(self): 

273 "Returns a tuple of the datasets in onnxruntime C_OrtValue." 

274 if self.w_ort is None: 

275 return self.X_ort, self.y_ort 

276 return self.X_ort, self.y_ort, self.w_ort