Coverage for deeponnxcustom/onnxtorch/torchort.py: 95%
248 statements
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-06 02:28 +0200
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-06 02:28 +0200
1"""
2@file
3@brief Experimental.
4"""
5import warnings
6import logging
7from textwrap import dedent
8from io import BytesIO
9import onnx
10from onnxruntime import InferenceSession
11from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
12 SessionIOBinding)
13try:
14 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
15 TrainingAgent, OrtValueCache, OrtModuleGraphBuilder,
16 OrtModuleGraphBuilderConfiguration, OrtDevice,
17 TrainingGraphTransformerConfiguration, OrtValueVector,
18 PartialGraphExecutionState)
19except ImportError: # pragma: no cover
20 # onnxruntime-training is not installed.
21 warnings.warn(
22 "TorchOrtFactory cannot work without onnxruntime-training.")
23from onnxruntime import RunOptions
24from torch import is_grad_enabled # pylint: disable=E0611
25from torch.autograd import Function
26from torch.utils.dlpack import from_dlpack, to_dlpack
27from torch._C import _from_dlpack
30class TorchOrtFunction(Function):
31 """
32 Ancestor to all classes created by @see cl TorchOrtFactory.
33 It implements simple functions to move the ownership of tensors
34 from *onnxruntime* to *pytorch* (or the other way around)
35 through :epkg:`DLPack` structures.
36 This class requires :epkg:`onnxruntime_training`.
38 .. faqref::
39 :title: Differences between onnxruntime and onnxruntime-training
41 onnxruntime-training is an extension of onnxruntime
42 that supports training. Version 1.10 is obtained by compiling
43 onnxruntime from the sources with different flags.
44 One example:
46 ::
48 python ./tools/ci_build/build.py --build_dir ./build/debian \\
49 --config Release --build_wheel --numpy_version= \\
50 --skip_tests --build_shared_lib --enable_training \\
51 --enable_training_ops --enable_training_torch_interop \\
52 --parallel
53 """
55 @staticmethod
56 def from_torch_to_ort(tensors):
57 "Converts a list of pytorch tensors into an OrtValueVector."
58 vect = OrtValueVector()
59 vect.reserve(len(tensors))
60 for t in tensors:
61 if t is None:
62 # if gradient then
63 # grad_output = torch.zeros(shape, device=device, dtype=dtype)
64 raise NotImplementedError( # pragma: no cover
65 "Empty vector found.")
66 if not t.is_contiguous():
67 # grad = grad.contiguous()
68 raise NotImplementedError( # pragma: no cover
69 "Non contiguous gradient found.")
70 vect.push_back(to_dlpack(t), False)
71 return vect
73 @staticmethod
74 def from_ort_to_torch(ort_values):
75 "Converts a OrtValueVector into a tuple of pytorch tensors."
76 # return tuple(_from_dlpack(ov.to_dlpack()) for ov in ort_values)
77 if hasattr(ort_values, 'to_dlpack'):
78 return tuple(ort_values.to_dlpack(_from_dlpack))
79 if len(ort_values) == 0:
80 raise RuntimeError( # pragma: no cover
81 "The conversion fails on an empty vector.")
82 if hasattr(ort_values[0], '__dlpack__'):
83 return tuple( # pragma: no cover
84 from_dlpack(ov) for ov in ort_values)
85 else:
86 return tuple(_from_dlpack(ov.to_dlpack()) for ov in ort_values)
89def ort_forward(ctx, *inputs):
90 """
91 Implements forward function.
92 See :epkg:`autograd functions`.
93 """
94 cls = ctx._forward_cls
95 logger = cls._logger
96 training = is_grad_enabled() or any(ctx.needs_input_grad)
98 def _log(msg):
99 logger.debug("[%s.forward] (%dI) %s" % (
100 cls.__name__, len(inputs), msg))
102 if logger is not None:
103 if training:
104 _log("begin with gradient")
105 else:
106 _log("begin")
107 _log("torch function %r" % type(ctx))
108 _log("ort class %r" % cls)
109 _log("create OrtValueVector (through dlpack)")
111 forward_inputs = cls.from_torch_to_ort(inputs)
113 if training:
114 forward_outputs = OrtValueVector()
115 state = PartialGraphExecutionState()
116 cls._states.append(state)
117 if logger is not None:
118 _log("run_forward")
119 cls._training_agent.run_forward(
120 forward_inputs, forward_outputs, state, cls._cache)
122 ctx.save_for_backward(*inputs)
124 if cls._update_cache:
125 if logger is not None:
126 _log("update_cache")
127 raise NotImplementedError("Cache is not implemented.")
129 # for i in range(self._cache_start, len(forward_outputs)):
130 # self.cache.insert(
131 # self._cached_node_arg_names[i - cls._cache_start],
132 # forward_outputs[i])
133 # self._update_cache = False
134 # if logger is not None:
135 # _log("to torck.tensor")
136 # return tuple(_utils._ortvalue_to_torch_tensor
137 # (forward_outputs[i], device) for i in range(self._cache_start))
139 else:
140 if logger is not None:
141 _log("to torck.tensor")
142 res = cls.from_ort_to_torch(forward_outputs)
143 if len(res) == 1:
144 res = res[0]
145 if logger is not None:
146 _log("end")
147 return res
148 else:
149 # what about bind_input (+ data_ptr)
150 if len(forward_inputs) != len(cls._grad_input_names):
151 raise RuntimeError( # pragma: no cover
152 "Size mismatch len(inputs)=%d, len(onnx inputs)=%d." % (
153 len(forward_inputs), len(cls._grad_input_names)))
154 iobinding = SessionIOBinding(cls._sess_eval._sess)
155 if logger is not None:
156 _log("bind inputs %r" % cls._grad_input_names)
157 for name, inp in zip(
158 cls._grad_input_names, forward_inputs):
159 iobinding.bind_ortvalue_input(name, inp)
161 # bind output
162 if logger is not None:
163 _log("bind outputs %r" % cls._output_names)
164 for name, dev in zip(
165 cls._output_names, cls._fw_no_grad_output_device_info):
166 iobinding.bind_output(name, dev)
168 # if the shape is known in advance
169 # iobinding.bind_output(
170 # output_desc.name, torch_tensor.device.type,
171 # _utils.get_device_index(target_device),
172 # _utils.dtype_torch_to_numpy(torch_tensor.dtype),
173 # list(torch_tensor.size()), torch_tensor.data_ptr())
175 if logger is not None:
176 _log("grad_enabled=False (run_with_iobinding)")
177 cls._sess_eval._sess.run_with_iobinding(iobinding, cls._run_options)
178 if logger is not None:
179 _log("get_outputs")
180 ortvalues = iobinding.get_outputs()
181 if logger is not None:
182 _log("to torck.tensor")
183 res = cls.from_ort_to_torch(ortvalues)
184 if len(res) == 1:
185 res = res[0]
186 if logger is not None:
187 _log("end")
188 return res
191def ort_backward(ctx, *grad_outputs):
192 """
193 Implements backward function.
194 See :epkg:`autograd functions`.
195 """
196 cls = ctx._forward_cls
197 logger = cls._logger
199 def _log(msg):
200 logger.debug("[%s.backward] (%dI) %s" % (
201 cls.__name__, len(grad_outputs), msg))
203 if logger is not None:
204 _log("begin")
205 _log("torch function %r" % type(ctx))
206 _log("ort class %r" % cls)
207 _log("saved_tensors")
209 inputs = ctx.saved_tensors
210 if cls._debug:
211 print( # pragma: no cover
212 "DEBUG: saved_tensors %r" % type(inputs))
213 if logger is not None:
214 _log("cls._state.pop()")
215 state = cls._states.pop()
217 if logger is not None:
218 _log("create OrtValueVector (through dlpack)")
220 backward_inputs = cls.from_torch_to_ort(grad_outputs)
222 backward_outputs = OrtValueVector()
223 if logger is not None:
224 _log("run_backward")
225 cls._training_agent.run_backward(backward_inputs, backward_outputs, state)
226 res = cls.from_ort_to_torch(backward_outputs)
227 if len(res) == 1:
228 res = res[0]
229 else:
230 if cls._debug: # pragma: no cover
231 print("DEBUG")
232 for i, ov in enumerate(backward_outputs):
233 print("BCK-RET: i=%d - ptr=%r - shape=%r" % (
234 i, ov.shape(), ov.data_ptr()))
235 if logger is not None:
236 _log("got %r gradients" % len(res))
237 if logger is not None:
238 _log("end")
239 return res
242class TorchOrtFactory:
243 """
244 A class which dynamically another class which implements a
245 custom function (see :epkg:`autograd functions`).
246 Use ONNX inside a torch function. Only initializers
247 can be trained, no parameters.
249 :param onnx_model: onnx model
250 :param weights_to_train: names of the weights to train
251 :param input_names: input names or None for all
252 :param output_names: output names or None for all
253 :param class_name: class name
254 :param sess_options: see :epkg:`SessionOptions`
255 :param providers: see :epkg:`InferenceSession`
256 :param provider_options: see :epkg:`InferenceSession`
257 :param run_options: see :epkg:`RunOptions`
258 :param graph_builder_config:
259 see :epkg:`OrtModuleGraphBuilderConfiguration`
260 :param device_index: used for cuda (0 for `cuda:0`,
261 `cuda:1`, ...), 0 by default
263 .. note::
264 The current implementation of :epkg:`onnxruntime` forces
265 the weights to train to appear in the alphabetical order.
266 The constructor checks that condition is verified.
268 .. warning::
269 This class does not consider subgraphs.
270 """
272 def __init__(self, onnx_model, weights_to_train,
273 input_names=None, output_names=None,
274 class_name=None,
275 sess_options=None, providers=None,
276 provider_options=None, run_options=None,
277 graph_builder_config=None,
278 device_index=0):
279 self.onnx_model = onnx_model
280 self.input_names = input_names
281 self.output_names = output_names
282 self.class_name = class_name
283 self.weights_to_train = weights_to_train
284 self.device_index = device_index
286 self.provider_options = provider_options
287 self.sess_options = sess_options
288 self.providers = providers
289 self.run_options = run_options
290 self.graph_builder_config = graph_builder_config
292 # default
293 if self.weights_to_train is None:
294 raise ValueError( # pragma: no cover
295 "weights_to_train must be specified.")
296 if self.input_names is None:
297 self.input_names = [obj.name
298 for obj in self.onnx_model.graph.input]
299 if self.output_names is None:
300 self.output_names = [obj.name
301 for obj in self.onnx_model.graph.output]
302 if self.class_name is None:
303 self.class_name = "TorchOrtFunction_%r" % id(self)
304 if hasattr(self.providers, 'type'):
305 if self.providers.type != 'cpu':
306 self.device_index = self.providers.index
307 self.providers = self.providers.type
308 if self.providers in (None, 'cpu'):
309 self.providers = ["CPUExecutionProvider" for i in self.input_names]
310 if self.provider_options is None:
311 self.provider_options = [{} for i in self.input_names]
312 elif self.providers in ('cuda', 'cuda:0', 'gpu'):
313 self.providers = [
314 "CUDAExecutionProvider" for i in self.input_names]
315 if self.provider_options is None:
316 self.provider_options = [{} for i in self.input_names]
317 if self.run_options is None:
318 self.run_options = RunOptions()
319 self.run_options.training_mode = True
321 if len(self.input_names) != len(self.providers):
322 raise ValueError( # pragma: no cover
323 "input_names and providers must have the same length.")
324 if len(self.input_names) != len(self.provider_options):
325 raise ValueError( # pragma: no cover
326 "input_names and provider_options must have the same length.")
328 if list(sorted(self.weights_to_train)) != self.weights_to_train:
329 raise ValueError(
330 "List of weights to train must be sorted but is not in %r. "
331 "You shoud use function onnx_rename_weights to do that "
332 "before calling this class." % self.weights_to_train)
334 if self.graph_builder_config is None:
335 initializer_names = [
336 i.name for i in self.onnx_model.graph.initializer]
337 input_names = [i.name for i in self.onnx_model.graph.input]
339 config = OrtModuleGraphBuilderConfiguration()
340 config.initializer_names = [init for init in initializer_names
341 if init in self.weights_to_train]
342 config.initializer_names_to_train = self.weights_to_train
343 config.input_names_require_grad = input_names
344 config.build_gradient_graph = True
346 if (len(config.initializer_names) != # noqa
347 len(config.initializer_names_to_train)):
348 raise RuntimeError(
349 "Unable to automatically fill "
350 "OrtModuleGraphBuilderConfiguration, mismatch between "
351 "%r and %r." % (config.initializer_names,
352 config.initializer_names_to_train))
354 p = TrainingGraphTransformerConfiguration()
355 config.graph_transformer_config = p
357 # config.enable_caching = True
358 # config.loglevel =
359 # config.use_memory_efficient_gradient = True
360 self.graph_builder_config = config
362 def __repr__(self):
363 "usual"
364 return "%s(...)" % self.__class__.__name__
366 @staticmethod
367 def _repr_helper_(obj, indent=0):
368 "used to improve logging messages"
369 if obj is None:
370 return 'None'
371 rows = []
372 for c in sorted(dir(obj)):
373 if c[0] == '_':
374 continue
375 try:
376 value = getattr(obj, c)
377 except AttributeError: # pragma: no cover
378 continue
379 rows.append("%s=%r" % (c, value))
381 if indent == 0:
382 return "%s(%s)" % (obj.__class__.__name__, ", ".join(rows))
383 return "%s(\n %s)" % (
384 obj.__class__.__name__,
385 "\n ".join(rows))
387 @staticmethod
388 def _provider_name_to_device_type(provider_name):
389 if provider_name == 'CPUExecutionProvider':
390 return OrtDevice.cpu()
391 if provider_name == 'CUDAExecutionProvider': # pragma: no cover
392 return OrtDevice.cuda()
393 raise ValueError( # pragma: no cover
394 'Unexpected provider name %r.' % provider_name)
396 def create_class(self, enable_logging=False, keep_models=False,
397 debug=False):
398 """
399 Creates a class which inherits from
400 :func:`torch.autograd.Function` and implements forward,
401 backward methods using ONNX. The function dynamically
402 creates a new class and pushes every needed objects
403 as static attributes of the new class.
405 :param enable_logging: used to debug, logs every building step,
406 at info level, logs information while processing forward
407 and backward at debug level
408 :param keep_models: stores additional information as
409 static attributes
410 :param debug: display information
411 :return: a new class
413 The pattern follows the documentation described in
414 :epkg:`autograd functions`. Methods forward and backward
415 are replaced by onnx implementations, runtime is
416 :epkg:`onnxruntime-training`.
418 ::
420 class CustomClass(torch.autograd.Function):
422 @staticmethod
423 def forward(ctx, *input):
424 ctx.save_for_backward(*input)
425 return ...
427 @staticmethod
428 def backward(ctx, *grad_output):
429 input, = ctx.saved_tensors
430 grad_input = grad_output.clone()
431 grad_input[input < 0] = 0
432 return grad_input
434 The new class has the following attributes:
436 * `__doc__`: doc string
437 * `__module__`: module name (this file)
438 * `_run_options`: see :epkg:`RunOptions`
439 * `_sess`: :epkg:`InferenceSession` with the original graph
440 * `_sess_eval`: :epkg:`InferenceSession` on the graph
441 with weights as inputs
442 * `_training_agent`: :epkg:`TrainingAgent`
443 * `_cache`: :epkg:`OrtValueCache`
444 * `_update_cache`: update the cache or not
445 * `_states`: a list
446 * `_logger`: logger
447 * `_input_names`: input names
448 * `_debug`: use debug mode
449 * `_grad_input_names`: gradient input names
450 * `_output_names`: output names
451 * `_weights_to_train`: names of the weights to train
453 Torch API:
455 * `forward`: forward static method
456 * `backward`: forward static method
458 Training attributes
460 * `_bw_fetches_names`: bw_fetches_names,
461 * `_fw_outputs_device_info`: fw_outputs_device_info,
462 * `_bw_outputs_device_info`: bw_outputs_device_info,
463 * `_fw_no_grad_output_device_info`: fw_no_grad_output_device_info,
464 * `_graph_info`: graph_info}
466 Additional attributes added if *keep_model* is True:
468 * `_trained_onnx`: ONNX graph for the gradient
469 * `_optimized_pre_grad_model`: evaluation ONNX graph taking
470 weights as inputs
471 * `_graph_builder`: :epkg:`OrtModuleGraphBuilder`
472 """
473 if enable_logging:
474 logger = logging.getLogger("deeponnxcustom")
475 else:
476 logger = None # pragma: no cover
478 doc = dedent("""Use onnxruntime to compute the gradient
479 in a pytorch function.""")
481 if logger is not None:
482 logger.info("[TorchOrtFactory] create training onnx")
483 logger.info("[TorchOrtFactory] input_names=%r",
484 self.input_names)
485 logger.info("[TorchOrtFactory] output_names=%r",
486 self.output_names)
487 logger.info("[TorchOrtFactory] weights_to_train=%r",
488 self.weights_to_train)
490 builder = OrtModuleGraphBuilder()
492 if logger is not None:
493 cf = self.graph_builder_config.graph_transformer_config
494 cfp = cf.propagate_cast_ops_config
495 logger.info("[TorchOrtFactory] OrtModuleGraphBuilder.initialize")
496 logger.info(
497 "[TorchOrtFactory] graph_builder_config=%s",
498 TorchOrtFactory._repr_helper_(
499 self.graph_builder_config, indent=4))
500 logger.info(
501 "[TorchOrtFactory] graph_builder_config."
502 "graph_transformer_config=%s",
503 TorchOrtFactory._repr_helper_(cf, indent=4))
504 logger.info(
505 "[TorchOrtFactory] graph_builder_config."
506 "graph_transformer_config.propagate_cast_ops_config=%s",
507 TorchOrtFactory._repr_helper_(cfp, indent=4))
509 builder.initialize(
510 self.onnx_model.SerializeToString(),
511 self.graph_builder_config)
513 if logger is not None:
514 logger.info("[TorchOrtFactory] OrtModuleGraphBuilder.build")
515 builder.build()
517 if logger is not None:
518 logger.info("[TorchOrtFactory] OrtModuleGraphBuilder.get_model")
520 train_onnx_model_serialized = builder.get_model()
522 optimized_pre_grad_model = builder.get_inference_optimized_model()
523 graph_info = builder.get_graph_info()
525 if logger is not None:
526 logger.info("[TorchOrtFactory] graph_info=%s",
527 TorchOrtFactory._repr_helper_(
528 graph_info, indent=4))
529 logger.info("[TorchOrtFactory] create TrainSession")
530 logger.info("[TorchOrtFactory] sess_options=%s",
531 TorchOrtFactory._repr_helper_(
532 self.sess_options, indent=4))
533 logger.info("[TorchOrtFactory] providers=%r", self.providers)
535 sess = InferenceSession(
536 train_onnx_model_serialized,
537 sess_options=self.sess_options,
538 provider_options=self.provider_options,
539 providers=self.providers)
541 if logger is not None:
542 logger.info("[TorchOrtFactory] create InferenceSession")
544 sess_eval = InferenceSession(
545 optimized_pre_grad_model,
546 sess_options=self.sess_options,
547 provider_options=self.provider_options,
548 providers=self.providers)
550 if logger is not None:
551 logger.info("[TorchOrtFactory] create training agent")
553 grad_input_names = [obj.name for obj in sess.get_inputs()]
554 bw_fetches_names = [obj.name for obj in sess.get_outputs()]
556 fw_outputs_device_info = [
557 OrtDevice(
558 TorchOrtFactory._provider_name_to_device_type(i),
559 OrtDevice.default_memory(),
560 self.device_index)
561 for i in self.providers]
562 bw_outputs_device_info = [
563 OrtDevice(
564 TorchOrtFactory._provider_name_to_device_type(
565 self.providers[0]),
566 OrtDevice.default_memory(),
567 self.device_index)
568 for i in bw_fetches_names]
569 fw_no_grad_output_device_info = [
570 OrtDevice(
571 TorchOrtFactory._provider_name_to_device_type(
572 self.providers[0]),
573 OrtDevice.default_memory(),
574 self.device_index)
575 for i in self.output_names]
577 training_agent = TrainingAgent(
578 sess._sess,
579 grad_input_names,
580 fw_outputs_device_info,
581 bw_fetches_names,
582 bw_outputs_device_info)
584 if logger is not None:
585 logger.info(
586 "[TorchOrtFactory] instantiate dynamic class %r",
587 self.class_name)
588 logger.info(
589 "[TorchOrtFactory] weights_to_train=%r",
590 self.weights_to_train)
591 logger.info(
592 "[TorchOrtFactory] grad_input_names=%r",
593 grad_input_names)
594 logger.info(
595 "[TorchOrtFactory] bw_fetches_names=%r",
596 bw_fetches_names)
597 logger.info(
598 "[TorchOrtFactory] device_index=%r",
599 self.device_index)
601 kwargs = {
602 '__doc__': doc,
603 '__module__': __name__,
604 '_run_options': self.run_options,
605 '_sess': sess,
606 '_sess_eval': sess_eval,
607 '_training_agent': training_agent,
608 '_cache': OrtValueCache(),
609 '_update_cache': False,
610 '_states': [],
611 '_logger': logger,
612 '_input_names': self.input_names,
613 '_debug': debug,
614 '_grad_input_names': grad_input_names,
615 '_output_names': self.output_names,
616 '_bw_fetches_names': bw_fetches_names,
617 '_fw_outputs_device_info': fw_outputs_device_info,
618 '_bw_outputs_device_info': bw_outputs_device_info,
619 '_fw_no_grad_output_device_info': fw_no_grad_output_device_info,
620 '_weights_to_train': list(sorted(
621 self.weights_to_train)),
622 'forward': staticmethod(ort_forward),
623 'backward': staticmethod(ort_backward),
624 '_graph_info': graph_info}
626 if keep_models:
627 kwargs.update(dict(
628 _trained_onnx=onnx.load(BytesIO(train_onnx_model_serialized)),
629 _optimized_pre_grad_model=onnx.load(
630 BytesIO(optimized_pre_grad_model)),
631 _graph_builder=builder,
632 _factory=self))
634 onx_inp = [o.name for o in kwargs['_trained_onnx'].graph.input]
635 onx_out = [o.name for o in kwargs['_trained_onnx'].graph.output]
636 if len(onx_inp) != len(onx_out):
637 raise RuntimeError(
638 "Gradient input and output are inconsistant: "
639 "%r != %r" % (onx_inp, onx_out))
641 newclass = type(self.class_name, (TorchOrtFunction,), kwargs)
642 return newclass