Coverage for deeponnxcustom/onnxtorch/torchort.py: 95%

248 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-06-06 02:28 +0200

1""" 

2@file 

3@brief Experimental. 

4""" 

5import warnings 

6import logging 

7from textwrap import dedent 

8from io import BytesIO 

9import onnx 

10from onnxruntime import InferenceSession 

11from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

12 SessionIOBinding) 

13try: 

14 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

15 TrainingAgent, OrtValueCache, OrtModuleGraphBuilder, 

16 OrtModuleGraphBuilderConfiguration, OrtDevice, 

17 TrainingGraphTransformerConfiguration, OrtValueVector, 

18 PartialGraphExecutionState) 

19except ImportError: # pragma: no cover 

20 # onnxruntime-training is not installed. 

21 warnings.warn( 

22 "TorchOrtFactory cannot work without onnxruntime-training.") 

23from onnxruntime import RunOptions 

24from torch import is_grad_enabled # pylint: disable=E0611 

25from torch.autograd import Function 

26from torch.utils.dlpack import from_dlpack, to_dlpack 

27from torch._C import _from_dlpack 

28 

29 

30class TorchOrtFunction(Function): 

31 """ 

32 Ancestor to all classes created by @see cl TorchOrtFactory. 

33 It implements simple functions to move the ownership of tensors 

34 from *onnxruntime* to *pytorch* (or the other way around) 

35 through :epkg:`DLPack` structures. 

36 This class requires :epkg:`onnxruntime_training`. 

37 

38 .. faqref:: 

39 :title: Differences between onnxruntime and onnxruntime-training 

40 

41 onnxruntime-training is an extension of onnxruntime 

42 that supports training. Version 1.10 is obtained by compiling 

43 onnxruntime from the sources with different flags. 

44 One example: 

45 

46 :: 

47 

48 python ./tools/ci_build/build.py --build_dir ./build/debian \\ 

49 --config Release --build_wheel --numpy_version= \\ 

50 --skip_tests --build_shared_lib --enable_training \\ 

51 --enable_training_ops --enable_training_torch_interop \\ 

52 --parallel 

53 """ 

54 

55 @staticmethod 

56 def from_torch_to_ort(tensors): 

57 "Converts a list of pytorch tensors into an OrtValueVector." 

58 vect = OrtValueVector() 

59 vect.reserve(len(tensors)) 

60 for t in tensors: 

61 if t is None: 

62 # if gradient then 

63 # grad_output = torch.zeros(shape, device=device, dtype=dtype) 

64 raise NotImplementedError( # pragma: no cover 

65 "Empty vector found.") 

66 if not t.is_contiguous(): 

67 # grad = grad.contiguous() 

68 raise NotImplementedError( # pragma: no cover 

69 "Non contiguous gradient found.") 

70 vect.push_back(to_dlpack(t), False) 

71 return vect 

72 

73 @staticmethod 

74 def from_ort_to_torch(ort_values): 

75 "Converts a OrtValueVector into a tuple of pytorch tensors." 

76 # return tuple(_from_dlpack(ov.to_dlpack()) for ov in ort_values) 

77 if hasattr(ort_values, 'to_dlpack'): 

78 return tuple(ort_values.to_dlpack(_from_dlpack)) 

79 if len(ort_values) == 0: 

80 raise RuntimeError( # pragma: no cover 

81 "The conversion fails on an empty vector.") 

82 if hasattr(ort_values[0], '__dlpack__'): 

83 return tuple( # pragma: no cover 

84 from_dlpack(ov) for ov in ort_values) 

85 else: 

86 return tuple(_from_dlpack(ov.to_dlpack()) for ov in ort_values) 

87 

88 

89def ort_forward(ctx, *inputs): 

90 """ 

91 Implements forward function. 

92 See :epkg:`autograd functions`. 

93 """ 

94 cls = ctx._forward_cls 

95 logger = cls._logger 

96 training = is_grad_enabled() or any(ctx.needs_input_grad) 

97 

98 def _log(msg): 

99 logger.debug("[%s.forward] (%dI) %s" % ( 

100 cls.__name__, len(inputs), msg)) 

101 

102 if logger is not None: 

103 if training: 

104 _log("begin with gradient") 

105 else: 

106 _log("begin") 

