Coverage for onnxcustom/training/optimizers.py: 98%

141 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 01:42 +0100

1""" 

2@file 

3@brief Optimizer with :epkg:`onnxruntime-training`. 

4""" 

5import numpy 

6from onnxruntime import ( # pylint: disable=E0611 

7 TrainingParameters, SessionOptions, TrainingSession) 

8from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

9 OrtValue as C_OrtValue, SessionIOBinding as C_IOBinding) 

10from ..utils.onnxruntime_helper import ( 

11 numpy_to_ort_value, device_to_providers) 

12from .data_loader import OrtDataLoader 

13from .excs import ConvergenceError, EvaluationError 

14from ._base_estimator import BaseEstimator 

15 

16 

17class OrtGradientOptimizer(BaseEstimator): 

18 """ 

19 Implements a simple :epkg:`Stochastic Gradient Descent` 

20 with :epkg:`onnxruntime-training`. 

21 

22 :param model_onnx: onnx graph to train 

23 :param weights_to_train: names of initializers to be optimized 

24 :param loss_output_name: name of the loss output 

25 :param max_iter: number of training iterations 

26 :param training_optimizer_name: optimizing algorithm 

27 :param batch_size: batch size (see class *DataLoader*) 

28 :param learning_rate: a name or a learning rate instance or a float, 

29 see module :mod:`onnxcustom.training.sgd_learning_rate` 

30 :param device: device as :epkg:`C_OrtDevice` or a string 

31 representing this device 

32 :param warm_start: when set to True, reuse the solution of the previous 

33 call to fit as initialization, otherwise, just erase the previous 

34 solution. 

35 :param verbose: use :epkg:`tqdm` to display the training progress 

36 :param validation_every: validation with a test set every 

37 *validation_every* iterations 

38 :param saved_gradient: if not None, a filename, 

39 the optimizer saves the gradient into it 

40 :param sample_weight_name: name of the sample weight input 

41 

42 Once initialized, the class creates the attribute 

43 `train_session_` which holds an instance of :ref:`l-ort-training-session`. 

44 

45 See example :ref:`l-orttraining-nn-gpu`. 

46 """ 

47 

48 def __init__(self, model_onnx, weights_to_train, loss_output_name='loss', 

49 max_iter=100, training_optimizer_name='SGDOptimizer', 

50 batch_size=10, learning_rate='SGD', 

51 device='cpu', warm_start=False, verbose=0, 

52 validation_every=0.1, saved_gradient=None, 

53 sample_weight_name="weight"): 

54 BaseEstimator.__init__(self, model_onnx, learning_rate, device) 

55 self.batch_size = batch_size 

56 self.weights_to_train = weights_to_train 

57 self.loss_output_name = loss_output_name 

58 self.training_optimizer_name = training_optimizer_name 

59 self.verbose = verbose 

60 self.max_iter = max_iter 

61 self.warm_start = warm_start 

62 self.saved_gradient = saved_gradient 

63 self.sample_weight_name = sample_weight_name 

64 if validation_every < 1: 

65 self.validation_every = int(self.max_iter * validation_every) 

66 else: 

67 self.validation_every = validation_every # pragma: no cover 

68 if self.learning_rate.needs_grad: 

69 raise NotImplementedError( 

70 "Any weight update involving past gradient is " 

71 "not implemented (class %r)." 

72 "" % self.learning_rate.__class__.__name__) 

73 

74 def fit(self, X, y, sample_weight=None, X_val=None, y_val=None, 

75 use_numpy=False): 

76 """ 

77 Trains the model. 

78 

79 :param X: features 

80 :param y: expected output 

81 :param sample_weight: sample weight if any 

82 :param X_val: evaluation dataset 

83 :param y_val: evaluation dataset 

84 :param use_numpy: if True, slow iterator using numpy, 

85 otherwise, minimizes copy 

86 :return: self 

87 """ 

88 input_names = [i.name for i in self.model_onnx.graph.input] 

