Coverage for mlprodict/tools/onnx_inference_ort_helper.py: 70%
23 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# pylint: disable=C0302
2"""
3@file
4@brief Helpers for :epkg:`onnxruntime`.
5"""
8def get_ort_device(device):
9 """
10 Converts device into :epkg:`C_OrtDevice`.
12 :param device: any type
13 :return: :epkg:`C_OrtDevice`
15 Example:
17 ::
19 get_ort_device('cpu')
20 get_ort_device('gpu')
21 get_ort_device('cuda')
22 get_ort_device('cuda:0')
23 """
24 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611,W0611
25 OrtDevice as C_OrtDevice) # delayed
26 if isinstance(device, C_OrtDevice):
27 return device
28 if isinstance(device, str):
29 if device == 'cpu':
30 return C_OrtDevice(
31 C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0)
32 if device in {'gpu', 'cuda:0', 'cuda', 'gpu:0'}:
33 return C_OrtDevice(
34 C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0)
35 if device.startswith('gpu:'):
36 idx = int(device[4:])
37 return C_OrtDevice(
38 C_OrtDevice.cuda(), C_OrtDevice.default_memory(), idx)
39 if device.startswith('cuda:'):
40 idx = int(device[5:])
41 return C_OrtDevice(
42 C_OrtDevice.cuda(), C_OrtDevice.default_memory(), idx)
43 raise ValueError( # pragma: no cover
44 f"Unable to interpret string {device!r} as a device.")
45 raise TypeError( # pragma: no cover
46 f"Unable to interpret type {type(device)!r}, ({device!r}) as de device.")
49def device_to_providers(device):
50 """
51 Returns the corresponding providers for a specific device.
53 :param device: :epkg:`C_OrtDevice`
54 :return: providers
55 """
56 if isinstance(device, str):
57 device = get_ort_device(device)
58 if device.device_type() == device.cpu():
59 return ['CPUExecutionProvider']
60 if device.device_type() == device.cuda():
61 return ['CUDAExecutionProvider', 'CPUExecutionProvider']
62 raise ValueError( # pragma: no cover
63 f"Unexpected device {device!r}.")