107 _log("torch function %r" % type(ctx)) 

108 _log("ort class %r" % cls) 

109 _log("create OrtValueVector (through dlpack)") 

110 

111 forward_inputs = cls.from_torch_to_ort(inputs) 

112 

113 if training: 

114 forward_outputs = OrtValueVector() 

115 state = PartialGraphExecutionState() 

116 cls._states.append(state) 

117 if logger is not None: 

118 _log("run_forward") 

119 cls._training_agent.run_forward( 

120 forward_inputs, forward_outputs, state, cls._cache) 

121 

122 ctx.save_for_backward(*inputs) 

123 

124 if cls._update_cache: 

125 if logger is not None: 

126 _log("update_cache") 

127 raise NotImplementedError("Cache is not implemented.") 

128 

129 # for i in range(self._cache_start, len(forward_outputs)): 

130 # self.cache.insert( 

131 # self._cached_node_arg_names[i - cls._cache_start], 

132 # forward_outputs[i]) 

133 # self._update_cache = False 

134 # if logger is not None: 

135 # _log("to torck.tensor") 

136 # return tuple(_utils._ortvalue_to_torch_tensor 

137 # (forward_outputs[i], device) for i in range(self._cache_start)) 

138 

139 else: 

140 if logger is not None: 

141 _log("to torck.tensor") 

142 res = cls.from_ort_to_torch(forward_outputs) 

143 if len(res) == 1: 

144 res = res[0] 

145 if logger is not None: 

146 _log("end") 

147 return res 

148 else: 

149 # what about bind_input (+ data_ptr) 

150 if len(forward_inputs) != len(cls._grad_input_names): 

151 raise RuntimeError( # pragma: no cover 

152 "Size mismatch len(inputs)=%d, len(onnx inputs)=%d." % ( 

153 len(forward_inputs), len(cls._grad_input_names))) 

154 iobinding = SessionIOBinding(cls._sess_eval._sess) 

155 if logger is not None: 

156 _log("bind inputs %r" % cls._grad_input_names) 

157 for name, inp in zip( 

158 cls._grad_input_names, forward_inputs): 

159 iobinding.bind_ortvalue_input(name, inp) 

160 

161 # bind output 

162 if logger is not None: 

163 _log("bind outputs %r" % cls._output_names) 

164 for name, dev in zip( 

165 cls._output_names, cls._fw_no_grad_output_device_info): 

166 iobinding.bind_output(name, dev) 

167 

168 # if the shape is known in advance 

169 # iobinding.bind_output( 

170 # output_desc.name, torch_tensor.device.type, 

171 # _utils.get_device_index(target_device), 

172 # _utils.dtype_torch_to_numpy(torch_tensor.dtype), 

173 # list(torch_tensor.size()), torch_tensor.data_ptr()) 

174 

175 if logger is not None: 

176 _log("grad_enabled=False (run_with_iobinding)") 

177 cls._sess_eval._sess.run_with_iobinding(iobinding, cls._run_options) 

178 if logger is not None: 

179 _log("get_outputs") 

180 ortvalues = iobinding.get_outputs() 

181 if logger is not None: 

182 _log("to torck.tensor") 

183 res = cls.from_ort_to_torch(ortvalues) 

184 if len(res) == 1: 

185 res = res[0] 

186 if logger is not None: 

187 _log("end") 

188 return res 

189 

190 

191def ort_backward(ctx, *grad_outputs): 

192 """ 

193 Implements backward function. 

194 See :epkg:`autograd functions`. 

195 """ 

196 cls = ctx._forward_cls 

197 logger = cls._logger 

198 

199 def _log(msg): 

200 logger.debug("[%s.backward] (%dI) %s" % ( 

201 cls.__name__, len(grad_outputs), msg)) 

202 

203 if logger is not None: 

204 _log("begin") 

205 _log("torch function %r" % type(ctx)) 

206 _log("ort class %r" % cls) 

207 _log("saved_tensors") 

208 

209 inputs = ctx.saved_tensors 

210 if cls._debug: 

211 print( # pragma: no cover 

212 "DEBUG: saved_tensors %r" % type(inputs)) 

213 if logger is not None: 

214 _log("cls._state.pop()") 

215 state = cls._states.pop() 