89 if ((len(input_names) == 2 and sample_weight is not None) or 

90 (len(input_names) == 3 and sample_weight is None)): 

91 raise RuntimeError( # pragma: no cover 

92 "Number of inputs should be 2 if sample_weight is None " 

93 "or 3 if not None but it is %d." % len(input_names)) 

94 self.train_session_ = self._create_training_session( 

95 self.model_onnx, self.weights_to_train, 

96 loss_output_name=self.loss_output_name, 

97 training_optimizer_name=self.training_optimizer_name, 

98 device=self.device) 

99 

100 if not self.warm_start: 

101 state = self.get_state() 

102 new_state = {} 

103 for k, v in state.items(): 

104 if len(v.shape) > 0: 

105 new_state[k] = numpy.random.randn(*v.shape).astype(v.dtype) 

106 else: 

107 f = numpy.random.randn(1) 

108 f = f.astype(v.dtype) 

109 new_state[k] = f 

110 self.set_state(new_state) 

111 

112 data_loader = OrtDataLoader( 

113 X, y, sample_weight=sample_weight, 

114 batch_size=self.batch_size, device=self.device) 

115 if X_val is not None: 

116 data_loader_val = OrtDataLoader( 

117 X_val, y_val, batch_size=X_val.shape[0], device=self.device, 

118 random_iter=False) 

119 else: 

120 data_loader_val = None 

121 

122 self.learning_rate.init_learning_rate() 

123 self.input_names_ = [i.name for i in self.train_session_.get_inputs()] 

124 self.output_names_ = [ 

125 o.name for o in self.train_session_.get_outputs()] 

126 self.loss_index_ = self.output_names_.index(self.loss_output_name) 

127 

128 bind = self.train_session_.io_binding()._iobinding 

129 

130 if self.verbose > 0: # pragma: no cover 

131 from tqdm import tqdm # pylint: disable=C0415 

132 loop = tqdm(range(self.max_iter)) 

133 else: 

134 loop = range(self.max_iter) 

135 

136 self.train_losses_ = [] 

137 self.validation_losses_ = [] 

138 lr = self.learning_rate.value 

139 for it in loop: 

140 lr_alive = numpy.array([lr / self.batch_size], dtype=numpy.float32) 

141 ort_lr = numpy_to_ort_value(lr_alive, self.device) 

142 loss = self._iteration(data_loader, ort_lr, 

143 bind, use_numpy=use_numpy, 

144 sample_weight=sample_weight is not None) 

145 lr = self.learning_rate.update_learning_rate(it).value 

146 if self.verbose > 1: # pragma: no cover 

147 loop.set_description( 

148 "loss=%1.3g lr=%1.3g " # pylint: disable=E1307 

149 "lrn=%1.3g" % ( 

150 loss, lr, lr_alive[0])) 

151 self.train_losses_.append(loss) 

152 if (data_loader_val is not None and 

153 (it + 1) % self.validation_every == 0): 

154 self.validation_losses_.append( 

155 self._evaluation(data_loader_val, bind)) 

156 self.trained_coef_ = self.train_session_.get_state() 

157 return self 

158 

159 def _bind_input_ortvalue(self, name, bind, c_ortvalue): 

160 """ 

161 Binds :epkg:`C_OrtValue` to the structure used by 

162 :epkg:`InferenceSession` to run inference. 

163 

164 :param name: str 

165 :param bind: python structure 

166 :param c_ortvalue: C structure for OrtValue (:epkg:`C_OrtValue`), 

167 it can be also a numpy array 

168 """ 

169 if not isinstance(bind, C_IOBinding): 

170 raise TypeError( # pragma: no cover 

171 f"Unexpected type {type(bind)!r}.") 

172 if isinstance(c_ortvalue, C_OrtValue): 

173 bind.bind_ortvalue_input(name, c_ortvalue) 

174 elif isinstance(c_ortvalue, numpy.ndarray): 

175 # This fails on linux with int64. 

