Coverage for onnxcustom/utils/onnxruntime_helper.py: 87%
71 statements
« prev ^ index » next coverage.py v7.0.5, created at 2023-01-17 01:42 +0100
« prev ^ index » next coverage.py v7.0.5, created at 2023-01-17 01:42 +0100
1"""
2@file
3@brief Onnxruntime helper.
4"""
5from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
6 OrtDevice as C_OrtDevice, OrtValue as C_OrtValue)
9def provider_to_device(provider_name):
10 """
11 Converts provider into a device.
13 :param provider_name: provider name
14 :return: device name
16 .. runpython::
17 :showcode:
19 from onnxcustom.utils.onnxruntime_helper import provider_to_device
20 print(provider_to_device('CPUExecutionProvider'))
21 """
22 if provider_name == 'CPUExecutionProvider':
23 return 'cpu'
24 if provider_name == 'CUDAExecutionProvider':
25 return 'cuda'
26 raise ValueError(
27 f"Unexpected value for provider_name={provider_name!r}.")
30def get_ort_device_type(device):
31 """
32 Converts device into device type.
34 :param device: string
35 :return: device type
36 """
37 if isinstance(device, str):
38 if device == 'cuda':
39 return C_OrtDevice.cuda()
40 if device == 'cpu':
41 return C_OrtDevice.cpu()
42 raise ValueError( # pragma: no cover
43 f'Unsupported device type: {device!r}.')
44 if not hasattr(device, 'device_type'):
45 raise TypeError(f'Unsupported device type: {type(device)!r}.')
46 device_type = device.device_type()
47 if device_type in ('cuda', 1):
48 return C_OrtDevice.cuda()
49 if device_type in ('cpu', 0):
50 return C_OrtDevice.cpu()
51 raise ValueError( # pragma: no cover
52 f'Unsupported device type: {device_type!r}.')
55def get_ort_device(device):
56 """
57 Converts device into :epkg:`C_OrtDevice`.
59 :param device: any type
60 :return: :epkg:`C_OrtDevice`
62 Example:
64 ::
66 get_ort_device('cpu')
67 get_ort_device('gpu')
68 get_ort_device('cuda')
69 get_ort_device('cuda:0')
70 """
71 if isinstance(device, C_OrtDevice):
72 return device
73 if isinstance(device, str):
74 if device == 'cpu':
75 return C_OrtDevice(
76 C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0)
77 if device in {'gpu', 'cuda:0', 'cuda', 'gpu:0'}:
78 return C_OrtDevice(
79 C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0)
80 if device.startswith('gpu:'):
81 idx = int(device[4:])
82 return C_OrtDevice(
83 C_OrtDevice.cuda(), C_OrtDevice.default_memory(), idx)
84 if device.startswith('cuda:'):
85 idx = int(device[5:])
86 return C_OrtDevice(
87 C_OrtDevice.cuda(), C_OrtDevice.default_memory(), idx)
88 raise ValueError(
89 f"Unable to interpret string {device!r} as a device.")
90 raise TypeError( # pragma: no cover
91 f"Unable to interpret type {type(device)!r}, ({device!r}) as de device.")
94def ort_device_to_string(device):
95 """
96 Returns a string representing the device.
97 Opposite of function @see fn get_ort_device.
99 :param device: see :epkg:`C_OrtDevice`
100 :return: string
101 """
102 if not isinstance(device, C_OrtDevice):
103 raise TypeError(
104 f"device must be of type C_OrtDevice not {type(device)!r}.")
105 ty = device.device_type()
106 if ty == C_OrtDevice.cpu():
107 sty = 'cpu'
108 elif ty == C_OrtDevice.cuda():
109 sty = 'cuda'
110 else:
111 raise NotImplementedError( # pragma: no cover
112 f"Unable to guess device for {device!r} and type={ty!r}.")
113 idx = device.device_id()
114 if idx == 0:
115 return sty
116 return "%s:%d" % (sty, idx)
119def numpy_to_ort_value(arr, device=None):
120 """
121 Converts a numpy array to :epkg:`C_OrtValue`.
123 :param arr: numpy array
124 :param device: :epkg:`C_OrtDevice` or None for cpu
125 :return: :epkg:`C_OrtValue`
126 """
127 if device is None:
128 device = get_ort_device('cpu')
129 return C_OrtValue.ortvalue_from_numpy(arr, device)
132def device_to_providers(device):
133 """
134 Returns the corresponding providers for a specific device.
136 :param device: :epkg:`C_OrtDevice`
137 :return: providers
138 """
139 if isinstance(device, str):
140 device = get_ort_device(device)
141 if device.device_type() == device.cpu():
142 return ['CPUExecutionProvider']
143 if device.device_type() == device.cuda():
144 return ['CUDAExecutionProvider']
145 raise ValueError( # pragma: no cover
146 f"Unexpected device {device!r}.")
149def get_ort_device_from_session(sess):
150 """
151 Retrieves the device from an object :epkg:`InferenceSession`.
153 :param sess: :epkg:`InferenceSession`
154 :return: :epkg:`C_OrtDevice`
155 """
156 providers = sess.get_providers()
157 if providers == ["CPUExecutionProvider"]:
158 return C_OrtDevice(C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0)
159 if providers[0] == "CUDAExecutionProvider":
160 options = sess.get_provider_options()
161 if len(options) == 0:
162 return C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0)
163 if "CUDAExecutionProvider" not in options:
164 raise NotImplementedError(
165 f"Unable to guess 'device_id' in {options}.")
166 cuda = options["CUDAExecutionProvider"]
167 if "device_id" not in cuda:
168 raise NotImplementedError(
169 f"Unable to guess 'device_id' in {options}.")
170 device_id = int(cuda["device_id"])
171 return C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), device_id)
172 raise NotImplementedError(
173 f"Not able to guess the model device from {providers}.")