Coverage for onnxcustom/training/_base_onnx_function.py: 95%

119 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 01:42 +0100

1# pylint: disable=W0105 

2""" 

3@file 

4@brief Helper for :epkg:`onnxruntime-training`. 

5""" 

6import inspect 

7from io import BytesIO 

8import numpy 

9import onnx 

10from onnxruntime import SessionOptions, InferenceSession, RunOptions 

11from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

12 OrtValue as C_OrtValue) 

13from ..utils.onnxruntime_helper import ort_device_to_string 

14from .excs import ProviderError 

15from ._base import BaseOnnxClass 

16 

17 

18class BaseLearningOnnx(BaseOnnxClass): 

19 """ 

20 Class handling ONNX function to manipulate OrtValue. 

21 Base class for @see cl BaseLearningRate and 

22 @see cl BaseLearningLoss. 

23 """ 

24 

25 def __init__(self): 

26 self.cache_in_ = {} 

27 self.cache_out_ = {} 

28 

29 def __getstate__(self): 

30 """ 

31 Overwrites getstate to get rid of InferenceSession. 

32 """ 

33 atts = [k for k in self.__dict__ if not k.endswith('_')] 

34 state = {k: getattr(self, k) for k in atts} 

35 if hasattr(self, 'ro_'): 

36 state['ro_'] = True 

37 onx = [k for k in self.__dict__ if k.endswith('_onnx_')] 

38 for o in onx: 

39 state[o] = getattr(self, o).SerializeToString() 

40 onx = [k for k in self.__dict__ if k.endswith('_sess_')] 

41 bind = [k for k in self.__dict__ if k.endswith('_bind_')] 

42 for k in bind: 

43 state[k] = True 

44 binds = [k for k in self.__dict__ if k.endswith('_binds_')] 

45 for k in binds: 

46 state[k] = len(getattr(self, k)) 

47 for o in onx: 

48 state[o] = getattr(self, o).get_providers() 

49 return state 

50 

51 def __setstate__(self, state): 

52 """ 

53 Overwrites getstate to get rid of InferenceSession. 

54 """ 

55 for k, v in state.items(): 

56 if k == 'ro_': 

57 self.ro_ = RunOptions() 

58 elif not k.endswith('_onnx_') and not k.endswith('_sess_'): 

59 setattr(self, k, v) 

60 

61 so = SessionOptions() 

62 so.log_severity_level = 4 

63 for k, v in state.items(): 

64 if k.endswith('_onnx_'): 

65 setattr(self, k, onnx.load(BytesIO(v))) 

66 k2 = k.replace("onnx", "sess") 

67 prov = state[k2] 

68 setattr(self, k2, InferenceSession( 

69 getattr(self, k).SerializeToString(), so, 

70 providers=prov)) 

71 for k, v in state.items(): 

72 if k.endswith('_bind_'): 

73 k2 = k[:-5] 

74 setattr(self, k, getattr(self, k2).io_binding()._iobinding) 

75 elif k.endswith('_binds_'): 

76 k2 = k[:-6] 

77 n = v 

78 setattr(self, k, [ 

79 getattr(self, k2).io_binding()._iobinding 

80 for i in range(n)]) 

81 self.cache_in_ = {} 

82 self.cache_out_ = {} 

83 return self 

84 

85 def __repr_extended__(self): 

86 return '' 

87 

88 def __repr__(self): 

89 """ 

90 Usual 

91 """ 

92 param = self._get_param_names() 

93 ps = [] 

94 for k, v in param: 

95 if k not in self.__dict__: 

96 continue # pragma: no cover 

97 ov = getattr(self, k) 

98 if v is not inspect._empty or ov != v: 

99 ro = repr(ov) 

100 ps.append(f"{k}={ro}") 

101 return f"{self.__class__.__name__}({', '.join(ps)}){self.__repr_extended__()}" 

102 

103 def build_onnx_function(self, opset, device, *args): 

104 """ 

105 This class computes a function represented as an ONNX graph. 

106 This method builds it. 

107 This function creates :epkg:`InferenceSession` 

108 which do that. 

109 

110 :param opset: opset to use 

111 :param device: :epkg:`C_OrtDevice` 

112 :param args: additional arguments 

113 """ 

114 raise NotImplementedError( 

115 "This method must be overwritten.") 

116 

117 @staticmethod 

118 def _cache_in_clear(cache, name, bind): 

