Coverage for onnxcustom/utils/onnxruntime_helper.py: 87%

71 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 01:42 +0100

1""" 

2@file 

3@brief Onnxruntime helper. 

4""" 

5from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

6 OrtDevice as C_OrtDevice, OrtValue as C_OrtValue) 

7 

8 

9def provider_to_device(provider_name): 

10 """ 

11 Converts provider into a device. 

12 

13 :param provider_name: provider name 

14 :return: device name 

15 

16 .. runpython:: 

17 :showcode: 

18 

19 from onnxcustom.utils.onnxruntime_helper import provider_to_device 

20 print(provider_to_device('CPUExecutionProvider')) 

21 """ 

22 if provider_name == 'CPUExecutionProvider': 

23 return 'cpu' 

24 if provider_name == 'CUDAExecutionProvider': 

25 return 'cuda' 

26 raise ValueError( 

27 f"Unexpected value for provider_name={provider_name!r}.") 

28 

29 

30def get_ort_device_type(device): 

31 """ 

32 Converts device into device type. 

33 

34 :param device: string 

35 :return: device type 

36 """ 

37 if isinstance(device, str): 

38 if device == 'cuda': 

39 return C_OrtDevice.cuda() 

40 if device == 'cpu': 

41 return C_OrtDevice.cpu() 

42 raise ValueError( # pragma: no cover 

43 f'Unsupported device type: {device!r}.') 

44 if not hasattr(device, 'device_type'): 

45 raise TypeError(f'Unsupported device type: {type(device)!r}.') 

46 device_type = device.device_type() 

47 if device_type in ('cuda', 1): 

48 return C_OrtDevice.cuda() 

49 if device_type in ('cpu', 0): 

50 return C_OrtDevice.cpu() 

51 raise ValueError( # pragma: no cover 

52 f'Unsupported device type: {device_type!r}.') 

53 

54 

55def get_ort_device(device): 

56 """ 

57 Converts device into :epkg:`C_OrtDevice`. 

58 

59 :param device: any type 

60 :return: :epkg:`C_OrtDevice` 

61 

62 Example: 

63 

64 :: 

65 

66 get_ort_device('cpu') 

67 get_ort_device('gpu') 

68 get_ort_device('cuda') 

69 get_ort_device('cuda:0') 

70 """ 

71 if isinstance(device, C_OrtDevice): 

72 return device 

73 if isinstance(device, str): 

74 if device == 'cpu': 

75 return C_OrtDevice( 

76 C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0) 

77 if device in {'gpu', 'cuda:0', 'cuda', 'gpu:0'}: 

78 return C_OrtDevice( 

79 C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0) 

80 if device.startswith('gpu:'): 

81 idx = int(device[4:]) 

82 return C_OrtDevice( 

83 C_OrtDevice.cuda(), C_OrtDevice.default_memory(), idx) 

84 if device.startswith('cuda:'): 

85 idx = int(device[5:]) 

86 return C_OrtDevice( 

87 C_OrtDevice.cuda(), C_OrtDevice.default_memory(), idx) 

88 raise ValueError( 

89 f"Unable to interpret string {device!r} as a device.") 

90 raise TypeError( # pragma: no cover 

91 f"Unable to interpret type {type(device)!r}, ({device!r}) as de device.") 

92 

93 

94def ort_device_to_string(device): 

95 """ 

96 Returns a string representing the device. 

97 Opposite of function @see fn get_ort_device. 

98 

99 :param device: see :epkg:`C_OrtDevice` 

100 :return: string 

101 """ 

102 if not isinstance(device, C_OrtDevice): 

103 raise TypeError( 

104 f"device must be of type C_OrtDevice not {type(device)!r}.") 

105 ty = device.device_type() 

106 if ty == C_OrtDevice.cpu(): 

107 sty = 'cpu' 

108 elif ty == C_OrtDevice.cuda(): 

109 sty = 'cuda' 

110 else: 

111 raise NotImplementedError( # pragma: no cover 

112 f"Unable to guess device for {device!r} and type={ty!r}.") 

113 idx = device.device_id() 

114 if idx == 0: 

115 return sty 

116 return "%s:%d" % (sty, idx) 

117 

118 

119def numpy_to_ort_value(arr, device=None): 

120 """ 

121 Converts a numpy array to :epkg:`C_OrtValue`. 

122 

123 :param arr: numpy array 

124 :param device: :epkg:`C_OrtDevice` or None for cpu 

125 :return: :epkg:`C_OrtValue` 

126 """ 

127 if device is None: 

128 device = get_ort_device('cpu') 

129 return C_OrtValue.ortvalue_from_numpy(arr, device) 

130 

131 

132def device_to_providers(device): 

133 """ 

134 Returns the corresponding providers for a specific device. 

135 

136 :param device: :epkg:`C_OrtDevice` 

137 :return: providers 

138 """ 

139 if isinstance(device, str): 

140 device = get_ort_device(device) 

141 if device.device_type() == device.cpu(): 

142 return ['CPUExecutionProvider'] 

143 if device.device_type() == device.cuda(): 

144 return ['CUDAExecutionProvider'] 

145 raise ValueError( # pragma: no cover 

146 f"Unexpected device {device!r}.") 

147 

148 

149def get_ort_device_from_session(sess): 

150 """ 

151 Retrieves the device from an object :epkg:`InferenceSession`. 

152 

153 :param sess: :epkg:`InferenceSession` 

154 :return: :epkg:`C_OrtDevice` 

155 """ 

156 providers = sess.get_providers() 

157 if providers == ["CPUExecutionProvider"]: 

158 return C_OrtDevice(C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0) 

159 if providers[0] == "CUDAExecutionProvider": 

160 options = sess.get_provider_options() 

161 if len(options) == 0: 

162 return C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0) 

163 if "CUDAExecutionProvider" not in options: 

164 raise NotImplementedError( 

165 f"Unable to guess 'device_id' in {options}.") 

166 cuda = options["CUDAExecutionProvider"] 

167 if "device_id" not in cuda: 

168 raise NotImplementedError( 

169 f"Unable to guess 'device_id' in {options}.") 

170 device_id = int(cuda["device_id"]) 

171 return C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), device_id) 

172 raise NotImplementedError( 

173 f"Not able to guess the model device from {providers}.")