Coverage for onnxcustom/training/ortgradient.py: 95%

330 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 01:42 +0100

1# pylint: disable=E1101 

2""" 

3@file 

4@brief Gradient with :epkg:`onnxruntime-training` forward backward. 

5""" 

6import os 

7import logging 

8import warnings 

9from io import BytesIO 

10import onnx 

11from onnx.numpy_helper import to_array 

12from onnxruntime import InferenceSession, RunOptions 

13from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

14 SessionIOBinding, OrtValue as C_OrtValue) 

15from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

16 TrainingAgent, OrtValueCache, OrtModuleGraphBuilder, 

17 OrtModuleGraphBuilderConfiguration, OrtDevice, 

18 TrainingGraphTransformerConfiguration, OrtValueVector, 

19 PartialGraphExecutionState) 

20from ..utils.orttraining_helper import get_train_initializer 

21 

22 

23class OrtGradientForwardBackward: 

24 """ 

25 Implements forward backward mechanism assuming the function 

26 to train is defined by an ONNX graph. 

27 

28 :param onnx_model: onnx model 

29 :param weights_to_train: names of the weights to train, 

30 if None, all initializer of floats type are included in the list 

31 :param input_names: input names or None for all 

32 :param output_names: output names or None for all 

33 :param class_name: name to give the class dynamically created 

34 :param sess_options: see :epkg:`SessionOptions` 

35 :param providers: see :epkg:`InferenceSession` 

36 :param provider_options: see :epkg:`InferenceSession` 

37 :param run_options: see :epkg:`RunOptions` 

38 :param graph_builder_config: 

39 see :epkg:`OrtModuleGraphBuilderConfiguration` 

40 :param device_index: used for cuda (0 for `cuda:0`, 

41 `cuda:1`, ...), 0 by default 

42 :param enable_logging: enables logging while setting up the class 

43 :param debug: to run extra verification while training 

44 

45 .. note:: 

46 The current implementation of :epkg:`onnxruntime` forces 

47 the weights to train to appear in the alphabetical order. 

48 The constructor checks that condition is verified. 

49 

50 .. warning:: 

51 This class does not consider subgraphs. 

52 """ 

53 

54 def __init__(self, onnx_model, weights_to_train=None, 

55 input_names=None, output_names=None, class_name=None, 

56 sess_options=None, providers=None, 

57 provider_options=None, run_options=None, 

58 graph_builder_config=None, 

59 device_index=0, enable_logging=False, debug=False): 

60 

61 if weights_to_train is None: 

62 weights_to_train = ( 

63 OrtGradientForwardBackward._select_initializer_names( 

64 onnx_model)) 

65 if len(weights_to_train) == 0: 

66 raise RuntimeError( # pragma: no cover 

67 "Unable to guess the weights to train from initializers: " 

68 "%r." % [i.name for i in onnx_model.graph.initializer]) 

69 

70 self.onnx_model = onnx_model 

71 self.input_names = input_names 

72 self.output_names = output_names 

73 self.weights_to_train = weights_to_train 

74 self.device_index = device_index 

75 self.enable_logging = enable_logging 

76 self.class_name = (class_name if class_name is not None else 

77 "OrtGradientForwardBackwardFunction_%d" % id(self)) 

78 

79 self.provider_options = provider_options 

80 self.sess_options = sess_options 

81 self.providers = providers 

82 self.run_options = run_options 

83 self.graph_builder_config = graph_builder_config 

84 self.debug = debug 

85 

86 # default 

87 if self.weights_to_train is None: 

88 raise ValueError( # pragma: no cover 

89 "weights_to_train must be specified.") 

90 if self.input_names is None: 

91 self.input_names = [obj.name 

92 for obj in self.onnx_model.graph.input] 

93 if self.output_names is None: 

94 self.output_names = [obj.name 

95 for obj in self.onnx_model.graph.output] 

96 if self.class_name is None: 

97 self.class_name = f"TorchOrtFunction_{id(self)!r}" # pragma: no cover 

98 if hasattr(self.providers, 'type'): 

99 if self.providers.type != 'cpu': 

100 self.device_index = self.providers.index 

101 self.providers = self.providers.type 

102 if self.providers in (None, 'cpu'): 

103 self.providers = ["CPUExecutionProvider" for i in self.input_names] 

104 if self.provider_options is None: 

105 self.provider_options = [{} for i in self.input_names] 