216 

217 if logger is not None: 

218 _log("create OrtValueVector (through dlpack)") 

219 

220 backward_inputs = cls.from_torch_to_ort(grad_outputs) 

221 

222 backward_outputs = OrtValueVector() 

223 if logger is not None: 

224 _log("run_backward") 

225 cls._training_agent.run_backward(backward_inputs, backward_outputs, state) 

226 res = cls.from_ort_to_torch(backward_outputs) 

227 if len(res) == 1: 

228 res = res[0] 

229 else: 

230 if cls._debug: # pragma: no cover 

231 print("DEBUG") 

232 for i, ov in enumerate(backward_outputs): 

233 print("BCK-RET: i=%d - ptr=%r - shape=%r" % ( 

234 i, ov.shape(), ov.data_ptr())) 

235 if logger is not None: 

236 _log("got %r gradients" % len(res)) 

237 if logger is not None: 

238 _log("end") 

239 return res 

240 

241 

242class TorchOrtFactory: 

243 """ 

244 A class which dynamically another class which implements a 

245 custom function (see :epkg:`autograd functions`). 

246 Use ONNX inside a torch function. Only initializers 

247 can be trained, no parameters. 

248 

249 :param onnx_model: onnx model 

250 :param weights_to_train: names of the weights to train 

251 :param input_names: input names or None for all 

252 :param output_names: output names or None for all 

253 :param class_name: class name 

254 :param sess_options: see :epkg:`SessionOptions` 

255 :param providers: see :epkg:`InferenceSession` 

256 :param provider_options: see :epkg:`InferenceSession` 

257 :param run_options: see :epkg:`RunOptions` 

258 :param graph_builder_config: 

259 see :epkg:`OrtModuleGraphBuilderConfiguration` 

260 :param device_index: used for cuda (0 for `cuda:0`, 

261 `cuda:1`, ...), 0 by default 

262 

263 .. note:: 

264 The current implementation of :epkg:`onnxruntime` forces 

265 the weights to train to appear in the alphabetical order. 

266 The constructor checks that condition is verified. 

267 

268 .. warning:: 

269 This class does not consider subgraphs. 

270 """ 

271 

272 def __init__(self, onnx_model, weights_to_train, 

273 input_names=None, output_names=None, 

274 class_name=None, 

275 sess_options=None, providers=None, 

276 provider_options=None, run_options=None, 

277 graph_builder_config=None, 

278 device_index=0): 

279 self.onnx_model = onnx_model 

280 self.input_names = input_names 

281 self.output_names = output_names 

282 self.class_name = class_name 

283 self.weights_to_train = weights_to_train 

284 self.device_index = device_index 

285 

286 self.provider_options = provider_options 

287 self.sess_options = sess_options 

288 self.providers = providers 

289 self.run_options = run_options 

290 self.graph_builder_config = graph_builder_config 

291 

292 # default 

293 if self.weights_to_train is None: 

294 raise ValueError( # pragma: no cover 

295 "weights_to_train must be specified.") 

296 if self.input_names is None: 

297 self.input_names = [obj.name 

298 for obj in self.onnx_model.graph.input] 

299 if self.output_names is None: 

300 self.output_names = [obj.name 

301 for obj in self.onnx_model.graph.output] 

302 if self.class_name is None: 

303 self.class_name = "TorchOrtFunction_%r" % id(self) 

304 if hasattr(self.providers, 'type'): 

305 if self.providers.type != 'cpu': 

306 self.device_index = self.providers.index 

307 self.providers = self.providers.type 

308 if self.providers in (None, 'cpu'): 

309 self.providers = ["CPUExecutionProvider" for i in self.input_names] 

310 if self.provider_options is None: 

311 self.provider_options = [{} for i in self.input_names] 

312 elif self.providers in ('cuda', 'cuda:0', 'gpu'): 

313 self.providers = [ 

314 "CUDAExecutionProvider" for i in self.input_names] 

315 if self.provider_options is None: 

316 self.provider_options = [{} for i in self.input_names] 

317 if self.run_options is None: 

318 self.run_options = RunOptions() 

319 self.run_options.training_mode = True 

320 

321 if len(self.input_names) != len(self.providers): 

322 raise ValueError( # pragma: no cover 

323 "input_names and providers must have the same length.") 