119 key = id(bind) 

120 if key in cache: 

121 if name in cache[key]: 

122 if cache[key][name] == 0: 

123 return True 

124 cache[key][name] = 0 

125 return False 

126 return True 

127 

128 def clear_binding_inputs(self, name, bind, cache=False): 

129 """ 

130 Clears binding and empty cache. 

131 """ 

132 if cache and self._cache_in_clear(self.cache_in_, name, bind): 

133 return 

134 bind.clear_binding_inputs() 

135 

136 @staticmethod 

137 def _bio_cache(cache, name, bind, c_ortvalue, ptr2): 

138 key = id(bind) 

139 if key in cache: 

140 if name in cache[key]: 

141 ptr = cache[key][name] 

142 if ptr == ptr2: 

143 return True 

144 cache[key][name] = ptr2 

145 else: 

146 cache[key] = {name: ptr2} 

147 return False 

148 

149 @staticmethod 

150 def _bio_do_bind_in(name, bind, c_ortvalue): 

151 bind.bind_ortvalue_input(name, c_ortvalue) 

152 

153 @staticmethod 

154 def _bio_ptr(c): 

155 return c.data_ptr() 

156 

157 def _bind_input_ortvalue(self, name, bind, c_ortvalue, device, 

158 cache=False): 

159 """ 

160 Binds :epkg:`C_OrtValue` to the structure used by 

161 :epkg:`InferenceSession` to run inference. 

162 

163 :param name: str 

164 :param bind: python structure 

165 :param c_ortvalue: C structure for OrtValue (:epkg:`C_OrtValue`), 

166 it can be also a numpy array 

167 :param device: device 

168 :param cache: avoids binding again if the data pointer did not change, 

169 only works when c_ortvalue is of :epkg:`C_OrtValue`, the cache is 

170 equivalent to a dictionary 

171 `{ id(bind), name: c_ort_value.data_ptr() }`. 

172 """ 

173 if isinstance(c_ortvalue, C_OrtValue): 

174 if cache and self._bio_cache( 

175 self.cache_in_, name, bind, c_ortvalue, 

176 self._bio_ptr(c_ortvalue)): 

177 return 

178 self._bio_do_bind_in(name, bind, c_ortvalue) 

179 elif isinstance(c_ortvalue, numpy.ndarray): 

180 if self.device_type() != device.cpu(): # pylint: disable=E1101 

181 raise ProviderError( # pragma: no cover 

182 f"device={ort_device_to_string(device)} is not CPU.") 

183 if cache and self._bio_cache( 

184 self.cache_in_, name, bind, c_ortvalue, 

185 c_ortvalue.__array_interface__['data'][0]): 

186 return 

187 bind.bind_input( 

188 name, device, c_ortvalue.dtype, c_ortvalue.shape, 

189 c_ortvalue.__array_interface__['data'][0]) 

190 else: 

191 raise TypeError( # pragma: no cover 

192 f"Unable to bind type {type(c_ortvalue)!r} for name {name!r}.") 

193 

194 @staticmethod 

195 def _bio_do_bind_out(name, bind, c_ortvalue): 

196 bind.bind_ortvalue_output(name, c_ortvalue) 

197 

198 def _bind_output_ortvalue(self, name, bind, c_ortvalue, cache=False): 

199 """ 

200 Binds :epkg:`C_OrtValue` to the structure used by 

201 :epkg:`InferenceSession` to run inference. 

202 

203 :param name: str 

204 :param bind: python structure 

205 :param c_ortvalue: C structure for OrtValue (:epkg:`C_OrtValue`) 

206 :param cache: avoids binding again if the data pointer did not change, 

207 only works when c_ortvalue is of :epkg:`C_OrtValue`, the cache is 

208 equivalent to a dictionary 

209 `{ id(bind), name: c_ort_value.data_ptr() }`. 

210 

211 This method can be used for inplace computation. 

212 """ 

213 if isinstance(c_ortvalue, C_OrtValue): 

214 if cache and self._bio_cache( 

215 self.cache_out_, name, bind, c_ortvalue, 

216 self._bio_ptr(c_ortvalue)): 

217 return 

218 self._bio_do_bind_out(name, bind, c_ortvalue) 

219 else: 

220 raise TypeError( # pragma: no cover 

221 f"Unable to bind type {type(c_ortvalue)!r} for name {name!r}.")