106 elif self.providers in ('cuda', 'cuda:0', 'gpu'): 

107 self.providers = [ 

108 "CUDAExecutionProvider" for i in self.input_names] 

109 if self.provider_options is None: 

110 self.provider_options = [{} for i in self.input_names] 

111 if self.provider_options is None: 

112 self.provider_options = [{} for i in self.providers] 

113 

114 if list(sorted(self.weights_to_train)) != self.weights_to_train: 

115 raise ValueError( # pragma: no cover 

116 "List of weights to train must be sorted but %r is not. " 

117 "You shoud use function onnx_rename_weights to do that " 

118 "before calling this class." % self.weights_to_train) 

119 set_weights = set(self.weights_to_train) 

120 if len(set_weights) != len(self.weights_to_train): 

121 raise ValueError( # pragma: no cover 

122 f"One weight is not unique in {self.weights_to_train!r}.") 

123 found = [] 

124 for i in self.onnx_model.graph.initializer: 

125 if i.name not in set_weights: 

126 continue 

127 found.append(i.name) 

128 if len(found) != len(self.weights_to_train): 

129 raise ValueError( 

130 "One weight name in self.weights_to_train was not found in " 

131 "the initializers %r found=%r init names=%r." % ( 

132 self.weights_to_train, found, 

133 [i.name for i in self.onnx_model.graph.initializer])) 

134 if found != self.weights_to_train: 

135 raise ValueError( 

136 "List of weights to train must be sorted and follow the " 

137 "as the initializers in the graph. %r != %r." 

138 "You shoud use function onnx_rename_weights to do that " 

139 "before calling this class." % ( 

140 self.weights_to_train, found)) 

141 

142 if any(map(lambda v: v not in ['CPUExecutionProvider', 

143 'CUDAExecutionProvider'], 

144 self.providers)): 

145 raise ValueError( 

146 f"Unexpected providers {self.providers!r} (providers={providers!r}).") 

147 

148 # complete initialisation 

149 self._init_next() 

150 

151 @staticmethod 

152 def _select_initializer_names(onnx_model): 

153 """ 

154 Selects all initializers with float type. 

155 

156 :param onnx_model: ONNX graph 

157 """ 

158 inits = get_train_initializer(onnx_model) 

159 return list(inits) 

160 

161 def _init_next(self): 

162 if self.enable_logging: 

163 self._logger = logging.getLogger("onnxcustom") 

164 else: 

165 self._logger = None # pragma: no cover 

166 if self.run_options is None: 

167 self.run_options = RunOptions() 

168 self.run_options.training_mode = True 

169 

170 if self.graph_builder_config is None: 

171 initializer_names = [ 

172 i.name for i in self.onnx_model.graph.initializer] 

173 input_names = [i.name for i in self.onnx_model.graph.input] 

174 

175 config = OrtModuleGraphBuilderConfiguration() 

176 config.initializer_names = [init for init in initializer_names 

177 if init in self.weights_to_train] 

178 config.initializer_names_to_train = self.weights_to_train 

179 config.input_names_require_grad = input_names 

180 config.build_gradient_graph = True 

181 

182 if (len(config.initializer_names) != # noqa 

183 len(config.initializer_names_to_train)): 

184 raise RuntimeError( # pragma: no cover 

185 "Unable to automatically fill " 

186 "OrtModuleGraphBuilderConfiguration, mismatch between " 

187 "%r and %r (initializer_names=%r)." % ( 

188 config.initializer_names, 

189 config.initializer_names_to_train, 

190 initializer_names)) 

191 

192 p = TrainingGraphTransformerConfiguration() 

193 config.graph_transformer_config = p 

194 

195 # config.enable_caching = True 

196 # config.loglevel = 

197 # config.use_memory_efficient_gradient = True 

198 self.graph_builder_config = config 

199 

200 attributes = self._create_onnx_graphs() 

201 attributes['__doc__'] = ( 

202 "Inherits from @see cl OrtGradientForwardBackwardFunction.") 

203 attributes['__module__'] = ( 

204 OrtGradientForwardBackwardFunction.__module__) 

205 self.cls_type_ = type( 

206 self.class_name, (OrtGradientForwardBackwardFunction,), 

207 attributes) 

208 

209 def new_instance(self): 

210 """ 

211 Creates an instance of class `self.cls_type_`. 

212 It implements methods *forward* and *backward*. 

213 """ 

214 return self.cls_type_() 

215 