324 if len(self.input_names) != len(self.provider_options): 

325 raise ValueError( # pragma: no cover 

326 "input_names and provider_options must have the same length.") 

327 

328 if list(sorted(self.weights_to_train)) != self.weights_to_train: 

329 raise ValueError( 

330 "List of weights to train must be sorted but is not in %r. " 

331 "You shoud use function onnx_rename_weights to do that " 

332 "before calling this class." % self.weights_to_train) 

333 

334 if self.graph_builder_config is None: 

335 initializer_names = [ 

336 i.name for i in self.onnx_model.graph.initializer] 

337 input_names = [i.name for i in self.onnx_model.graph.input] 

338 

339 config = OrtModuleGraphBuilderConfiguration() 

340 config.initializer_names = [init for init in initializer_names 

341 if init in self.weights_to_train] 

342 config.initializer_names_to_train = self.weights_to_train 

343 config.input_names_require_grad = input_names 

344 config.build_gradient_graph = True 

345 

346 if (len(config.initializer_names) != # noqa 

347 len(config.initializer_names_to_train)): 

348 raise RuntimeError( 

349 "Unable to automatically fill " 

350 "OrtModuleGraphBuilderConfiguration, mismatch between " 

351 "%r and %r." % (config.initializer_names, 

352 config.initializer_names_to_train)) 

353 

354 p = TrainingGraphTransformerConfiguration() 

355 config.graph_transformer_config = p 

356 

357 # config.enable_caching = True 

358 # config.loglevel = 

359 # config.use_memory_efficient_gradient = True 

360 self.graph_builder_config = config 

361 

362 def __repr__(self): 

363 "usual" 

364 return "%s(...)" % self.__class__.__name__ 

365 

366 @staticmethod 

367 def _repr_helper_(obj, indent=0): 

368 "used to improve logging messages" 

369 if obj is None: 

370 return 'None' 

371 rows = [] 

372 for c in sorted(dir(obj)): 

373 if c[0] == '_': 

374 continue 

375 try: 

376 value = getattr(obj, c) 

377 except AttributeError: # pragma: no cover 

378 continue 

379 rows.append("%s=%r" % (c, value)) 

380 

381 if indent == 0: 

382 return "%s(%s)" % (obj.__class__.__name__, ", ".join(rows)) 

383 return "%s(\n %s)" % ( 

384 obj.__class__.__name__, 

385 "\n ".join(rows)) 

386 

387 @staticmethod 

388 def _provider_name_to_device_type(provider_name): 

389 if provider_name == 'CPUExecutionProvider': 

390 return OrtDevice.cpu() 

391 if provider_name == 'CUDAExecutionProvider': # pragma: no cover 

392 return OrtDevice.cuda() 

393 raise ValueError( # pragma: no cover 

394 'Unexpected provider name %r.' % provider_name) 

395 

396 def create_class(self, enable_logging=False, keep_models=False, 

397 debug=False): 

