Coverage for onnxcustom/training/_base_onnx_function.py: 95%
119 statements
« prev ^ index » next coverage.py v7.0.5, created at 2023-01-17 01:42 +0100
« prev ^ index » next coverage.py v7.0.5, created at 2023-01-17 01:42 +0100
1# pylint: disable=W0105
2"""
3@file
4@brief Helper for :epkg:`onnxruntime-training`.
5"""
6import inspect
7from io import BytesIO
8import numpy
9import onnx
10from onnxruntime import SessionOptions, InferenceSession, RunOptions
11from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
12 OrtValue as C_OrtValue)
13from ..utils.onnxruntime_helper import ort_device_to_string
14from .excs import ProviderError
15from ._base import BaseOnnxClass
18class BaseLearningOnnx(BaseOnnxClass):
19 """
20 Class handling ONNX function to manipulate OrtValue.
21 Base class for @see cl BaseLearningRate and
22 @see cl BaseLearningLoss.
23 """
25 def __init__(self):
26 self.cache_in_ = {}
27 self.cache_out_ = {}
29 def __getstate__(self):
30 """
31 Overwrites getstate to get rid of InferenceSession.
32 """
33 atts = [k for k in self.__dict__ if not k.endswith('_')]
34 state = {k: getattr(self, k) for k in atts}
35 if hasattr(self, 'ro_'):
36 state['ro_'] = True
37 onx = [k for k in self.__dict__ if k.endswith('_onnx_')]
38 for o in onx:
39 state[o] = getattr(self, o).SerializeToString()
40 onx = [k for k in self.__dict__ if k.endswith('_sess_')]
41 bind = [k for k in self.__dict__ if k.endswith('_bind_')]
42 for k in bind:
43 state[k] = True
44 binds = [k for k in self.__dict__ if k.endswith('_binds_')]
45 for k in binds:
46 state[k] = len(getattr(self, k))
47 for o in onx:
48 state[o] = getattr(self, o).get_providers()
49 return state
51 def __setstate__(self, state):
52 """
53 Overwrites getstate to get rid of InferenceSession.
54 """
55 for k, v in state.items():
56 if k == 'ro_':
57 self.ro_ = RunOptions()
58 elif not k.endswith('_onnx_') and not k.endswith('_sess_'):
59 setattr(self, k, v)
61 so = SessionOptions()
62 so.log_severity_level = 4
63 for k, v in state.items():
64 if k.endswith('_onnx_'):
65 setattr(self, k, onnx.load(BytesIO(v)))
66 k2 = k.replace("onnx", "sess")
67 prov = state[k2]
68 setattr(self, k2, InferenceSession(
69 getattr(self, k).SerializeToString(), so,
70 providers=prov))
71 for k, v in state.items():
72 if k.endswith('_bind_'):
73 k2 = k[:-5]
74 setattr(self, k, getattr(self, k2).io_binding()._iobinding)
75 elif k.endswith('_binds_'):
76 k2 = k[:-6]
77 n = v
78 setattr(self, k, [
79 getattr(self, k2).io_binding()._iobinding
80 for i in range(n)])
81 self.cache_in_ = {}
82 self.cache_out_ = {}
83 return self
85 def __repr_extended__(self):
86 return ''
88 def __repr__(self):
89 """
90 Usual
91 """
92 param = self._get_param_names()
93 ps = []
94 for k, v in param:
95 if k not in self.__dict__:
96 continue # pragma: no cover
97 ov = getattr(self, k)
98 if v is not inspect._empty or ov != v:
99 ro = repr(ov)
100 ps.append(f"{k}={ro}")
101 return f"{self.__class__.__name__}({', '.join(ps)}){self.__repr_extended__()}"
103 def build_onnx_function(self, opset, device, *args):
104 """
105 This class computes a function represented as an ONNX graph.
106 This method builds it.
107 This function creates :epkg:`InferenceSession`
108 which do that.
110 :param opset: opset to use
111 :param device: :epkg:`C_OrtDevice`
112 :param args: additional arguments
113 """
114 raise NotImplementedError(
115 "This method must be overwritten.")
117 @staticmethod
118 def _cache_in_clear(cache, name, bind):
119 key = id(bind)
120 if key in cache:
121 if name in cache[key]:
122 if cache[key][name] == 0:
123 return True
124 cache[key][name] = 0
125 return False
126 return True
128 def clear_binding_inputs(self, name, bind, cache=False):
129 """
130 Clears binding and empty cache.
131 """
132 if cache and self._cache_in_clear(self.cache_in_, name, bind):
133 return
134 bind.clear_binding_inputs()
136 @staticmethod
137 def _bio_cache(cache, name, bind, c_ortvalue, ptr2):
138 key = id(bind)
139 if key in cache:
140 if name in cache[key]:
141 ptr = cache[key][name]
142 if ptr == ptr2:
143 return True
144 cache[key][name] = ptr2
145 else:
146 cache[key] = {name: ptr2}
147 return False
149 @staticmethod
150 def _bio_do_bind_in(name, bind, c_ortvalue):
151 bind.bind_ortvalue_input(name, c_ortvalue)
153 @staticmethod
154 def _bio_ptr(c):
155 return c.data_ptr()
157 def _bind_input_ortvalue(self, name, bind, c_ortvalue, device,
158 cache=False):
159 """
160 Binds :epkg:`C_OrtValue` to the structure used by
161 :epkg:`InferenceSession` to run inference.
163 :param name: str
164 :param bind: python structure
165 :param c_ortvalue: C structure for OrtValue (:epkg:`C_OrtValue`),
166 it can be also a numpy array
167 :param device: device
168 :param cache: avoids binding again if the data pointer did not change,
169 only works when c_ortvalue is of :epkg:`C_OrtValue`, the cache is
170 equivalent to a dictionary
171 `{ id(bind), name: c_ort_value.data_ptr() }`.
172 """
173 if isinstance(c_ortvalue, C_OrtValue):
174 if cache and self._bio_cache(
175 self.cache_in_, name, bind, c_ortvalue,
176 self._bio_ptr(c_ortvalue)):
177 return
178 self._bio_do_bind_in(name, bind, c_ortvalue)
179 elif isinstance(c_ortvalue, numpy.ndarray):
180 if self.device_type() != device.cpu(): # pylint: disable=E1101
181 raise ProviderError( # pragma: no cover
182 f"device={ort_device_to_string(device)} is not CPU.")
183 if cache and self._bio_cache(
184 self.cache_in_, name, bind, c_ortvalue,
185 c_ortvalue.__array_interface__['data'][0]):
186 return
187 bind.bind_input(
188 name, device, c_ortvalue.dtype, c_ortvalue.shape,
189 c_ortvalue.__array_interface__['data'][0])
190 else:
191 raise TypeError( # pragma: no cover
192 f"Unable to bind type {type(c_ortvalue)!r} for name {name!r}.")
194 @staticmethod
195 def _bio_do_bind_out(name, bind, c_ortvalue):
196 bind.bind_ortvalue_output(name, c_ortvalue)
198 def _bind_output_ortvalue(self, name, bind, c_ortvalue, cache=False):
199 """
200 Binds :epkg:`C_OrtValue` to the structure used by
201 :epkg:`InferenceSession` to run inference.
203 :param name: str
204 :param bind: python structure
205 :param c_ortvalue: C structure for OrtValue (:epkg:`C_OrtValue`)
206 :param cache: avoids binding again if the data pointer did not change,
207 only works when c_ortvalue is of :epkg:`C_OrtValue`, the cache is
208 equivalent to a dictionary
209 `{ id(bind), name: c_ort_value.data_ptr() }`.
211 This method can be used for inplace computation.
212 """
213 if isinstance(c_ortvalue, C_OrtValue):
214 if cache and self._bio_cache(
215 self.cache_out_, name, bind, c_ortvalue,
216 self._bio_ptr(c_ortvalue)):
217 return
218 self._bio_do_bind_out(name, bind, c_ortvalue)
219 else:
220 raise TypeError( # pragma: no cover
221 f"Unable to bind type {type(c_ortvalue)!r} for name {name!r}.")