Coverage for onnxcustom/training/ortgradient.py: 95%
330 statements
« prev ^ index » next coverage.py v7.0.5, created at 2023-01-17 01:42 +0100
« prev ^ index » next coverage.py v7.0.5, created at 2023-01-17 01:42 +0100
1# pylint: disable=E1101
2"""
3@file
4@brief Gradient with :epkg:`onnxruntime-training` forward backward.
5"""
6import os
7import logging
8import warnings
9from io import BytesIO
10import onnx
11from onnx.numpy_helper import to_array
12from onnxruntime import InferenceSession, RunOptions
13from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
14 SessionIOBinding, OrtValue as C_OrtValue)
15from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
16 TrainingAgent, OrtValueCache, OrtModuleGraphBuilder,
17 OrtModuleGraphBuilderConfiguration, OrtDevice,
18 TrainingGraphTransformerConfiguration, OrtValueVector,
19 PartialGraphExecutionState)
20from ..utils.orttraining_helper import get_train_initializer
23class OrtGradientForwardBackward:
24 """
25 Implements forward backward mechanism assuming the function
26 to train is defined by an ONNX graph.
28 :param onnx_model: onnx model
29 :param weights_to_train: names of the weights to train,
30 if None, all initializer of floats type are included in the list
31 :param input_names: input names or None for all
32 :param output_names: output names or None for all
33 :param class_name: name to give the class dynamically created
34 :param sess_options: see :epkg:`SessionOptions`
35 :param providers: see :epkg:`InferenceSession`
36 :param provider_options: see :epkg:`InferenceSession`
37 :param run_options: see :epkg:`RunOptions`
38 :param graph_builder_config:
39 see :epkg:`OrtModuleGraphBuilderConfiguration`
40 :param device_index: used for cuda (0 for `cuda:0`,
41 `cuda:1`, ...), 0 by default
42 :param enable_logging: enables logging while setting up the class
43 :param debug: to run extra verification while training
45 .. note::
46 The current implementation of :epkg:`onnxruntime` forces
47 the weights to train to appear in the alphabetical order.
48 The constructor checks that condition is verified.
50 .. warning::
51 This class does not consider subgraphs.
52 """
54 def __init__(self, onnx_model, weights_to_train=None,
55 input_names=None, output_names=None, class_name=None,
56 sess_options=None, providers=None,
57 provider_options=None, run_options=None,
58 graph_builder_config=None,
59 device_index=0, enable_logging=False, debug=False):
61 if weights_to_train is None:
62 weights_to_train = (
63 OrtGradientForwardBackward._select_initializer_names(
64 onnx_model))
65 if len(weights_to_train) == 0:
66 raise RuntimeError( # pragma: no cover
67 "Unable to guess the weights to train from initializers: "
68 "%r." % [i.name for i in onnx_model.graph.initializer])
70 self.onnx_model = onnx_model
71 self.input_names = input_names
72 self.output_names = output_names
73 self.weights_to_train = weights_to_train
74 self.device_index = device_index
75 self.enable_logging = enable_logging
76 self.class_name = (class_name if class_name is not None else
77 "OrtGradientForwardBackwardFunction_%d" % id(self))
79 self.provider_options = provider_options
80 self.sess_options = sess_options
81 self.providers = providers
82 self.run_options = run_options
83 self.graph_builder_config = graph_builder_config
84 self.debug = debug
86 # default
87 if self.weights_to_train is None:
88 raise ValueError( # pragma: no cover
89 "weights_to_train must be specified.")
90 if self.input_names is None:
91 self.input_names = [obj.name
92 for obj in self.onnx_model.graph.input]
93 if self.output_names is None:
94 self.output_names = [obj.name
95 for obj in self.onnx_model.graph.output]
96 if self.class_name is None:
97 self.class_name = f"TorchOrtFunction_{id(self)!r}" # pragma: no cover
98 if hasattr(self.providers, 'type'):
99 if self.providers.type != 'cpu':
100 self.device_index = self.providers.index
101 self.providers = self.providers.type
102 if self.providers in (None, 'cpu'):
103 self.providers = ["CPUExecutionProvider" for i in self.input_names]
104 if self.provider_options is None:
105 self.provider_options = [{} for i in self.input_names]
106 elif self.providers in ('cuda', 'cuda:0', 'gpu'):
107 self.providers = [
108 "CUDAExecutionProvider" for i in self.input_names]
109 if self.provider_options is None:
110 self.provider_options = [{} for i in self.input_names]
111 if self.provider_options is None:
112 self.provider_options = [{} for i in self.providers]
114 if list(sorted(self.weights_to_train)) != self.weights_to_train:
115 raise ValueError( # pragma: no cover
116 "List of weights to train must be sorted but %r is not. "
117 "You shoud use function onnx_rename_weights to do that "
118 "before calling this class." % self.weights_to_train)
119 set_weights = set(self.weights_to_train)
120 if len(set_weights) != len(self.weights_to_train):
121 raise ValueError( # pragma: no cover
122 f"One weight is not unique in {self.weights_to_train!r}.")
123 found = []
124 for i in self.onnx_model.graph.initializer:
125 if i.name not in set_weights:
126 continue
127 found.append(i.name)
128 if len(found) != len(self.weights_to_train):
129 raise ValueError(
130 "One weight name in self.weights_to_train was not found in "
131 "the initializers %r found=%r init names=%r." % (
132 self.weights_to_train, found,
133 [i.name for i in self.onnx_model.graph.initializer]))
134 if found != self.weights_to_train:
135 raise ValueError(
136 "List of weights to train must be sorted and follow the "
137 "as the initializers in the graph. %r != %r."
138 "You shoud use function onnx_rename_weights to do that "
139 "before calling this class." % (
140 self.weights_to_train, found))
142 if any(map(lambda v: v not in ['CPUExecutionProvider',
143 'CUDAExecutionProvider'],
144 self.providers)):
145 raise ValueError(
146 f"Unexpected providers {self.providers!r} (providers={providers!r}).")
148 # complete initialisation
149 self._init_next()
151 @staticmethod
152 def _select_initializer_names(onnx_model):
153 """
154 Selects all initializers with float type.
156 :param onnx_model: ONNX graph
157 """
158 inits = get_train_initializer(onnx_model)
159 return list(inits)
161 def _init_next(self):
162 if self.enable_logging:
163 self._logger = logging.getLogger("onnxcustom")
164 else:
165 self._logger = None # pragma: no cover
166 if self.run_options is None:
167 self.run_options = RunOptions()
168 self.run_options.training_mode = True
170 if self.graph_builder_config is None:
171 initializer_names = [
172 i.name for i in self.onnx_model.graph.initializer]
173 input_names = [i.name for i in self.onnx_model.graph.input]
175 config = OrtModuleGraphBuilderConfiguration()
176 config.initializer_names = [init for init in initializer_names
177 if init in self.weights_to_train]
178 config.initializer_names_to_train = self.weights_to_train
179 config.input_names_require_grad = input_names
180 config.build_gradient_graph = True
182 if (len(config.initializer_names) != # noqa
183 len(config.initializer_names_to_train)):
184 raise RuntimeError( # pragma: no cover
185 "Unable to automatically fill "
186 "OrtModuleGraphBuilderConfiguration, mismatch between "
187 "%r and %r (initializer_names=%r)." % (
188 config.initializer_names,
189 config.initializer_names_to_train,
190 initializer_names))
192 p = TrainingGraphTransformerConfiguration()
193 config.graph_transformer_config = p
195 # config.enable_caching = True
196 # config.loglevel =
197 # config.use_memory_efficient_gradient = True
198 self.graph_builder_config = config
200 attributes = self._create_onnx_graphs()
201 attributes['__doc__'] = (
202 "Inherits from @see cl OrtGradientForwardBackwardFunction.")
203 attributes['__module__'] = (
204 OrtGradientForwardBackwardFunction.__module__)
205 self.cls_type_ = type(
206 self.class_name, (OrtGradientForwardBackwardFunction,),
207 attributes)
209 def new_instance(self):
210 """
211 Creates an instance of class `self.cls_type_`.
212 It implements methods *forward* and *backward*.
213 """
214 return self.cls_type_()
216 def __getstate__(self):
217 "Removes any non pickable attribute."
218 atts = [k for k in self.__dict__ if not k.endswith('_')
219 if k not in {'_logger', 'graph_builder_config',
220 'run_options'}]
221 state = {att: getattr(self, att) for att in atts}
222 state['run_options'] = None
223 state['graph_builder_config'] = None
224 return state
226 def __setstate__(self, state):
227 "Restores any non pickable attribute."
228 for att, v in state.items():
229 setattr(self, att, v)
230 self._init_next()
231 return self
233 def __repr__(self):
234 "usual"
235 return f"{self.__class__.__name__}(...)"
237 @staticmethod
238 def _repr_helper_(obj, indent=0):
239 "used to improve logging messages"
240 if obj is None:
241 return 'None'
242 rows = []
243 for c in sorted(dir(obj)):
244 if c[0] == '_':
245 continue
246 try:
247 value = getattr(obj, c)
248 except AttributeError: # pragma: no cover
249 continue
250 rows.append(f"{c}={value!r}")
252 if indent == 0:
253 return f"{obj.__class__.__name__}({', '.join(rows)})"
254 return "%s(\n %s)" % (
255 obj.__class__.__name__,
256 "\n ".join(rows))
258 @staticmethod
259 def _provider_name_to_device_type(provider_name):
260 if provider_name == 'CPUExecutionProvider':
261 return OrtDevice.cpu()
262 if provider_name == 'CUDAExecutionProvider': # pragma: no cover
263 return OrtDevice.cuda()
264 raise ValueError( # pragma: no cover
265 f'Unexpected provider name {provider_name!r}.')
267 def get_initializer(self, name, exc=True):
268 """
269 Returns an initializer as numpy arrays.
271 :param name: initializer name
272 :param exc: raises an exception if not found or return None
273 :return: the initializer as a :epkg:`C_OrtValue`
274 """
275 for init in self.onnx_model.graph.initializer:
276 if name == init.name:
277 return to_array(init)
278 if exc:
279 raise RuntimeError( # pragma: no cover
280 "Unable to find name %r in %r." % (
281 name,
282 list(i.name for i in self.onnx_model.graph.initializer)))
283 return None
285 def _create_onnx_graphs(self):
286 """
287 Creates forward and backward ONNX graph.
288 The new class has the following attributes:
290 * `__doc__`: doc string
291 * `__module__`: module name (this file)
292 * `_run_options`: see :epkg:`RunOptions`
293 * `_sess`: :epkg:`InferenceSession` with the original graph
294 * `_sess_eval`: :epkg:`InferenceSession` on the graph
295 with weights as inputs
296 * `_training_agent`: :epkg:`TrainingAgent`
297 * `_cache`: :epkg:`OrtValueCache`
298 * `_logger`: logger
299 * `_input_names`: input names
300 * `_debug`: use debug mode
301 * `_grad_input_names`: gradient input names
302 * `_output_names`: output names
303 * `_weights_to_train`: names of the weights to train
305 Training attributes
307 * `_bw_fetches_names`: bw_fetches_names,
308 * `_fw_outputs_device_info`: fw_outputs_device_info,
309 * `_bw_outputs_device_info`: bw_outputs_device_info,
310 * `_fw_no_grad_output_device_info`: fw_no_grad_output_device_info,
311 * `_graph_info`: graph_info}
313 Additional attributes added if *keep_model* is True:
315 * `_trained_onnx`: ONNX graph for the gradient
316 * `_optimized_pre_grad_model`: evaluation ONNX graph taking
317 weights as inputs
318 * `_graph_builder`: :epkg:`OrtModuleGraphBuilder`
319 """
320 logger = self._logger
321 if logger is not None:
322 logger.info("[OrtGradientForwardBackward] create training onnx")
323 logger.info("[OrtGradientForwardBackward] input_names=%r",
324 self.input_names)
325 logger.info("[OrtGradientForwardBackward] output_names=%r",
326 self.output_names)
327 logger.info("[OrtGradientForwardBackward] weights_to_train=%r",
328 self.weights_to_train)
330 builder = OrtModuleGraphBuilder()
332 if logger is not None:
333 cf = self.graph_builder_config.graph_transformer_config
334 cfp = cf.propagate_cast_ops_config
335 logger.info(
336 "[OrtGradientForwardBackward] "
337 "OrtModuleGraphBuilder.initialize")
338 logger.info(
339 "[OrtGradientForwardBackward] graph_builder_config=%s",
340 OrtGradientForwardBackward._repr_helper_(
341 self.graph_builder_config, indent=4))
342 logger.info(
343 "[OrtGradientForwardBackward] graph_builder_config."
344 "graph_transformer_config=%s",
345 OrtGradientForwardBackward._repr_helper_(cf, indent=4))
346 logger.info(
347 "[OrtGradientForwardBackward] graph_builder_config."
348 "graph_transformer_config.propagate_cast_ops_config=%s",
349 OrtGradientForwardBackward._repr_helper_(cfp, indent=4))
351 builder.initialize(
352 self.onnx_model.SerializeToString(),
353 self.graph_builder_config)
355 if logger is not None:
356 logger.info(
357 "[OrtGradientForwardBackward] OrtModuleGraphBuilder.build")
358 builder.build()
360 if logger is not None:
361 logger.info(
362 "[OrtGradientForwardBackward] OrtModuleGraphBuilder.get_model")
364 try:
365 train_onnx_model_serialized = builder.get_gradient_model()
366 except AttributeError:
367 # older version
368 train_onnx_model_serialized = builder.get_model()
369 try:
370 optimized_pre_grad_model = builder.get_forward_model()
371 except AttributeError:
372 # older version
373 optimized_pre_grad_model = builder.get_inference_optimized_model()
374 graph_info = builder.get_graph_info()
376 if logger is not None:
377 logger.info("[OrtGradientForwardBackward] graph_info=%s",
378 OrtGradientForwardBackward._repr_helper_(
379 graph_info, indent=4))
380 logger.info("[OrtGradientForwardBackward] create TrainSession")
381 logger.info("[OrtGradientForwardBackward] sess_options=%s",
382 OrtGradientForwardBackward._repr_helper_(
383 self.sess_options, indent=4))
384 logger.info(
385 "[OrtGradientForwardBackward] providers=%r", self.providers)
387 sess = InferenceSession(
388 train_onnx_model_serialized, sess_options=self.sess_options,
389 provider_options=self.provider_options, providers=self.providers)
391 if logger is not None:
392 logger.info("[OrtGradientForwardBackward] create InferenceSession")
394 sess_eval = InferenceSession(
395 optimized_pre_grad_model, sess_options=self.sess_options,
396 provider_options=self.provider_options, providers=self.providers)
398 if logger is not None:
399 logger.info("[OrtGradientForwardBackward] create training agent")
401 grad_input_names = [obj.name for obj in sess.get_inputs()]
402 bw_fetches_names = [obj.name for obj in sess.get_outputs()]
404 fw_outputs_device_info = [
405 OrtDevice(
406 OrtGradientForwardBackward._provider_name_to_device_type(i),
407 OrtDevice.default_memory(), self.device_index)
408 for i in self.providers]
409 bw_outputs_device_info = [
410 OrtDevice(
411 OrtGradientForwardBackward._provider_name_to_device_type(
412 self.providers[0]),
413 OrtDevice.default_memory(), self.device_index)
414 for i in bw_fetches_names]
415 fw_no_grad_output_device_info = [
416 OrtDevice(
417 OrtGradientForwardBackward._provider_name_to_device_type(
418 self.providers[0]),
419 OrtDevice.default_memory(), self.device_index)
420 for i in self.output_names]
422 try:
423 # onnxruntime>=1.12
424 training_agent = TrainingAgent(
425 sess._sess,
426 grad_input_names,
427 fw_outputs_device_info,
428 bw_fetches_names,
429 bw_outputs_device_info,
430 0)
431 except TypeError:
432 # onnxruntime<=1.11
433 training_agent = TrainingAgent(
434 sess._sess,
435 grad_input_names,
436 fw_outputs_device_info,
437 bw_fetches_names,
438 bw_outputs_device_info)
440 if logger is not None:
441 logger.info(
442 "[OrtGradientForwardBackward] instantiate dynamic class %r",
443 self.class_name)
444 logger.info(
445 "[OrtGradientForwardBackward] weights_to_train=%r",
446 self.weights_to_train)
447 logger.info(
448 "[OrtGradientForwardBackward] grad_input_names=%r",
449 grad_input_names)
450 logger.info(
451 "[OrtGradientForwardBackward] bw_fetches_names=%r",
452 bw_fetches_names)
453 logger.info(
454 "[OrtGradientForwardBackward] device_index=%r",
455 self.device_index)
456 devices = list(fw_outputs_device_info)
457 while len(devices) < len(grad_input_names):
458 devices.append(devices[-1])
460 trained_onnx = onnx.load(BytesIO(train_onnx_model_serialized))
461 onnx_loss = onnx.load(BytesIO(optimized_pre_grad_model))
462 for i, node in enumerate(trained_onnx.graph.node):
463 if node.name == '':
464 node.name = "N%d" % i
465 for i, node in enumerate(onnx_loss.graph.node):
466 if node.name == '':
467 node.name = "N%d" % i
469 kwargs = {
470 '_run_options': self.run_options,
471 '_sess': sess,
472 '_sess_eval': sess_eval,
473 '_training_agent': training_agent,
474 '_cache': OrtValueCache(),
475 '_logger': logger,
476 '_input_names': self.input_names,
477 '_grad_input_names': grad_input_names,
478 '_output_names': self.output_names,
479 '_bw_fetches_names': bw_fetches_names,
480 '_fw_outputs_device_info': fw_outputs_device_info,
481 '_bw_outputs_device_info': bw_outputs_device_info,
482 '_fw_no_grad_output_device_info': fw_no_grad_output_device_info,
483 '_weights_to_train': list(sorted(self.weights_to_train)),
484 '_graph_info': graph_info,
485 #
486 '_trained_onnx': trained_onnx,
487 '_optimized_pre_grad_model': onnx_loss,
488 '_graph_builder': builder,
489 '_devices': devices,
490 '_debug': self.debug
491 }
492 graph = kwargs['_trained_onnx'].graph
493 kwargs.update({
494 '_onx_inp': [o.name for o in graph.input],
495 '_onx_out': [o.name for o in graph.output]
496 })
498 if len(kwargs['_onx_inp']) != len(kwargs['_onx_out']):
499 raise RuntimeError( # pragma: no cover
500 "Gradient input and output are inconsistant: "
501 "%r != %r" % (kwargs['_onx_inp'], kwargs['_onx_out']))
502 return kwargs
505class OrtGradientForwardBackwardFunction:
506 """
507 Ancestor for a class implementing forward and backward
508 and dynamically created by @see cl OrtGradientForwardBackward.
510 Attributes stored in *forward* method:
511 * `saved_tensors_`: list of tensors to save during forward
512 and to retrieve during backward
513 * `state_`: current weights stored in :epkg:`PartialGraphExecutionState`
514 """
516 def __init__(self):
517 self.states_ = []
518 self.saved_tensors_ = None
520 @classmethod
521 def save_onnx_graph(cls, folder, prefix=None, suffix=None):
522 """
523 Saves onnx graph stored in this class.
524 """
525 if prefix is None:
526 prefix = '' # pragma: no cover
527 if suffix is None:
528 suffix = '' # pragma: no cover
529 if isinstance(folder, str) and not os.path.exists(folder):
530 raise FileNotFoundError( # pragma: no cover
531 f"Folder {folder!r} does not exist.")
532 saved = {}
533 for k, v in cls.__dict__.items():
534 if hasattr(v, "SerializeToString"):
535 if isinstance(folder, str):
536 name = f"{prefix}{cls.__name__}{suffix}.{k}.onnx"
537 filename = os.path.join(folder, name)
538 if os.path.exists(filename):
539 warnings.warn( # pragma: no cover
540 f"Filename {filename!r} already exists.")
541 with open(filename, "wb") as f:
542 f.write(v.SerializeToString())
543 saved[k] = filename
544 else:
545 saved[k] = v.SerializeToString()
546 elif hasattr(v, "save_onnx_graph"):
547 saved[k] = v.save_onnx_graph(
548 folder, prefix=prefix, suffix=f"{suffix}.{k}")
549 return saved
551 @staticmethod
552 def device_name(device):
553 """
554 Returns the device name of a device.
556 :param device: OrtDevice
557 :return: string
558 """
559 if device.device_type() == OrtDevice.cpu():
560 return 'Cpu'
561 if device.device_type() == OrtDevice.cuda(): # pragma: no cover
562 return 'Gpu'
563 raise RuntimeError( # pragma: no cover
564 f"Unexpected value for device type {device.device_type()!r}.")
566 @staticmethod
567 def input_to_ort(tensors, devices, debug):
568 "Converts a list of tensos into an :epkg:`OrtValueVector`."
569 def _validate_(tensors):
570 if any(map(
571 lambda tu: (
572 tu[0].device_name() !=
573 OrtGradientForwardBackwardFunction.device_name(
574 tu[1])),
575 zip(tensors, devices))):
576 raise RuntimeError( # pragma: no cover
577 "Not all inputs are on the same device %r != %r." % (
578 [OrtGradientForwardBackward.device_name(d)
579 for d in devices],
580 [x.device_name() for x in tensors]))
582 if isinstance(tensors, OrtValueVector):
583 if debug:
584 _validate_(tensors)
585 return tensors
586 if all(map(lambda t: isinstance(t, C_OrtValue), tensors)):
587 if debug:
588 _validate_(tensors)
589 vect = OrtValueVector()
590 vect.reserve(len(tensors))
591 for t in tensors:
592 if t is None:
593 raise NotImplementedError( # pragma: no cover
594 "Empty vector found.")
595 vect.push_back(t)
596 return vect
598 # generic case
599 vect = OrtValueVector()
600 vect.reserve(len(tensors))
601 for t, dev in zip(tensors, devices):
602 if t is None:
603 # if gradient then
604 # grad_output = torch.zeros(shape, device=device, dtype=dtype)
605 raise NotImplementedError( # pragma: no cover
606 "Empty vector found.")
607 if not t.data.contiguous:
608 t = t.as_contiguous() # pragma: no cover
609 vect.push_back(C_OrtValue.ortvalue_from_numpy(t, dev))
610 if debug:
611 if len(vect) != len(tensors):
612 raise RuntimeError( # pragma: no cover
613 "Unexpected array length %d != %d (len(devices)=%d)." % (
614 len(vect), len(tensors), len(devices)))
615 _validate_(vect)
616 return vect
618 def save_for_backward(self, inputs):
619 """
620 Saves inputs furing forward steps. The list inputs
621 is copied (simple copy, no deep copy).
623 :param inputs: list of tensors to save.
624 """
625 self.saved_tensors_ = list(inputs)
627 @property
628 def saved_tensors(self):
629 """
630 Returns saved tensors during forward step.
631 """
632 if self.saved_tensors_ is None:
633 raise RuntimeError( # pragma: no cover
634 "No tensors was saved with save_for_backward.")
635 return self.saved_tensors_
637 def forward(self, inputs, training=False, forward_outputs_cache=None):
638 """
639 Implements forward function.
641 :param inputs: inputs
642 :param training: only inference or training as well
643 :return: output as :epkg:`OrtValueVector`
644 """
645 logger = self._logger
646 cls = self.__class__
648 def _log(msg, *args):
649 logger.debug("[%s.forward] (%dI) " + msg,
650 cls.__name__, len(inputs), *args)
652 if logger is not None:
653 _log("begin with gradient" if training else "begin")
654 _log("torch function %r", type(cls))
655 _log("ort class %r", cls)
656 _log("create OrtValueVector (through dlpack)")
658 forward_inputs = cls.input_to_ort(inputs, cls._devices, cls._debug)
660 if training:
661 forward_outputs = forward_outputs_cache or OrtValueVector()
662 state = PartialGraphExecutionState()
663 self.states_.append(state)
664 if logger is not None:
665 _log("run_forward")
666 cls._training_agent.run_forward(
667 forward_inputs, forward_outputs, state, cls._cache)
669 self.save_for_backward(inputs)
670 if logger is not None:
671 _log("end")
672 return forward_outputs
673 else:
674 # what about bind_input (+ data_ptr)
675 if len(forward_inputs) != len(cls._grad_input_names):
676 raise RuntimeError( # pragma: no cover
677 "Size mismatch len(inputs)=%d, len(onnx inputs)=%d." % (
678 len(forward_inputs), len(cls._grad_input_names)))
679 iobinding = SessionIOBinding(cls._sess_eval._sess)
680 if logger is not None:
681 _log("bind inputs %r", cls._grad_input_names)
682 for name, inp in zip(
683 cls._grad_input_names, forward_inputs):
684 iobinding.bind_ortvalue_input(name, inp)
686 # bind output
687 if logger is not None:
688 _log("bind outputs %r", cls._output_names)
689 for name, dev in zip(
690 cls._output_names, cls._fw_no_grad_output_device_info):
691 iobinding.bind_output(name, dev)
693 # if the shape is known in advance
694 # iobinding.bind_output(
695 # output_desc.name, torch_tensor.device.type,
696 # _utils.get_device_index(target_device),
697 # _utils.dtype_torch_to_numpy(torch_tensor.dtype),
698 # list(torch_tensor.size()), torch_tensor.data_ptr())
700 if logger is not None:
701 _log("grad_enabled=False (run_with_iobinding)")
702 cls._sess_eval._sess.run_with_iobinding(
703 iobinding, cls._run_options)
704 if logger is not None:
705 _log("get_outputs")
706 ortvalues = iobinding.get_outputs()
707 if logger is not None:
708 _log("to torck.tensor (%d)", len(ortvalues))
709 _log("end")
710 return ortvalues
712 def backward(self, grad_outputs, backward_outputs_cache=None):
713 """
714 Implements backward function. The function returns
715 an :epkg:`OrtValueVector`.
716 """
717 cls = self.__class__
718 logger = cls._logger
720 def _log(msg, *args):
721 logger.debug("[%s.backward] (%dI) " + msg,
722 cls.__name__, len(grad_outputs), *args)
724 if logger is not None:
725 _log("begin")
726 _log("torch function %r", type(cls))
727 _log("ort class %r", cls)
728 _log("saved_tensors")
730 inputs = self.saved_tensors
731 if logger is not None:
732 _log("DEBUG: saved_tensors %r", type(inputs))
733 _log("self.state_.pop()")
734 state = self.states_.pop()
736 if logger is not None:
737 _log("create OrtValueVector")
739 backward_inputs = cls.input_to_ort(
740 grad_outputs, cls._bw_outputs_device_info, cls._debug)
742 if logger is not None:
743 _log("len(grad_outputs)=%d type(grad_outputs)=%r",
744 len(grad_outputs), type(grad_outputs))
745 _log("len(backward_inputs)=%d type(backward_inputs)=%r",
746 len(backward_inputs), type(backward_inputs))
747 for i in range(len(backward_inputs)): # pylint: disable=C0200
748 _log("backward_inputs[%d].shape=%r",
749 i, backward_inputs[i].shape())
750 _log("run_backward")
751 backward_outputs = backward_outputs_cache or OrtValueVector()
752 cls._training_agent.run_backward(
753 backward_inputs, backward_outputs, state)
754 if logger is not None: # pragma: no cover
755 _log("DEBUG")
756 for i, ov in enumerate(backward_outputs):
757 _log("BCK-RET: i=%d - shape=%r - ptr=%r",
758 i, ov.shape(), ov.data_ptr())
759 _log("got %r gradients", len(backward_outputs))
760 _log("end")
761 return backward_outputs