216 def __getstate__(self): 

217 "Removes any non pickable attribute." 

218 atts = [k for k in self.__dict__ if not k.endswith('_') 

219 if k not in {'_logger', 'graph_builder_config', 

220 'run_options'}] 

221 state = {att: getattr(self, att) for att in atts} 

222 state['run_options'] = None 

223 state['graph_builder_config'] = None 

224 return state 

225 

226 def __setstate__(self, state): 

227 "Restores any non pickable attribute." 

228 for att, v in state.items(): 

229 setattr(self, att, v) 

230 self._init_next() 

231 return self 

232 

233 def __repr__(self): 

234 "usual" 

235 return f"{self.__class__.__name__}(...)" 

236 

237 @staticmethod 

238 def _repr_helper_(obj, indent=0): 

239 "used to improve logging messages" 

240 if obj is None: 

241 return 'None' 

242 rows = [] 

243 for c in sorted(dir(obj)): 

244 if c[0] == '_': 

245 continue 

246 try: 

247 value = getattr(obj, c) 

248 except AttributeError: # pragma: no cover 

249 continue 

250 rows.append(f"{c}={value!r}") 

251 

252 if indent == 0: 

253 return f"{obj.__class__.__name__}({', '.join(rows)})" 

254 return "%s(\n %s)" % ( 

255 obj.__class__.__name__, 

256 "\n ".join(rows)) 

257 

258 @staticmethod 

259 def _provider_name_to_device_type(provider_name): 

260 if provider_name == 'CPUExecutionProvider': 

261 return OrtDevice.cpu() 

262 if provider_name == 'CUDAExecutionProvider': # pragma: no cover 

263 return OrtDevice.cuda() 

264 raise ValueError( # pragma: no cover 

265 f'Unexpected provider name {provider_name!r}.') 

266 

267 def get_initializer(self, name, exc=True): 

268 """ 

269 Returns an initializer as numpy arrays. 

270 

271 :param name: initializer name 

272 :param exc: raises an exception if not found or return None 

273 :return: the initializer as a :epkg:`C_OrtValue` 

274 """ 

275 for init in self.onnx_model.graph.initializer: 

276 if name == init.name: 

277 return to_array(init) 

278 if exc: 

279 raise RuntimeError( # pragma: no cover 

280 "Unable to find name %r in %r." % ( 

281 name, 

282 list(i.name for i in self.onnx_model.graph.initializer))) 

283 return None 

284 

285 def _create_onnx_graphs(self): 

286 """ 

287 Creates forward and backward ONNX graph. 

288 The new class has the following attributes: 

289 

290 * `__doc__`: doc string 

291 * `__module__`: module name (this file) 

292 * `_run_options`: see :epkg:`RunOptions` 

293 * `_sess`: :epkg:`InferenceSession` with the original graph 

294 * `_sess_eval`: :epkg:`InferenceSession` on the graph 

295 with weights as inputs 

296 * `_training_agent`: :epkg:`TrainingAgent` 

297 * `_cache`: :epkg:`OrtValueCache` 

298 * `_logger`: logger 

299 * `_input_names`: input names 

300 * `_debug`: use debug mode 

301 * `_grad_input_names`: gradient input names 

302 * `_output_names`: output names 

303 * `_weights_to_train`: names of the weights to train 

304 

305 Training attributes 

306 

307 * `_bw_fetches_names`: bw_fetches_names, 

308 * `_fw_outputs_device_info`: fw_outputs_device_info, 

309 * `_bw_outputs_device_info`: bw_outputs_device_info, 

310 * `_fw_no_grad_output_device_info`: fw_no_grad_output_device_info, 

311 * `_graph_info`: graph_info} 

312 

313 Additional attributes added if *keep_model* is True: 

314 

315 * `_trained_onnx`: ONNX graph for the gradient 

316 * `_optimized_pre_grad_model`: evaluation ONNX graph taking 

317 weights as inputs 

318 * `_graph_builder`: :epkg:`OrtModuleGraphBuilder` 

319 """ 

320 logger = self._logger 

321 if logger is not None: 

322 logger.info("[OrtGradientForwardBackward] create training onnx") 

323 logger.info("[OrtGradientForwardBackward] input_names=%r", 

324 self.input_names) 

325 logger.info("[OrtGradientForwardBackward] output_names=%r", 

326 self.output_names) 

327 logger.info("[OrtGradientForwardBackward] weights_to_train=%r", 

328 self.weights_to_train) 