398 """ 

399 Creates a class which inherits from 

400 :func:`torch.autograd.Function` and implements forward, 

401 backward methods using ONNX. The function dynamically 

402 creates a new class and pushes every needed objects 

403 as static attributes of the new class. 

404 

405 :param enable_logging: used to debug, logs every building step, 

406 at info level, logs information while processing forward 

407 and backward at debug level 

408 :param keep_models: stores additional information as 

409 static attributes 

410 :param debug: display information 

411 :return: a new class 

412 

413 The pattern follows the documentation described in 

414 :epkg:`autograd functions`. Methods forward and backward 

415 are replaced by onnx implementations, runtime is 

416 :epkg:`onnxruntime-training`. 

417 

418 :: 

419 

420 class CustomClass(torch.autograd.Function): 

421 

422 @staticmethod 

423 def forward(ctx, *input): 

424 ctx.save_for_backward(*input) 

425 return ... 

426 

427 @staticmethod 

428 def backward(ctx, *grad_output): 

429 input, = ctx.saved_tensors 

430 grad_input = grad_output.clone() 

431 grad_input[input < 0] = 0 

432 return grad_input 

433 

434 The new class has the following attributes: 

435 

436 * `__doc__`: doc string 

437 * `__module__`: module name (this file) 

438 * `_run_options`: see :epkg:`RunOptions` 

439 * `_sess`: :epkg:`InferenceSession` with the original graph 

440 * `_sess_eval`: :epkg:`InferenceSession` on the graph 

441 with weights as inputs 

442 * `_training_agent`: :epkg:`TrainingAgent` 

443 * `_cache`: :epkg:`OrtValueCache` 

444 * `_update_cache`: update the cache or not 

445 * `_states`: a list 

446 * `_logger`: logger 

447 * `_input_names`: input names 

448 * `_debug`: use debug mode 

449 * `_grad_input_names`: gradient input names 

450 * `_output_names`: output names 

451 * `_weights_to_train`: names of the weights to train 

452 

453 Torch API: 

454 

455 * `forward`: forward static method 

456 * `backward`: forward static method 

457 

458 Training attributes 

459 

460 * `_bw_fetches_names`: bw_fetches_names, 

461 * `_fw_outputs_device_info`: fw_outputs_device_info, 

462 * `_bw_outputs_device_info`: bw_outputs_device_info, 

463 * `_fw_no_grad_output_device_info`: fw_no_grad_output_device_info, 

464 * `_graph_info`: graph_info} 

465 

466 Additional attributes added if *keep_model* is True: 

467 

468 * `_trained_onnx`: ONNX graph for the gradient 

469 * `_optimized_pre_grad_model`: evaluation ONNX graph taking 

470 weights as inputs 

471 * `_graph_builder`: :epkg:`OrtModuleGraphBuilder` 

472 """ 

473 if enable_logging: 

474 logger = logging.getLogger("deeponnxcustom") 

475 else: 

476 logger = None # pragma: no cover 

477 

478 doc = dedent("""Use onnxruntime to compute the gradient 

479 in a pytorch function.""") 

480 

481 if logger is not None: 

482 logger.info("[TorchOrtFactory] create training onnx") 

483 logger.info("[TorchOrtFactory] input_names=%r", 

484 self.input_names) 

485 logger.info("[TorchOrtFactory] output_names=%r", 

486 self.output_names) 

487 logger.info("[TorchOrtFactory] weights_to_train=%r", 

488 self.weights_to_train) 

489 

490 builder = OrtModuleGraphBuilder() 

491 

492 if logger is not None: 

493 cf = self.graph_builder_config.graph_transformer_config 

494 cfp = cf.propagate_cast_ops_config 

495 logger.info("[TorchOrtFactory] OrtModuleGraphBuilder.initialize") 

496 logger.info( 

497 "[TorchOrtFactory] graph_builder_config=%s", 

498 TorchOrtFactory._repr_helper_( 

499 self.graph_builder_config, indent=4)) 

500 logger.info( 

501 "[TorchOrtFactory] graph_builder_config." 

502 "graph_transformer_config=%s", 

503 TorchOrtFactory._repr_helper_(cf, indent=4)) 

504 logger.info( 

505 "[TorchOrtFactory] graph_builder_config." 

506 "graph_transformer_config.propagate_cast_ops_config=%s", 

507 TorchOrtFactory._repr_helper_(cfp, indent=4)) 

508 

509 builder.initialize( 

510 self.onnx_model.SerializeToString(), 

511 self.graph_builder_config) 

512 

513 if logger is not None: 

514 logger.info("[TorchOrtFactory] OrtModuleGraphBuilder.build") 

515 builder.build() 

516 

517 if logger is not None: 

518 logger.info("[TorchOrtFactory] OrtModuleGraphBuilder.get_model") 

519 

520 train_onnx_model_serialized = builder.get_model() 

521 

522 optimized_pre_grad_model = builder.get_inference_optimized_model() 

523 graph_info = builder.get_graph_info() 

524 

525 if logger is not None: 

526 logger.info("[TorchOrtFactory] graph_info=%s", 

527 TorchOrtFactory._repr_helper_( 

528 graph_info, indent=4)) 

529 logger.info("[TorchOrtFactory] create TrainSession") 

530 logger.info("[TorchOrtFactory] sess_options=%s", 

531 TorchOrtFactory._repr_helper_( 

532 self.sess_options, indent=4)) 

533 logger.info("[TorchOrtFactory] providers=%r", self.providers) 

534 

