Coverage for deeponnxcustom/tools/onnx_helper.py: 100%

22 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-06-06 02:28 +0200

1""" 

2@file 

3@brief Helpers about ONNX. 

4""" 

5import math 

6 

7 

8def save_as_onnx(model, filename, size=None, target_opset=14, 

9 batch_size=1, device='cpu', 

10 keep_initializers_as_inputs=False): 

11 """ 

12 Converts a torch model into ONNX using 

13 :func:`torch.onnx.export`. The function works 

14 on models with only one input. 

15 

16 :param model: torch model 

17 :param filename: output filename 

18 :param size: input size or left None to guess it from the model 

19 :param target_opset: opset to use for the conversion 

20 :param batch_size: batch size 

21 :param device: device 

22 :param keep_initializers_as_inputs: see :func:`torch.onnx.export` 

23 

24 .. exref:: 

25 :title: Export a torch model into ONNX 

26 

27 :: 

28 

29 import torch 

30 from deeponnxcustom.tools.onnx_helper import save_as_onnx 

31 

32 class MyModel(torch.nn.Module): 

33 # ... 

34 

35 nn = MyModel() 

36 save_as_onnx(nn, "my_model.onnx") 

37 """ 

38 import torch # pylint: disable=C0415 

39 

40 if size is None: 

41 for p in model.named_parameters(): 

42 name, value = p 

43 if name.endswith('weight'): 

44 size = value.shape[-1] 

45 break 

46 if size is None: 

47 raise RuntimeError( 

48 "Unable to guess size from the following list of " 

49 "parameters:\n%s" % ("\n".join( 

50 "%r: shape=%r - dtype=%r" % (name, tuple(v.shape), v.dtype) 

51 for name, v in model.named_parameters()))) 

52 

53 size = (batch_size, ) + (size, ) 

54 x = torch.randn( # pylint: disable=E1101 

55 size, requires_grad=True).to(device) 

56 torch.onnx.export( 

57 model, x, filename, 

58 do_constant_folding=False, 

59 export_params=False, 

60 keep_initializers_as_inputs=keep_initializers_as_inputs, 

61 input_names=['input'], output_names=['output'], 

62 dynamic_axes={'input': {0: 'batch_size'}, 

63 'output': {0: 'batch_size'}}) 

64 

65 

66def onnx_rename_weights(onx): 

67 """ 

68 Renames ONNX initialiers to make sure their name 

69 follows the alphabetical order. The model is 

70 modified inplace. This function calls 

71 :func:`onnx_rename_names 

72 <mlprodict.onnx_tools.onnx_manipulations.onnx_rename_names>`. 

73 

74 :param onx: ONNX model 

75 :return: same model 

76 

77 .. note:: 

78 The function does not go into subgraphs. 

79 """ 

80 from mlprodict.onnx_tools.onnx_manipulations import ( # pylint: disable=C0415 

81 onnx_rename_names) 

82 

83 init = [init.name for init in onx.graph.initializer] 

84 ninit = max(1, int(math.log(len(init)) / math.log(10) + 1)) 

85 fmt = "I%0{}d_%s".format(ninit) 

86 new_names = [fmt % (i, name) for i, name in enumerate(init)] 

87 repl = dict(zip(init, new_names)) 

88 return onnx_rename_names(onx, recursive=False, replace=repl)