329 

330 builder = OrtModuleGraphBuilder() 

331 

332 if logger is not None: 

333 cf = self.graph_builder_config.graph_transformer_config 

334 cfp = cf.propagate_cast_ops_config 

335 logger.info( 

336 "[OrtGradientForwardBackward] " 

337 "OrtModuleGraphBuilder.initialize") 

338 logger.info( 

339 "[OrtGradientForwardBackward] graph_builder_config=%s", 

340 OrtGradientForwardBackward._repr_helper_( 

341 self.graph_builder_config, indent=4)) 

342 logger.info( 

343 "[OrtGradientForwardBackward] graph_builder_config." 

344 "graph_transformer_config=%s", 

345 OrtGradientForwardBackward._repr_helper_(cf, indent=4)) 

346 logger.info( 

347 "[OrtGradientForwardBackward] graph_builder_config." 

348 "graph_transformer_config.propagate_cast_ops_config=%s", 

349 OrtGradientForwardBackward._repr_helper_(cfp, indent=4)) 

350 

351 builder.initialize( 

352 self.onnx_model.SerializeToString(), 

353 self.graph_builder_config) 

354 

355 if logger is not None: 

356 logger.info( 

357 "[OrtGradientForwardBackward] OrtModuleGraphBuilder.build") 

358 builder.build() 

359 

360 if logger is not None: 

361 logger.info( 

362 "[OrtGradientForwardBackward] OrtModuleGraphBuilder.get_model") 

363 

364 try: 

365 train_onnx_model_serialized = builder.get_gradient_model() 

366 except AttributeError: 

367 # older version 

368 train_onnx_model_serialized = builder.get_model() 

369 try: 

370 optimized_pre_grad_model = builder.get_forward_model() 

371 except AttributeError: 

372 # older version 

373 optimized_pre_grad_model = builder.get_inference_optimized_model() 

374 graph_info = builder.get_graph_info() 

375 

376 if logger is not None: 

377 logger.info("[OrtGradientForwardBackward] graph_info=%s", 

378 OrtGradientForwardBackward._repr_helper_( 

379 graph_info, indent=4)) 

380 logger.info("[OrtGradientForwardBackward] create TrainSession") 

381 logger.info("[OrtGradientForwardBackward] sess_options=%s", 

382 OrtGradientForwardBackward._repr_helper_( 

383 self.sess_options, indent=4)) 

384 logger.info( 

385 "[OrtGradientForwardBackward] providers=%r", self.providers) 

386 

387 sess = InferenceSession( 

388 train_onnx_model_serialized, sess_options=self.sess_options, 

389 provider_options=self.provider_options, providers=self.providers) 

390 

391 if logger is not None: 

392 logger.info("[OrtGradientForwardBackward] create InferenceSession") 

393 

394 sess_eval = InferenceSession( 

395 optimized_pre_grad_model, sess_options=self.sess_options, 

396 provider_options=self.provider_options, providers=self.providers) 

397 

398 if logger is not None: 

399 logger.info("[OrtGradientForwardBackward] create training agent") 

400 

401 grad_input_names = [obj.name for obj in sess.get_inputs()] 

402 bw_fetches_names = [obj.name for obj in sess.get_outputs()] 

403 

404 fw_outputs_device_info = [ 

405 OrtDevice( 

406 OrtGradientForwardBackward._provider_name_to_device_type(i), 

407 OrtDevice.default_memory(), self.device_index) 

408 for i in self.providers] 

409 bw_outputs_device_info = [ 

410 OrtDevice( 

411 OrtGradientForwardBackward._provider_name_to_device_type( 

412 self.providers[0]), 

413 OrtDevice.default_memory(), self.device_index) 

414 for i in bw_fetches_names] 

415 fw_no_grad_output_device_info = [ 

416 OrtDevice( 

417 OrtGradientForwardBackward._provider_name_to_device_type( 

418 self.providers[0]), 

419 OrtDevice.default_memory(), self.device_index) 

420 for i in self.output_names] 

421 

422 try: 

423 # onnxruntime>=1.12 

424 training_agent = TrainingAgent( 

425 sess._sess, 

426 grad_input_names, 

427 fw_outputs_device_info, 

428 bw_fetches_names, 

429 bw_outputs_device_info, 

430 0) 

431 except TypeError: 

432 # onnxruntime<=1.11 