535 sess = InferenceSession( 

536 train_onnx_model_serialized, 

537 sess_options=self.sess_options, 

538 provider_options=self.provider_options, 

539 providers=self.providers) 

540 

541 if logger is not None: 

542 logger.info("[TorchOrtFactory] create InferenceSession") 

543 

544 sess_eval = InferenceSession( 

545 optimized_pre_grad_model, 

546 sess_options=self.sess_options, 

547 provider_options=self.provider_options, 

548 providers=self.providers) 

549 

550 if logger is not None: 

551 logger.info("[TorchOrtFactory] create training agent") 

552 

553 grad_input_names = [obj.name for obj in sess.get_inputs()] 

554 bw_fetches_names = [obj.name for obj in sess.get_outputs()] 

555 

556 fw_outputs_device_info = [ 

557 OrtDevice( 

558 TorchOrtFactory._provider_name_to_device_type(i), 

559 OrtDevice.default_memory(), 

560 self.device_index) 

561 for i in self.providers] 

562 bw_outputs_device_info = [ 

563 OrtDevice( 

564 TorchOrtFactory._provider_name_to_device_type( 

565 self.providers[0]), 

566 OrtDevice.default_memory(), 

567 self.device_index) 

568 for i in bw_fetches_names] 

569 fw_no_grad_output_device_info = [ 

570 OrtDevice( 

571 TorchOrtFactory._provider_name_to_device_type( 

572 self.providers[0]), 

573 OrtDevice.default_memory(), 

574 self.device_index) 

575 for i in self.output_names] 

576 

577 training_agent = TrainingAgent( 

578 sess._sess, 

579 grad_input_names, 

580 fw_outputs_device_info, 

581 bw_fetches_names, 

582 bw_outputs_device_info) 

583 

584 if logger is not None: 

585 logger.info( 

586 "[TorchOrtFactory] instantiate dynamic class %r", 

587 self.class_name) 

588 logger.info( 

589 "[TorchOrtFactory] weights_to_train=%r", 

590 self.weights_to_train) 

591 logger.info( 

592 "[TorchOrtFactory] grad_input_names=%r", 

593 grad_input_names) 

594 logger.info( 

595 "[TorchOrtFactory] bw_fetches_names=%r", 

596 bw_fetches_names) 

597 logger.info( 

598 "[TorchOrtFactory] device_index=%r", 

599 self.device_index) 

600 

601 kwargs = { 

602 '__doc__': doc, 

603 '__module__': __name__, 

604 '_run_options': self.run_options, 

605 '_sess': sess, 

606 '_sess_eval': sess_eval, 

607 '_training_agent': training_agent, 

608 '_cache': OrtValueCache(), 

609 '_update_cache': False, 

610 '_states': [], 

611 '_logger': logger, 

612 '_input_names': self.input_names, 

613 '_debug': debug, 

614 '_grad_input_names': grad_input_names, 

615 '_output_names': self.output_names, 

616 '_bw_fetches_names': bw_fetches_names, 

617 '_fw_outputs_device_info': fw_outputs_device_info, 

618 '_bw_outputs_device_info': bw_outputs_device_info, 

619 '_fw_no_grad_output_device_info': fw_no_grad_output_device_info, 

620 '_weights_to_train': list(sorted( 

621 self.weights_to_train)), 

622 'forward': staticmethod(ort_forward), 

623 'backward': staticmethod(ort_backward), 

624 '_graph_info': graph_info} 

625 

626 if keep_models: 

627 kwargs.update(dict( 

628 _trained_onnx=onnx.load(BytesIO(train_onnx_model_serialized)), 

629 _optimized_pre_grad_model=onnx.load( 

630 BytesIO(optimized_pre_grad_model)), 

631 _graph_builder=builder, 

632 _factory=self)) 

633 

634 onx_inp = [o.name for o in kwargs['_trained_onnx'].graph.input] 

635 onx_out = [o.name for o in kwargs['_trained_onnx'].graph.output] 

636 if len(onx_inp) != len(onx_out): 

637 raise RuntimeError( 

638 "Gradient input and output are inconsistant: " 

639 "%r != %r" % (onx_inp, onx_out)) 

640 

641 newclass = type(self.class_name, (TorchOrtFunction,), kwargs) 

642 return newclass