Coverage for mlprodict/tools/onnx_inference_ort_helper.py: 70%

23 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# pylint: disable=C0302 

2""" 

3@file 

4@brief Helpers for :epkg:`onnxruntime`. 

5""" 

6 

7 

8def get_ort_device(device): 

9 """ 

10 Converts device into :epkg:`C_OrtDevice`. 

11 

12 :param device: any type 

13 :return: :epkg:`C_OrtDevice` 

14 

15 Example: 

16 

17 :: 

18 

19 get_ort_device('cpu') 

20 get_ort_device('gpu') 

21 get_ort_device('cuda') 

22 get_ort_device('cuda:0') 

23 """ 

24 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611,W0611 

25 OrtDevice as C_OrtDevice) # delayed 

26 if isinstance(device, C_OrtDevice): 

27 return device 

28 if isinstance(device, str): 

29 if device == 'cpu': 

30 return C_OrtDevice( 

31 C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0) 

32 if device in {'gpu', 'cuda:0', 'cuda', 'gpu:0'}: 

33 return C_OrtDevice( 

34 C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0) 

35 if device.startswith('gpu:'): 

36 idx = int(device[4:]) 

37 return C_OrtDevice( 

38 C_OrtDevice.cuda(), C_OrtDevice.default_memory(), idx) 

39 if device.startswith('cuda:'): 

40 idx = int(device[5:]) 

41 return C_OrtDevice( 

42 C_OrtDevice.cuda(), C_OrtDevice.default_memory(), idx) 

43 raise ValueError( # pragma: no cover 

44 f"Unable to interpret string {device!r} as a device.") 

45 raise TypeError( # pragma: no cover 

46 f"Unable to interpret type {type(device)!r}, ({device!r}) as de device.") 

47 

48 

49def device_to_providers(device): 

50 """ 

51 Returns the corresponding providers for a specific device. 

52 

53 :param device: :epkg:`C_OrtDevice` 

54 :return: providers 

55 """ 

56 if isinstance(device, str): 

57 device = get_ort_device(device) 

58 if device.device_type() == device.cpu(): 

59 return ['CPUExecutionProvider'] 

60 if device.device_type() == device.cuda(): 

61 return ['CUDAExecutionProvider', 'CPUExecutionProvider'] 

62 raise ValueError( # pragma: no cover 

63 f"Unexpected device {device!r}.")