433 training_agent = TrainingAgent( 

434 sess._sess, 

435 grad_input_names, 

436 fw_outputs_device_info, 

437 bw_fetches_names, 

438 bw_outputs_device_info) 

439 

440 if logger is not None: 

441 logger.info( 

442 "[OrtGradientForwardBackward] instantiate dynamic class %r", 

443 self.class_name) 

444 logger.info( 

445 "[OrtGradientForwardBackward] weights_to_train=%r", 

446 self.weights_to_train) 

447 logger.info( 

448 "[OrtGradientForwardBackward] grad_input_names=%r", 

449 grad_input_names) 

450 logger.info( 

451 "[OrtGradientForwardBackward] bw_fetches_names=%r", 

452 bw_fetches_names) 

453 logger.info( 

454 "[OrtGradientForwardBackward] device_index=%r", 

455 self.device_index) 

456 devices = list(fw_outputs_device_info) 

457 while len(devices) < len(grad_input_names): 

458 devices.append(devices[-1]) 

459 

460 trained_onnx = onnx.load(BytesIO(train_onnx_model_serialized)) 

461 onnx_loss = onnx.load(BytesIO(optimized_pre_grad_model)) 

462 for i, node in enumerate(trained_onnx.graph.node): 

463 if node.name == '': 

464 node.name = "N%d" % i 

465 for i, node in enumerate(onnx_loss.graph.node): 

466 if node.name == '': 

467 node.name = "N%d" % i 

468 

469 kwargs = { 

470 '_run_options': self.run_options, 

471 '_sess': sess, 

472 '_sess_eval': sess_eval, 

473 '_training_agent': training_agent, 

474 '_cache': OrtValueCache(), 

475 '_logger': logger, 

476 '_input_names': self.input_names, 

477 '_grad_input_names': grad_input_names, 

478 '_output_names': self.output_names, 

479 '_bw_fetches_names': bw_fetches_names, 

480 '_fw_outputs_device_info': fw_outputs_device_info, 

481 '_bw_outputs_device_info': bw_outputs_device_info, 

482 '_fw_no_grad_output_device_info': fw_no_grad_output_device_info, 

483 '_weights_to_train': list(sorted(self.weights_to_train)), 

484 '_graph_info': graph_info, 

485 # 

486 '_trained_onnx': trained_onnx, 

487 '_optimized_pre_grad_model': onnx_loss, 

488 '_graph_builder': builder, 

489 '_devices': devices, 

490 '_debug': self.debug 

491 } 

492 graph = kwargs['_trained_onnx'].graph 

493 kwargs.update({ 

494 '_onx_inp': [o.name for o in graph.input], 

495 '_onx_out': [o.name for o in graph.output] 

496 }) 

497 

498 if len(kwargs['_onx_inp']) != len(kwargs['_onx_out']): 

499 raise RuntimeError( # pragma: no cover 

500 "Gradient input and output are inconsistant: " 

501 "%r != %r" % (kwargs['_onx_inp'], kwargs['_onx_out'])) 

502 return kwargs 

503 

504 

505class OrtGradientForwardBackwardFunction: 

506 """ 

507 Ancestor for a class implementing forward and backward 

508 and dynamically created by @see cl OrtGradientForwardBackward. 

509 

510 Attributes stored in *forward* method: 

511 * `saved_tensors_`: list of tensors to save during forward 

512 and to retrieve during backward 

513 * `state_`: current weights stored in :epkg:`PartialGraphExecutionState` 

514 """ 

515 

516 def __init__(self): 

517 self.states_ = [] 

518 self.saved_tensors_ = None 

519 

520 @classmethod 

521 def save_onnx_graph(cls, folder, prefix=None, suffix=None): 

522 """ 

523 Saves onnx graph stored in this class. 

524 """ 

525 if prefix is None: 

526 prefix = '' # pragma: no cover 

527 if suffix is None: 

528 suffix = '' # pragma: no cover 

529 if isinstance(folder, str) and not os.path.exists(folder): 

530 raise FileNotFoundError( # pragma: no cover 

531 f"Folder {folder!r} does not exist.") 

532 saved = {} 

533 for k, v in cls.__dict__.items(): 

534 if hasattr(v, "SerializeToString"): 

535 if isinstance(folder, str): 

536 name = f"{prefix}{cls.__name__}{suffix}.{k}.onnx" 

537 filename = os.path.join(folder, name) 

538 if os.path.exists(filename): 

