Coverage for onnxcustom/training/grad_helper.py: 92%
126 statements
« prev ^ index » next coverage.py v7.0.5, created at 2023-01-17 01:42 +0100
« prev ^ index » next coverage.py v7.0.5, created at 2023-01-17 01:42 +0100
1# pylint: disable=E1101
2"""
3@file
4@brief ONNX and gradient.
5"""
6from io import BytesIO
7from enum import IntFlag
8import onnx
9from onnx.helper import make_model, make_graph, make_node, make_tensor
10from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
11 OrtModuleGraphBuilder,
12 OrtModuleGraphBuilderConfiguration,
13 TrainingGraphTransformerConfiguration)
14from mlprodict.onnx_tools.optim.onnx_optimisation import onnx_remove_node
15from ..utils.orttraining_helper import get_train_initializer
18class DerivativeOptions(IntFlag):
19 """
20 Options defining how to build the onnx graph of the
21 gradients.
23 * `Zero`: default option, all options are disabled
24 * `KeepYieldOp`: keeps the operator *YieldOp* in the graph,
25 see @see fn onnx_derivative
26 * `KeepOutputs`: keeps the output of the original graph
27 * `FillGrad`: does not add any output to specify the gradient
28 of the output but assumes it is one
29 * `Loss`: the function assumes the loss was added to the graph
30 """
32 Zero = 0
33 KeepYieldOp = 1
34 KeepOutputs = 2
35 FillGrad = 4
36 Loss = 5
39def onnx_derivative(onx, weights=None, inputs=None,
40 options=DerivativeOptions.Zero,
41 loss=None, label=None, path_name=None):
42 """
43 Builds the gradient for an onnx graph.
45 :param onx: onnx graph
46 :param weights: gradient against those weights, None for all real weights
47 :param inputs: gradient against inputs, None for all real inputs
48 :param options: options of type @see cl DerivativeOptions
49 :param loss: loss output in case a loss was added in the graph,
50 *options* must be equal to `DerivativeOptions.Loss`
51 :param label: if *loss* is specified, then the label must be
52 specified as well
53 :param path_name: if *options* equal to `DerivativeOptions.Loss`,
54 the gradient is saved to that path
55 :return: onnx graph
57 The function calls :epkg:`OrtModuleGraphBuilderConfiguration`
58 from :epkg:`onnxruntime-training`. This graph is meant to be used
59 with @see cl OrtGradientForwardBackward and includes
60 operator `YieldOp`. That's the graph looks this way:
62 .. gdot::
63 :script: DOT-SECTION
65 import numpy
66 from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611
67 OnnxAdd, OnnxMul, OnnxIdentity)
68 from skl2onnx.common.data_types import FloatTensorType
69 from mlprodict.onnxrt import OnnxInference
70 from onnxcustom.training.grad_helper import (
71 onnx_derivative, DerivativeOptions)
72 from onnxcustom import __max_supported_opset__ as opv
74 node = OnnxAdd('X', numpy.array([1], dtype=numpy.float32),
75 op_version=opv, output_names=['Y'])
76 onx = node.to_onnx({'X': FloatTensorType([None, 10])},
77 {'Y': FloatTensorType([None, 10])},
78 target_opset=opv)
79 new_onx = onnx_derivative(onx, options=DerivativeOptions.KeepYieldOp)
81 oinf = OnnxInference(new_onx)
82 print("DOT-SECTION", oinf.to_dot())
84 These operators are the outputs of the
85 initial graph and must be replaced by the gradient of these
86 outputs to compute the gradient of the weights and the inputs.
87 After they are replaced, it looks this way:
89 .. gdot::
90 :script: DOT-SECTION
92 import numpy
93 from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611
94 OnnxAdd, OnnxMul, OnnxIdentity)
95 from skl2onnx.common.data_types import FloatTensorType
96 from mlprodict.onnxrt import OnnxInference
97 from onnxcustom.training.grad_helper import (
98 onnx_derivative, DerivativeOptions)
99 from onnxcustom import __max_supported_opset__ as opv
101 node = OnnxAdd('X', numpy.array([1], dtype=numpy.float32),
102 op_version=opv, output_names=['Y'])
103 onx = node.to_onnx({'X': FloatTensorType([None, 10])},
104 {'Y': FloatTensorType([None, 10])},
105 target_opset=opv)
106 new_onx = onnx_derivative(onx, options=DerivativeOptions.Zero)
108 oinf = OnnxInference(new_onx)
109 print("DOT-SECTION", oinf.to_dot())
111 The user can still compute the outputs.
113 .. gdot::
114 :script: DOT-SECTION
116 import numpy
117 from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611
118 OnnxAdd, OnnxMul, OnnxIdentity)
119 from skl2onnx.common.data_types import FloatTensorType
120 from mlprodict.onnxrt import OnnxInference
121 from onnxcustom.training.grad_helper import (
122 onnx_derivative, DerivativeOptions)
123 from onnxcustom import __max_supported_opset__ as opv
125 node = OnnxAdd('X', numpy.array([1], dtype=numpy.float32),
126 op_version=opv, output_names=['Y'])
127 onx = node.to_onnx({'X': FloatTensorType([None, 10])},
128 {'Y': FloatTensorType([None, 10])},
129 target_opset=opv)
130 new_onx = onnx_derivative(onx, options=DerivativeOptions.KeepOutputs)
132 oinf = OnnxInference(new_onx)
133 print("DOT-SECTION", oinf.to_dot())
135 The input gradient can be filled with a constant matrix
136 filled with one and with the expected shape.
138 .. gdot::
139 :script: DOT-SECTION
141 import numpy
142 from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611
143 OnnxAdd, OnnxMul, OnnxIdentity)
144 from skl2onnx.common.data_types import FloatTensorType
145 from mlprodict.onnxrt import OnnxInference
146 from onnxcustom.training.grad_helper import (
147 onnx_derivative, DerivativeOptions)
148 from onnxcustom import __max_supported_opset__ as opv
150 node = OnnxAdd('X', numpy.array([1], dtype=numpy.float32),
151 op_version=opv, output_names=['Y'])
152 onx = node.to_onnx({'X': FloatTensorType([None, 10])},
153 {'Y': FloatTensorType([None, 10])},
154 target_opset=opv)
155 new_onx = onnx_derivative(onx, options=(
156 DerivativeOptions.KeepOutputs | DerivativeOptions.FillGrad))
158 oinf = OnnxInference(new_onx)
159 print("DOT-SECTION", oinf.to_dot())
160 """
161 if not isinstance(options, DerivativeOptions):
162 raise TypeError(
163 f"Options must be from type DerivativeOptions not {type(options)!r}.")
165 if options == DerivativeOptions.Loss:
166 return _onnx_derivative_loss(onx, weights=weights, inputs=inputs,
167 options=options, loss=loss, label=label,
168 path_name=path_name)
169 return _onnx_derivative_fw(onx, weights=weights, inputs=inputs,
170 options=options)
173def _default_inputs(onx):
174 "Guesses default inputs (float ones) if not specified."
175 inputs_name = []
176 for i in onx.graph.input:
177 try:
178 elem_type = i.type.tensor_type.elem_type
179 except AttributeError: # pragma: no cover
180 # not a vector
181 continue
182 if elem_type in (
183 onnx.TensorProto.FLOAT16,
184 onnx.TensorProto.FLOAT,
185 onnx.TensorProto.DOUBLE):
186 inputs_name.append(i.name)
187 return inputs_name
190def _onnx_derivative_fw(onx, weights, inputs, options):
191 """
192 Implements a gradient based on class `OrtModuleGraphBuilder`.
193 """
194 if weights is None:
195 inits = get_train_initializer(onx)
196 weights = list(inits)
197 builder = OrtModuleGraphBuilder()
199 config = OrtModuleGraphBuilderConfiguration()
200 config.initializer_names = weights
201 config.initializer_names_to_train = weights
202 if inputs is None:
203 inputs_name = _default_inputs(onx)
204 if len(inputs_name) > 0:
205 config.input_names_require_grad = inputs_name
206 config.build_gradient_graph = True
208 p = TrainingGraphTransformerConfiguration()
209 config.graph_transformer_config = p
211 builder.initialize(onx.SerializeToString(), config)
212 builder.build()
213 try:
214 train_onnx_model_serialized = builder.get_gradient_model()
215 except AttributeError:
216 train_onnx_model_serialized = builder.get_model()
218 # optimized_pre_grad_model = builder.get_inference_optimized_model()
219 grad_yield = onnx.load(BytesIO(train_onnx_model_serialized))
220 if options & DerivativeOptions.KeepYieldOp:
221 if options != DerivativeOptions.KeepYieldOp:
222 raise ValueError(
223 "Option YieldOd cannot be combined with any other.")
224 return grad_yield
226 yields_op = [
227 (index, node) for index, node in enumerate(grad_yield.graph.node)
228 if node.op_type == 'YieldOp']
229 if len(yields_op) == 0:
230 raise RuntimeError( # pragma: no cover
231 "No YieldOp was found. The input graph must be wrong.")
233 other_nodes = [
234 (index, node) for index, node in enumerate(grad_yield.graph.node)
235 if node.op_type != 'YieldOp']
236 inputs = list(grad_yield.graph.input)
237 if options & DerivativeOptions.KeepOutputs:
238 outputs = list(grad_yield.graph.output)
239 else:
240 original = set(i.name for i in onx.graph.output)
241 outputs = [o for o in grad_yield.graph.output
242 if o.name not in original]
243 map_out = {o.name: o for o in onx.graph.output}
244 for index, yn in yields_op:
245 if len(yn.input) != 1 or len(yn.output) != 1:
246 raise NotImplementedError( # pragma: no cover
247 f"Unexpected configuration for YieldOp node {yn!r}.")
248 if yn.input[0] not in map_out:
249 raise RuntimeError( # pragma: no cover
250 f"Unable to find output {yn.input[0]!r} in {list(map_out)!r}.")
251 if not (options & DerivativeOptions.FillGrad): # pylint: disable=C0325
252 out = map_out[yn.input[0]]
253 new_input = onnx.ValueInfoProto()
254 new_input.name = yn.output[0]
255 new_input.doc_string = "from yieldop"
256 new_input.type.CopyFrom(out.type)
257 inputs.append(new_input)
258 else:
259 if not (options & DerivativeOptions.KeepOutputs): # pylint: disable=C0325
260 raise ValueError( # pragma: no cover
261 "FillGrad should be set with KeepOutputs.")
262 name = f"{yn.input[0]}_shape"
263 node = make_node('Shape', [yn.input[0]], [name])
264 other_nodes.append((index + 0.1, node))
265 out = map_out[yn.input[0]]
266 elem_type = out.type.tensor_type.elem_type
267 node = make_node(
268 'ConstantOfShape', [name], [yn.output[0]],
269 value=make_tensor(
270 "value", elem_type, (1, ), [1]))
271 other_nodes.append((index + 0.2, node))
272 if options & DerivativeOptions.KeepOutputs:
273 # Keeps output from the original graph.
274 outputs.append(out)
276 # Final graph.
277 other_nodes.sort()
278 other_nodes = [o[1] for o in other_nodes]
279 graph = make_graph(
280 other_nodes, grad_yield.graph.name, inputs, outputs,
281 list(grad_yield.graph.initializer))
282 new_model = make_model(graph)
283 new_model.ir_version = grad_yield.ir_version
284 new_model.producer_name = grad_yield.producer_name
285 new_model.producer_version = grad_yield.producer_version
286 new_model.domain = grad_yield.domain
287 new_model.model_version = grad_yield.model_version
288 new_model.doc_string = grad_yield.doc_string
289 if hasattr(onx, 'value_info'):
290 graph.value_info.extend(grad_yield.value_info)
291 del new_model.opset_import[:]
292 for oimp in grad_yield.opset_import:
293 op_set = new_model.opset_import.add()
294 op_set.domain = oimp.domain
295 op_set.version = oimp.version
297 return onnx_remove_node(new_model)
300def _onnx_derivative_loss(onx, weights, inputs, options, loss, label,
301 path_name):
302 """
303 Implements a gradient based on class `PyGradientGraphBuilder`.
304 """
305 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611,C0415
306 GradientGraphBuilder)
307 if path_name is None:
308 raise ValueError(
309 "path_name must not be None if options is 'Loss'.")
310 if weights is not None:
311 raise ValueError(
312 "weights must be None if options is 'Loss'.")
313 if label is None:
314 raise ValueError(
315 "label must not be None if options is 'Loss'.")
316 if loss is None or not isinstance(loss, str):
317 raise ValueError(
318 "loss must not None and a string if options is 'Loss'.")
319 if isinstance(label, str):
320 label = {label}
321 else:
322 label = set(label)
323 if inputs is None:
324 inputs_name = _default_inputs(onx)
325 inputs = inputs_name
326 if isinstance(inputs, str):
327 inputs = {inputs}
328 else:
329 inputs = set(inputs)
330 inputs = set(x for x in inputs if x not in label)
332 builder = GradientGraphBuilder(
333 onx.SerializeToString(), label, inputs, loss)
334 builder.build()
335 builder.save(path_name)
336 with open(path_name, "rb") as f:
337 return onnx.load(f)