176 bind.bind_input( 

177 name, self.device, c_ortvalue.dtype, c_ortvalue.shape, 

178 c_ortvalue.__array_interface__['data'][0]) 

179 else: 

180 raise TypeError( # pragma: no cover 

181 f"Unable to bind type {type(c_ortvalue)!r} for name {name!r}.") 

182 

183 def _iteration(self, data_loader, ort_lr, bind, use_numpy, sample_weight): 

184 actual_losses = [] 

185 

186 bind.bind_output('loss', self.device) 

187 idx = 3 if sample_weight else 2 

188 

189 if use_numpy: 

190 # onnxruntime does not copy the data, so the numpy 

191 # array must remain alive all along the iteration 

192 lr_alive = ort_lr.numpy() 

193 self._bind_input_ortvalue( 

194 self.input_names_[idx], bind, lr_alive) 

195 

196 # Slow iterations. 

197 for it in data_loader.iter_numpy(): 

198 if len(it) == 2: 

199 data, target = it 

200 self._bind_input_ortvalue( 

201 self.input_names_[0], bind, data) 

202 self._bind_input_ortvalue( 

203 self.input_names_[1], bind, target) 

204 else: 

205 data, target, weight = it 

206 self._bind_input_ortvalue( 

207 self.input_names_[0], bind, data) 

208 self._bind_input_ortvalue( 

209 self.input_names_[1], bind, target) 

210 self._bind_input_ortvalue( 

211 self.input_names_[2], bind, weight) 

212 

213 self.train_session_._sess.run_with_iobinding(bind, None) 

214 loss = bind.get_outputs()[0].numpy() 

215 if numpy.isinf(loss) or numpy.isnan(loss): 

216 raise ConvergenceError( 

217 "Loss is nan, learning_rate=%r, " 

218 "the gradient descent has failed " 

219 "(past losses=%r)." % ( 

220 ort_lr.numpy(), 

221 [float(v[0]) for v in ( 

222 actual_losses if len(actual_losses) < 5 

223 else actual_losses[-5:])])) 

224 actual_losses.append(loss / data.shape[0]) 

225 else: 

226 self._bind_input_ortvalue(self.input_names_[idx], bind, ort_lr) 

227 

228 # Fast iterations 

229 # Slow iterations. 

230 for batch_size in data_loader.iter_bind(bind, self.input_names_): 

231 self.train_session_._sess.run_with_iobinding(bind, None) 

232 # We copy the predicted output as well which is not needed. 

233 loss = bind.get_outputs()[0].numpy() 

234 if numpy.isinf(loss) or numpy.isnan(loss): 

235 raise ConvergenceError( 

236 "Loss is nan or infinite, learning_rate=%r, " 

237 "the gradient descent has failed " 

238 "(past losses=%r)." % ( 

239 ort_lr.numpy(), 

240 [float(v[0]) for v in ( 

241 actual_losses if len(actual_losses) < 5 

242 else actual_losses[-5:])])) 

243 actual_losses.append(loss / batch_size) 

244 

245 return numpy.array(actual_losses).mean() 

246 

247 def _evaluation(self, data_loader, bind): 

248 lr_alive = numpy.array([0], dtype=numpy.float32) 

249 self._bind_input_ortvalue(self.input_names_[2], bind, lr_alive) 

250 bind.bind_output('loss', self.device) 

251 

252 actual_losses = [] 

253 total = 0 

254 for batch_size in data_loader.iter_bind(bind, self.input_names_): 

255 self.train_session_._sess.run_with_iobinding(bind, None) 

256 outputs = bind.copy_outputs_to_cpu() 

257 if numpy.isinf(outputs[0]) or numpy.isnan(outputs[0]): 

258 raise EvaluationError( # pragma: no cover 

259 f"Loss is nan or infinite ({outputs[0]!r}), evaluation has failed.") 

260 actual_losses.append(outputs[0]) 

261 total += batch_size 

262 return numpy.array(actual_losses).sum() / max(total, 1) 

263 