539 warnings.warn( # pragma: no cover 

540 f"Filename {filename!r} already exists.") 

541 with open(filename, "wb") as f: 

542 f.write(v.SerializeToString()) 

543 saved[k] = filename 

544 else: 

545 saved[k] = v.SerializeToString() 

546 elif hasattr(v, "save_onnx_graph"): 

547 saved[k] = v.save_onnx_graph( 

548 folder, prefix=prefix, suffix=f"{suffix}.{k}") 

549 return saved 

550 

551 @staticmethod 

552 def device_name(device): 

553 """ 

554 Returns the device name of a device. 

555 

556 :param device: OrtDevice 

557 :return: string 

558 """ 

559 if device.device_type() == OrtDevice.cpu(): 

560 return 'Cpu' 

561 if device.device_type() == OrtDevice.cuda(): # pragma: no cover 

562 return 'Gpu' 

563 raise RuntimeError( # pragma: no cover 

564 f"Unexpected value for device type {device.device_type()!r}.") 

565 

566 @staticmethod 

567 def input_to_ort(tensors, devices, debug): 

568 "Converts a list of tensos into an :epkg:`OrtValueVector`." 

569 def _validate_(tensors): 

570 if any(map( 

571 lambda tu: ( 

572 tu[0].device_name() != 

573 OrtGradientForwardBackwardFunction.device_name( 

574 tu[1])), 

575 zip(tensors, devices))): 

576 raise RuntimeError( # pragma: no cover 

577 "Not all inputs are on the same device %r != %r." % ( 

578 [OrtGradientForwardBackward.device_name(d) 

579 for d in devices], 

580 [x.device_name() for x in tensors])) 

581 

582 if isinstance(tensors, OrtValueVector): 

583 if debug: 

584 _validate_(tensors) 

585 return tensors 

586 if all(map(lambda t: isinstance(t, C_OrtValue), tensors)): 

587 if debug: 

588 _validate_(tensors) 

589 vect = OrtValueVector() 

590 vect.reserve(len(tensors)) 

591 for t in tensors: 

592 if t is None: 

593 raise NotImplementedError( # pragma: no cover 

594 "Empty vector found.") 

595 vect.push_back(t) 

596 return vect 

597 

598 # generic case 

599 vect = OrtValueVector() 

600 vect.reserve(len(tensors)) 

601 for t, dev in zip(tensors, devices): 

602 if t is None: 

603 # if gradient then 

604 # grad_output = torch.zeros(shape, device=device, dtype=dtype) 

605 raise NotImplementedError( # pragma: no cover 

606 "Empty vector found.") 

607 if not t.data.contiguous: 

608 t = t.as_contiguous() # pragma: no cover 

609 vect.push_back(C_OrtValue.ortvalue_from_numpy(t, dev)) 

610 if debug: 

611 if len(vect) != len(tensors): 

612 raise RuntimeError( # pragma: no cover 

613 "Unexpected array length %d != %d (len(devices)=%d)." % ( 

614 len(vect), len(tensors), len(devices))) 

615 _validate_(vect) 

616 return vect 

617 

618 def save_for_backward(self, inputs): 

619 """ 

620 Saves inputs furing forward steps. The list inputs 

621 is copied (simple copy, no deep copy). 

622 

623 :param inputs: list of tensors to save. 

624 """ 

625 self.saved_tensors_ = list(inputs) 

626 

627 @property 

628 def saved_tensors(self): 

629 """ 

630 Returns saved tensors during forward step. 

631 """ 

632 if self.saved_tensors_ is None: 

633 raise RuntimeError( # pragma: no cover 

634 "No tensors was saved with save_for_backward.") 

635 return self.saved_tensors_ 

636 

637 def forward(self, inputs, training=False, forward_outputs_cache=None): 

638 """ 

639 Implements forward function. 

640 

641 :param inputs: inputs 

642 :param training: only inference or training as well 

643 :return: output as :epkg:`OrtValueVector` 

644 """ 

645 logger = self._logger 

646 cls = self.__class__ 

647 

648 def _log(msg, *args): 

649 logger.debug("[%s.forward] (%dI) " + msg, 

650 cls.__name__, len(inputs), *args) 

651 

652 if logger is not None: 

653 _log("begin with gradient" if training else "begin") 

654 _log("torch function %r", type(cls)) 

655 _log("ort class %r", cls) 

656 _log("create OrtValueVector (through dlpack)") 

657 

