Coverage for onnxcustom/training/optimizers.py: 98%
141 statements
« prev ^ index » next coverage.py v7.0.5, created at 2023-01-17 01:42 +0100
« prev ^ index » next coverage.py v7.0.5, created at 2023-01-17 01:42 +0100
1"""
2@file
3@brief Optimizer with :epkg:`onnxruntime-training`.
4"""
5import numpy
6from onnxruntime import ( # pylint: disable=E0611
7 TrainingParameters, SessionOptions, TrainingSession)
8from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
9 OrtValue as C_OrtValue, SessionIOBinding as C_IOBinding)
10from ..utils.onnxruntime_helper import (
11 numpy_to_ort_value, device_to_providers)
12from .data_loader import OrtDataLoader
13from .excs import ConvergenceError, EvaluationError
14from ._base_estimator import BaseEstimator
17class OrtGradientOptimizer(BaseEstimator):
18 """
19 Implements a simple :epkg:`Stochastic Gradient Descent`
20 with :epkg:`onnxruntime-training`.
22 :param model_onnx: onnx graph to train
23 :param weights_to_train: names of initializers to be optimized
24 :param loss_output_name: name of the loss output
25 :param max_iter: number of training iterations
26 :param training_optimizer_name: optimizing algorithm
27 :param batch_size: batch size (see class *DataLoader*)
28 :param learning_rate: a name or a learning rate instance or a float,
29 see module :mod:`onnxcustom.training.sgd_learning_rate`
30 :param device: device as :epkg:`C_OrtDevice` or a string
31 representing this device
32 :param warm_start: when set to True, reuse the solution of the previous
33 call to fit as initialization, otherwise, just erase the previous
34 solution.
35 :param verbose: use :epkg:`tqdm` to display the training progress
36 :param validation_every: validation with a test set every
37 *validation_every* iterations
38 :param saved_gradient: if not None, a filename,
39 the optimizer saves the gradient into it
40 :param sample_weight_name: name of the sample weight input
42 Once initialized, the class creates the attribute
43 `train_session_` which holds an instance of :ref:`l-ort-training-session`.
45 See example :ref:`l-orttraining-nn-gpu`.
46 """
48 def __init__(self, model_onnx, weights_to_train, loss_output_name='loss',
49 max_iter=100, training_optimizer_name='SGDOptimizer',
50 batch_size=10, learning_rate='SGD',
51 device='cpu', warm_start=False, verbose=0,
52 validation_every=0.1, saved_gradient=None,
53 sample_weight_name="weight"):
54 BaseEstimator.__init__(self, model_onnx, learning_rate, device)
55 self.batch_size = batch_size
56 self.weights_to_train = weights_to_train
57 self.loss_output_name = loss_output_name
58 self.training_optimizer_name = training_optimizer_name
59 self.verbose = verbose
60 self.max_iter = max_iter
61 self.warm_start = warm_start
62 self.saved_gradient = saved_gradient
63 self.sample_weight_name = sample_weight_name
64 if validation_every < 1:
65 self.validation_every = int(self.max_iter * validation_every)
66 else:
67 self.validation_every = validation_every # pragma: no cover
68 if self.learning_rate.needs_grad:
69 raise NotImplementedError(
70 "Any weight update involving past gradient is "
71 "not implemented (class %r)."
72 "" % self.learning_rate.__class__.__name__)
74 def fit(self, X, y, sample_weight=None, X_val=None, y_val=None,
75 use_numpy=False):
76 """
77 Trains the model.
79 :param X: features
80 :param y: expected output
81 :param sample_weight: sample weight if any
82 :param X_val: evaluation dataset
83 :param y_val: evaluation dataset
84 :param use_numpy: if True, slow iterator using numpy,
85 otherwise, minimizes copy
86 :return: self
87 """
88 input_names = [i.name for i in self.model_onnx.graph.input]
89 if ((len(input_names) == 2 and sample_weight is not None) or
90 (len(input_names) == 3 and sample_weight is None)):
91 raise RuntimeError( # pragma: no cover
92 "Number of inputs should be 2 if sample_weight is None "
93 "or 3 if not None but it is %d." % len(input_names))
94 self.train_session_ = self._create_training_session(
95 self.model_onnx, self.weights_to_train,
96 loss_output_name=self.loss_output_name,
97 training_optimizer_name=self.training_optimizer_name,
98 device=self.device)
100 if not self.warm_start:
101 state = self.get_state()
102 new_state = {}
103 for k, v in state.items():
104 if len(v.shape) > 0:
105 new_state[k] = numpy.random.randn(*v.shape).astype(v.dtype)
106 else:
107 f = numpy.random.randn(1)
108 f = f.astype(v.dtype)
109 new_state[k] = f
110 self.set_state(new_state)
112 data_loader = OrtDataLoader(
113 X, y, sample_weight=sample_weight,
114 batch_size=self.batch_size, device=self.device)
115 if X_val is not None:
116 data_loader_val = OrtDataLoader(
117 X_val, y_val, batch_size=X_val.shape[0], device=self.device,
118 random_iter=False)
119 else:
120 data_loader_val = None
122 self.learning_rate.init_learning_rate()
123 self.input_names_ = [i.name for i in self.train_session_.get_inputs()]
124 self.output_names_ = [
125 o.name for o in self.train_session_.get_outputs()]
126 self.loss_index_ = self.output_names_.index(self.loss_output_name)
128 bind = self.train_session_.io_binding()._iobinding
130 if self.verbose > 0: # pragma: no cover
131 from tqdm import tqdm # pylint: disable=C0415
132 loop = tqdm(range(self.max_iter))
133 else:
134 loop = range(self.max_iter)
136 self.train_losses_ = []
137 self.validation_losses_ = []
138 lr = self.learning_rate.value
139 for it in loop:
140 lr_alive = numpy.array([lr / self.batch_size], dtype=numpy.float32)
141 ort_lr = numpy_to_ort_value(lr_alive, self.device)
142 loss = self._iteration(data_loader, ort_lr,
143 bind, use_numpy=use_numpy,
144 sample_weight=sample_weight is not None)
145 lr = self.learning_rate.update_learning_rate(it).value
146 if self.verbose > 1: # pragma: no cover
147 loop.set_description(
148 "loss=%1.3g lr=%1.3g " # pylint: disable=E1307
149 "lrn=%1.3g" % (
150 loss, lr, lr_alive[0]))
151 self.train_losses_.append(loss)
152 if (data_loader_val is not None and
153 (it + 1) % self.validation_every == 0):
154 self.validation_losses_.append(
155 self._evaluation(data_loader_val, bind))
156 self.trained_coef_ = self.train_session_.get_state()
157 return self
159 def _bind_input_ortvalue(self, name, bind, c_ortvalue):
160 """
161 Binds :epkg:`C_OrtValue` to the structure used by
162 :epkg:`InferenceSession` to run inference.
164 :param name: str
165 :param bind: python structure
166 :param c_ortvalue: C structure for OrtValue (:epkg:`C_OrtValue`),
167 it can be also a numpy array
168 """
169 if not isinstance(bind, C_IOBinding):
170 raise TypeError( # pragma: no cover
171 f"Unexpected type {type(bind)!r}.")
172 if isinstance(c_ortvalue, C_OrtValue):
173 bind.bind_ortvalue_input(name, c_ortvalue)
174 elif isinstance(c_ortvalue, numpy.ndarray):
175 # This fails on linux with int64.
176 bind.bind_input(
177 name, self.device, c_ortvalue.dtype, c_ortvalue.shape,
178 c_ortvalue.__array_interface__['data'][0])
179 else:
180 raise TypeError( # pragma: no cover
181 f"Unable to bind type {type(c_ortvalue)!r} for name {name!r}.")
183 def _iteration(self, data_loader, ort_lr, bind, use_numpy, sample_weight):
184 actual_losses = []
186 bind.bind_output('loss', self.device)
187 idx = 3 if sample_weight else 2
189 if use_numpy:
190 # onnxruntime does not copy the data, so the numpy
191 # array must remain alive all along the iteration
192 lr_alive = ort_lr.numpy()
193 self._bind_input_ortvalue(
194 self.input_names_[idx], bind, lr_alive)
196 # Slow iterations.
197 for it in data_loader.iter_numpy():
198 if len(it) == 2:
199 data, target = it
200 self._bind_input_ortvalue(
201 self.input_names_[0], bind, data)
202 self._bind_input_ortvalue(
203 self.input_names_[1], bind, target)
204 else:
205 data, target, weight = it
206 self._bind_input_ortvalue(
207 self.input_names_[0], bind, data)
208 self._bind_input_ortvalue(
209 self.input_names_[1], bind, target)
210 self._bind_input_ortvalue(
211 self.input_names_[2], bind, weight)
213 self.train_session_._sess.run_with_iobinding(bind, None)
214 loss = bind.get_outputs()[0].numpy()
215 if numpy.isinf(loss) or numpy.isnan(loss):
216 raise ConvergenceError(
217 "Loss is nan, learning_rate=%r, "
218 "the gradient descent has failed "
219 "(past losses=%r)." % (
220 ort_lr.numpy(),
221 [float(v[0]) for v in (
222 actual_losses if len(actual_losses) < 5
223 else actual_losses[-5:])]))
224 actual_losses.append(loss / data.shape[0])
225 else:
226 self._bind_input_ortvalue(self.input_names_[idx], bind, ort_lr)
228 # Fast iterations
229 # Slow iterations.
230 for batch_size in data_loader.iter_bind(bind, self.input_names_):
231 self.train_session_._sess.run_with_iobinding(bind, None)
232 # We copy the predicted output as well which is not needed.
233 loss = bind.get_outputs()[0].numpy()
234 if numpy.isinf(loss) or numpy.isnan(loss):
235 raise ConvergenceError(
236 "Loss is nan or infinite, learning_rate=%r, "
237 "the gradient descent has failed "
238 "(past losses=%r)." % (
239 ort_lr.numpy(),
240 [float(v[0]) for v in (
241 actual_losses if len(actual_losses) < 5
242 else actual_losses[-5:])]))
243 actual_losses.append(loss / batch_size)
245 return numpy.array(actual_losses).mean()
247 def _evaluation(self, data_loader, bind):
248 lr_alive = numpy.array([0], dtype=numpy.float32)
249 self._bind_input_ortvalue(self.input_names_[2], bind, lr_alive)
250 bind.bind_output('loss', self.device)
252 actual_losses = []
253 total = 0
254 for batch_size in data_loader.iter_bind(bind, self.input_names_):
255 self.train_session_._sess.run_with_iobinding(bind, None)
256 outputs = bind.copy_outputs_to_cpu()
257 if numpy.isinf(outputs[0]) or numpy.isnan(outputs[0]):
258 raise EvaluationError( # pragma: no cover
259 f"Loss is nan or infinite ({outputs[0]!r}), evaluation has failed.")
260 actual_losses.append(outputs[0])
261 total += batch_size
262 return numpy.array(actual_losses).sum() / max(total, 1)
264 def _create_training_session(
265 self, training_onnx, weights_to_train,
266 loss_output_name='loss',
267 training_optimizer_name='SGDOptimizer',
268 device='cpu'):
269 """
270 Creates an instance of :epkg:`TrainingSession`.
272 :param training_onnx: an ONNX graph with a loss function
273 :param weights_to_train: list of initializer names to optimize
274 :param loss_output_name: output name for the loss
275 :param training_optimizer_name: optimizer name
276 :param device: one :epkg:`C_OrtDevice` or a string
277 :return: an instance of :epkg:`TrainingSession`
278 """
279 if training_optimizer_name != 'SGDOptimizer':
280 raise NotImplementedError(
281 "Only the SGDOptimizer is implemented not %r."
282 "" % training_optimizer_name)
283 ort_parameters = TrainingParameters()
284 ort_parameters.loss_output_name = loss_output_name
285 ort_parameters.use_mixed_precision = False
286 # ort_parameters.world_rank = -1
287 # ort_parameters.world_size = 1
288 # ort_parameters.gradient_accumulation_steps = 1
289 # ort_parameters.allreduce_post_accumulation = False
290 # ort_parameters.deepspeed_zero_stage = 0
291 # ort_parameters.enable_grad_norm_clip = False
292 # ort_parameters.set_gradients_as_graph_outputs = False
293 # ort_parameters.use_memory_efficient_gradient = False
294 # ort_parameters.enable_adasum = False
295 if self.saved_gradient is not None:
296 name = self.saved_gradient
297 name2 = name + ".training.onnx"
298 ort_parameters.model_with_gradient_graph_path = name
299 ort_parameters.model_with_training_graph_path = name2
301 output_types = {}
302 for output in training_onnx.graph.output:
303 output_types[output.name] = output.type.tensor_type
305 ort_parameters.weights_to_train = set(weights_to_train)
306 ort_parameters.training_optimizer_name = training_optimizer_name
307 # ort_parameters.lr_params_feed_name = lr_params_feed_name
309 ort_parameters.optimizer_attributes_map = {
310 name: {} for name in weights_to_train}
311 ort_parameters.optimizer_int_attributes_map = {
312 name: {} for name in weights_to_train}
314 session_options = SessionOptions()
315 session_options.log_severity_level = 4
316 session_options.log_verbosity_level = 4
317 # session_options.use_deterministic_compute = True
319 providers = device_to_providers(self.device)
320 session = TrainingSession(
321 training_onnx.SerializeToString(), ort_parameters, session_options,
322 providers=providers)
324 return session
326 def get_state(self):
327 """
328 Returns the trained weights.
329 """
330 if not hasattr(self, 'train_session_'):
331 if hasattr(self, 'trained_coef_'):
332 return self.trained_coef_
333 raise AttributeError("Method fit must be called before.")
334 return self.train_session_.get_state()
336 def get_trained_onnx(self, model=None):
337 """
338 Returns the trained onnx graph, the initial graph
339 modified by replacing the initializers with the trained
340 weights. If model is not specified, it uses the model
341 given as an argument to this class. This graph outputs
342 the loss and not the predictions. Parameter *model*
343 can be used to use the graph before loss was added
344 and then the returned graph will produce the predictions.
346 :param model: replace the weights in another graph
347 than the training graph
348 :return: onnx graph
349 """
350 return self._get_trained_onnx(self.get_state(), model=model)
352 def set_state(self, state):
353 """
354 Changes the trained weights.
355 """
356 if not hasattr(self, 'train_session_'):
357 raise AttributeError( # pragma: no cover
358 "Method fit must be called before.")
359 return self.train_session_.load_state(state)