264 def _create_training_session( 

265 self, training_onnx, weights_to_train, 

266 loss_output_name='loss', 

267 training_optimizer_name='SGDOptimizer', 

268 device='cpu'): 

269 """ 

270 Creates an instance of :epkg:`TrainingSession`. 

271 

272 :param training_onnx: an ONNX graph with a loss function 

273 :param weights_to_train: list of initializer names to optimize 

274 :param loss_output_name: output name for the loss 

275 :param training_optimizer_name: optimizer name 

276 :param device: one :epkg:`C_OrtDevice` or a string 

277 :return: an instance of :epkg:`TrainingSession` 

278 """ 

279 if training_optimizer_name != 'SGDOptimizer': 

280 raise NotImplementedError( 

281 "Only the SGDOptimizer is implemented not %r." 

282 "" % training_optimizer_name) 

283 ort_parameters = TrainingParameters() 

284 ort_parameters.loss_output_name = loss_output_name 

285 ort_parameters.use_mixed_precision = False 

286 # ort_parameters.world_rank = -1 

287 # ort_parameters.world_size = 1 

288 # ort_parameters.gradient_accumulation_steps = 1 

289 # ort_parameters.allreduce_post_accumulation = False 

290 # ort_parameters.deepspeed_zero_stage = 0 

291 # ort_parameters.enable_grad_norm_clip = False 

292 # ort_parameters.set_gradients_as_graph_outputs = False 

293 # ort_parameters.use_memory_efficient_gradient = False 

294 # ort_parameters.enable_adasum = False 

295 if self.saved_gradient is not None: 

296 name = self.saved_gradient 

297 name2 = name + ".training.onnx" 

298 ort_parameters.model_with_gradient_graph_path = name 

299 ort_parameters.model_with_training_graph_path = name2 

300 

301 output_types = {} 

302 for output in training_onnx.graph.output: 

303 output_types[output.name] = output.type.tensor_type 

304 

305 ort_parameters.weights_to_train = set(weights_to_train) 

306 ort_parameters.training_optimizer_name = training_optimizer_name 

307 # ort_parameters.lr_params_feed_name = lr_params_feed_name 

308 

309 ort_parameters.optimizer_attributes_map = { 

310 name: {} for name in weights_to_train} 

311 ort_parameters.optimizer_int_attributes_map = { 

312 name: {} for name in weights_to_train} 

313 

314 session_options = SessionOptions() 

315 session_options.log_severity_level = 4 

316 session_options.log_verbosity_level = 4 

317 # session_options.use_deterministic_compute = True 

318 

319 providers = device_to_providers(self.device) 

320 session = TrainingSession( 

321 training_onnx.SerializeToString(), ort_parameters, session_options, 

322 providers=providers) 

323 

324 return session 

325 

326 def get_state(self): 

327 """ 

328 Returns the trained weights. 

329 """ 

330 if not hasattr(self, 'train_session_'): 

331 if hasattr(self, 'trained_coef_'): 

332 return self.trained_coef_ 

333 raise AttributeError("Method fit must be called before.") 

334 return self.train_session_.get_state() 

335 

336 def get_trained_onnx(self, model=None): 

337 """ 

338 Returns the trained onnx graph, the initial graph 

339 modified by replacing the initializers with the trained 

340 weights. If model is not specified, it uses the model 

341 given as an argument to this class. This graph outputs 

342 the loss and not the predictions. Parameter *model* 

343 can be used to use the graph before loss was added 

344 and then the returned graph will produce the predictions. 

345 

346 :param model: replace the weights in another graph 

347 than the training graph 

348 :return: onnx graph 

349 """ 

350 return self._get_trained_onnx(self.get_state(), model=model) 

351 

352 def set_state(self, state): 

353 """ 

354 Changes the trained weights. 

355 """ 

356 if not hasattr(self, 'train_session_'): 

357 raise AttributeError( # pragma: no cover 

358 "Method fit must be called before.") 

359 return self.train_session_.load_state(state)