658 forward_inputs = cls.input_to_ort(inputs, cls._devices, cls._debug) 

659 

660 if training: 

661 forward_outputs = forward_outputs_cache or OrtValueVector() 

662 state = PartialGraphExecutionState() 

663 self.states_.append(state) 

664 if logger is not None: 

665 _log("run_forward") 

666 cls._training_agent.run_forward( 

667 forward_inputs, forward_outputs, state, cls._cache) 

668 

669 self.save_for_backward(inputs) 

670 if logger is not None: 

671 _log("end") 

672 return forward_outputs 

673 else: 

674 # what about bind_input (+ data_ptr) 

675 if len(forward_inputs) != len(cls._grad_input_names): 

676 raise RuntimeError( # pragma: no cover 

677 "Size mismatch len(inputs)=%d, len(onnx inputs)=%d." % ( 

678 len(forward_inputs), len(cls._grad_input_names))) 

679 iobinding = SessionIOBinding(cls._sess_eval._sess) 

680 if logger is not None: 

681 _log("bind inputs %r", cls._grad_input_names) 

682 for name, inp in zip( 

683 cls._grad_input_names, forward_inputs): 

684 iobinding.bind_ortvalue_input(name, inp) 

685 

686 # bind output 

687 if logger is not None: 

688 _log("bind outputs %r", cls._output_names) 

689 for name, dev in zip( 

690 cls._output_names, cls._fw_no_grad_output_device_info): 

691 iobinding.bind_output(name, dev) 

692 

693 # if the shape is known in advance 

694 # iobinding.bind_output( 

695 # output_desc.name, torch_tensor.device.type, 

696 # _utils.get_device_index(target_device), 

697 # _utils.dtype_torch_to_numpy(torch_tensor.dtype), 

698 # list(torch_tensor.size()), torch_tensor.data_ptr()) 

699 

700 if logger is not None: 

701 _log("grad_enabled=False (run_with_iobinding)") 

702 cls._sess_eval._sess.run_with_iobinding( 

703 iobinding, cls._run_options) 

704 if logger is not None: 

705 _log("get_outputs") 

706 ortvalues = iobinding.get_outputs() 

707 if logger is not None: 

708 _log("to torck.tensor (%d)", len(ortvalues)) 

709 _log("end") 

710 return ortvalues 

711 

712 def backward(self, grad_outputs, backward_outputs_cache=None): 

713 """ 

714 Implements backward function. The function returns 

715 an :epkg:`OrtValueVector`. 

716 """ 

717 cls = self.__class__ 

718 logger = cls._logger 

719 

720 def _log(msg, *args): 

721 logger.debug("[%s.backward] (%dI) " + msg, 

722 cls.__name__, len(grad_outputs), *args) 

723 

724 if logger is not None: 

725 _log("begin") 

726 _log("torch function %r", type(cls)) 

727 _log("ort class %r", cls) 

728 _log("saved_tensors") 

729 

730 inputs = self.saved_tensors 

731 if logger is not None: 

732 _log("DEBUG: saved_tensors %r", type(inputs)) 

733 _log("self.state_.pop()") 

734 state = self.states_.pop() 

735 

736 if logger is not None: 

737 _log("create OrtValueVector") 

738 

739 backward_inputs = cls.input_to_ort( 

740 grad_outputs, cls._bw_outputs_device_info, cls._debug) 

741 

742 if logger is not None: 

743 _log("len(grad_outputs)=%d type(grad_outputs)=%r", 

744 len(grad_outputs), type(grad_outputs)) 

745 _log("len(backward_inputs)=%d type(backward_inputs)=%r", 

746 len(backward_inputs), type(backward_inputs)) 

747 for i in range(len(backward_inputs)): # pylint: disable=C0200 

748 _log("backward_inputs[%d].shape=%r", 

749 i, backward_inputs[i].shape()) 

750 _log("run_backward") 

751 backward_outputs = backward_outputs_cache or OrtValueVector() 

752 cls._training_agent.run_backward( 

753 backward_inputs, backward_outputs, state) 

754 if logger is not None: # pragma: no cover 

755 _log("DEBUG") 

756 for i, ov in enumerate(backward_outputs): 

757 _log("BCK-RET: i=%d - shape=%r - ptr=%r", 

758 i, ov.shape(), ov.data_ptr()) 

759 _log("got %r gradients", len(backward_outputs)) 

760 _log("end") 

761 return backward_outputs