Coverage for deeponnxcustom/tools/onnx_helper.py: 100%
22 statements
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-06 02:28 +0200
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-06 02:28 +0200
1"""
2@file
3@brief Helpers about ONNX.
4"""
5import math
8def save_as_onnx(model, filename, size=None, target_opset=14,
9 batch_size=1, device='cpu',
10 keep_initializers_as_inputs=False):
11 """
12 Converts a torch model into ONNX using
13 :func:`torch.onnx.export`. The function works
14 on models with only one input.
16 :param model: torch model
17 :param filename: output filename
18 :param size: input size or left None to guess it from the model
19 :param target_opset: opset to use for the conversion
20 :param batch_size: batch size
21 :param device: device
22 :param keep_initializers_as_inputs: see :func:`torch.onnx.export`
24 .. exref::
25 :title: Export a torch model into ONNX
27 ::
29 import torch
30 from deeponnxcustom.tools.onnx_helper import save_as_onnx
32 class MyModel(torch.nn.Module):
33 # ...
35 nn = MyModel()
36 save_as_onnx(nn, "my_model.onnx")
37 """
38 import torch # pylint: disable=C0415
40 if size is None:
41 for p in model.named_parameters():
42 name, value = p
43 if name.endswith('weight'):
44 size = value.shape[-1]
45 break
46 if size is None:
47 raise RuntimeError(
48 "Unable to guess size from the following list of "
49 "parameters:\n%s" % ("\n".join(
50 "%r: shape=%r - dtype=%r" % (name, tuple(v.shape), v.dtype)
51 for name, v in model.named_parameters())))
53 size = (batch_size, ) + (size, )
54 x = torch.randn( # pylint: disable=E1101
55 size, requires_grad=True).to(device)
56 torch.onnx.export(
57 model, x, filename,
58 do_constant_folding=False,
59 export_params=False,
60 keep_initializers_as_inputs=keep_initializers_as_inputs,
61 input_names=['input'], output_names=['output'],
62 dynamic_axes={'input': {0: 'batch_size'},
63 'output': {0: 'batch_size'}})
66def onnx_rename_weights(onx):
67 """
68 Renames ONNX initialiers to make sure their name
69 follows the alphabetical order. The model is
70 modified inplace. This function calls
71 :func:`onnx_rename_names
72 <mlprodict.onnx_tools.onnx_manipulations.onnx_rename_names>`.
74 :param onx: ONNX model
75 :return: same model
77 .. note::
78 The function does not go into subgraphs.
79 """
80 from mlprodict.onnx_tools.onnx_manipulations import ( # pylint: disable=C0415
81 onnx_rename_names)
83 init = [init.name for init in onx.graph.initializer]
84 ninit = max(1, int(math.log(len(init)) / math.log(10) + 1))
85 fmt = "I%0{}d_%s".format(ninit)
86 new_names = [fmt % (i, name) for i, name in enumerate(init)]
87 repl = dict(zip(init, new_names))
88 return onnx_rename_names(onx, recursive=False, replace=repl)