Coverage for onnxcustom/training/grad_helper.py: 92%

126 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 01:42 +0100

1# pylint: disable=E1101 

2""" 

3@file 

4@brief ONNX and gradient. 

5""" 

6from io import BytesIO 

7from enum import IntFlag 

8import onnx 

9from onnx.helper import make_model, make_graph, make_node, make_tensor 

10from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

11 OrtModuleGraphBuilder, 

12 OrtModuleGraphBuilderConfiguration, 

13 TrainingGraphTransformerConfiguration) 

14from mlprodict.onnx_tools.optim.onnx_optimisation import onnx_remove_node 

15from ..utils.orttraining_helper import get_train_initializer 

16 

17 

18class DerivativeOptions(IntFlag): 

19 """ 

20 Options defining how to build the onnx graph of the 

21 gradients. 

22 

23 * `Zero`: default option, all options are disabled 

24 * `KeepYieldOp`: keeps the operator *YieldOp* in the graph, 

25 see @see fn onnx_derivative 

26 * `KeepOutputs`: keeps the output of the original graph 

27 * `FillGrad`: does not add any output to specify the gradient 

28 of the output but assumes it is one 

29 * `Loss`: the function assumes the loss was added to the graph 

30 """ 

31 

32 Zero = 0 

33 KeepYieldOp = 1 

34 KeepOutputs = 2 

35 FillGrad = 4 

36 Loss = 5 

37 

38 

39def onnx_derivative(onx, weights=None, inputs=None, 

40 options=DerivativeOptions.Zero, 

41 loss=None, label=None, path_name=None): 

42 """ 

43 Builds the gradient for an onnx graph. 

44 

45 :param onx: onnx graph 

46 :param weights: gradient against those weights, None for all real weights 

47 :param inputs: gradient against inputs, None for all real inputs 

48 :param options: options of type @see cl DerivativeOptions 

49 :param loss: loss output in case a loss was added in the graph, 

50 *options* must be equal to `DerivativeOptions.Loss` 

51 :param label: if *loss* is specified, then the label must be 

52 specified as well 

53 :param path_name: if *options* equal to `DerivativeOptions.Loss`, 

54 the gradient is saved to that path 

55 :return: onnx graph 

56 

57 The function calls :epkg:`OrtModuleGraphBuilderConfiguration` 

58 from :epkg:`onnxruntime-training`. This graph is meant to be used 

59 with @see cl OrtGradientForwardBackward and includes 

60 operator `YieldOp`. That's the graph looks this way: 

61 

62 .. gdot:: 

63 :script: DOT-SECTION 

64 

65 import numpy 

66 from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611 

67 OnnxAdd, OnnxMul, OnnxIdentity) 

68 from skl2onnx.common.data_types import FloatTensorType 

69 from mlprodict.onnxrt import OnnxInference 

70 from onnxcustom.training.grad_helper import ( 

71 onnx_derivative, DerivativeOptions) 

72 from onnxcustom import __max_supported_opset__ as opv 

73 

74 node = OnnxAdd('X', numpy.array([1], dtype=numpy.float32), 

75 op_version=opv, output_names=['Y']) 

76 onx = node.to_onnx({'X': FloatTensorType([None, 10])}, 

77 {'Y': FloatTensorType([None, 10])}, 

78 target_opset=opv) 

79 new_onx = onnx_derivative(onx, options=DerivativeOptions.KeepYieldOp) 

80 

81 oinf = OnnxInference(new_onx) 

82 print("DOT-SECTION", oinf.to_dot()) 

83 

84 These operators are the outputs of the 

85 initial graph and must be replaced by the gradient of these 

86 outputs to compute the gradient of the weights and the inputs. 

87 After they are replaced, it looks this way: 

88 

89 .. gdot:: 

90 :script: DOT-SECTION 

91 

92 import numpy 

93 from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611 

94 OnnxAdd, OnnxMul, OnnxIdentity) 

95 from skl2onnx.common.data_types import FloatTensorType 

96 from mlprodict.onnxrt import OnnxInference 

97 from onnxcustom.training.grad_helper import ( 

98 onnx_derivative, DerivativeOptions) 

99 from onnxcustom import __max_supported_opset__ as opv 

100 

101 node = OnnxAdd('X', numpy.array([1], dtype=numpy.float32), 

102 op_version=opv, output_names=['Y']) 

103 onx = node.to_onnx({'X': FloatTensorType([None, 10])}, 

104 {'Y': FloatTensorType([None, 10])}, 

105 target_opset=opv) 

106 new_onx = onnx_derivative(onx, options=DerivativeOptions.Zero) 

107 

108 oinf = OnnxInference(new_onx) 

109 print("DOT-SECTION", oinf.to_dot()) 

110 

111 The user can still compute the outputs. 

112 

113 .. gdot:: 

114 :script: DOT-SECTION 

115 

116 import numpy 

117 from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611 

118 OnnxAdd, OnnxMul, OnnxIdentity) 

119 from skl2onnx.common.data_types import FloatTensorType 

120 from mlprodict.onnxrt import OnnxInference 

121 from onnxcustom.training.grad_helper import ( 

122 onnx_derivative, DerivativeOptions) 

123 from onnxcustom import __max_supported_opset__ as opv 

124 

125 node = OnnxAdd('X', numpy.array([1], dtype=numpy.float32), 

126 op_version=opv, output_names=['Y']) 

127 onx = node.to_onnx({'X': FloatTensorType([None, 10])}, 

128 {'Y': FloatTensorType([None, 10])}, 

129 target_opset=opv) 

130 new_onx = onnx_derivative(onx, options=DerivativeOptions.KeepOutputs) 

131 

132 oinf = OnnxInference(new_onx) 

133 print("DOT-SECTION", oinf.to_dot()) 

134 

135 The input gradient can be filled with a constant matrix 

136 filled with one and with the expected shape. 

137 

138 .. gdot:: 

139 :script: DOT-SECTION 

140 

141 import numpy 

142 from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611 

143 OnnxAdd, OnnxMul, OnnxIdentity) 

144 from skl2onnx.common.data_types import FloatTensorType 

145 from mlprodict.onnxrt import OnnxInference 

146 from onnxcustom.training.grad_helper import ( 

147 onnx_derivative, DerivativeOptions) 

148 from onnxcustom import __max_supported_opset__ as opv 

149 

150 node = OnnxAdd('X', numpy.array([1], dtype=numpy.float32), 

151 op_version=opv, output_names=['Y']) 

152 onx = node.to_onnx({'X': FloatTensorType([None, 10])}, 

153 {'Y': FloatTensorType([None, 10])}, 

154 target_opset=opv) 

155 new_onx = onnx_derivative(onx, options=( 

156 DerivativeOptions.KeepOutputs | DerivativeOptions.FillGrad)) 

157 

158 oinf = OnnxInference(new_onx) 

159 print("DOT-SECTION", oinf.to_dot()) 

160 """ 

161 if not isinstance(options, DerivativeOptions): 

162 raise TypeError( 

163 f"Options must be from type DerivativeOptions not {type(options)!r}.") 

164 

165 if options == DerivativeOptions.Loss: 

166 return _onnx_derivative_loss(onx, weights=weights, inputs=inputs, 

167 options=options, loss=loss, label=label, 

168 path_name=path_name) 

169 return _onnx_derivative_fw(onx, weights=weights, inputs=inputs, 

170 options=options) 

171 

172 

173def _default_inputs(onx): 

174 "Guesses default inputs (float ones) if not specified." 

175 inputs_name = [] 

176 for i in onx.graph.input: 

177 try: 

178 elem_type = i.type.tensor_type.elem_type 

179 except AttributeError: # pragma: no cover 

180 # not a vector 

181 continue 

182 if elem_type in ( 

183 onnx.TensorProto.FLOAT16, 

184 onnx.TensorProto.FLOAT, 

185 onnx.TensorProto.DOUBLE): 

186 inputs_name.append(i.name) 

187 return inputs_name 

188 

189 

190def _onnx_derivative_fw(onx, weights, inputs, options): 

191 """ 

192 Implements a gradient based on class `OrtModuleGraphBuilder`. 

193 """ 

194 if weights is None: 

195 inits = get_train_initializer(onx) 

196 weights = list(inits) 

197 builder = OrtModuleGraphBuilder() 

198 

199 config = OrtModuleGraphBuilderConfiguration() 

200 config.initializer_names = weights 

201 config.initializer_names_to_train = weights 

202 if inputs is None: 

203 inputs_name = _default_inputs(onx) 

204 if len(inputs_name) > 0: 

205 config.input_names_require_grad = inputs_name 

206 config.build_gradient_graph = True 

207 

208 p = TrainingGraphTransformerConfiguration() 

209 config.graph_transformer_config = p 

210 

211 builder.initialize(onx.SerializeToString(), config) 

212 builder.build() 

213 try: 

214 train_onnx_model_serialized = builder.get_gradient_model() 

215 except AttributeError: 

216 train_onnx_model_serialized = builder.get_model() 

217 

218 # optimized_pre_grad_model = builder.get_inference_optimized_model() 

219 grad_yield = onnx.load(BytesIO(train_onnx_model_serialized)) 

220 if options & DerivativeOptions.KeepYieldOp: 

221 if options != DerivativeOptions.KeepYieldOp: 

222 raise ValueError( 

223 "Option YieldOd cannot be combined with any other.") 

224 return grad_yield 

225 

226 yields_op = [ 

227 (index, node) for index, node in enumerate(grad_yield.graph.node) 

228 if node.op_type == 'YieldOp'] 

229 if len(yields_op) == 0: 

230 raise RuntimeError( # pragma: no cover 

231 "No YieldOp was found. The input graph must be wrong.") 

232 

233 other_nodes = [ 

234 (index, node) for index, node in enumerate(grad_yield.graph.node) 

235 if node.op_type != 'YieldOp'] 

236 inputs = list(grad_yield.graph.input) 

237 if options & DerivativeOptions.KeepOutputs: 

238 outputs = list(grad_yield.graph.output) 

239 else: 

240 original = set(i.name for i in onx.graph.output) 

241 outputs = [o for o in grad_yield.graph.output 

242 if o.name not in original] 

243 map_out = {o.name: o for o in onx.graph.output} 

244 for index, yn in yields_op: 

245 if len(yn.input) != 1 or len(yn.output) != 1: 

246 raise NotImplementedError( # pragma: no cover 

247 f"Unexpected configuration for YieldOp node {yn!r}.") 

248 if yn.input[0] not in map_out: 

249 raise RuntimeError( # pragma: no cover 

250 f"Unable to find output {yn.input[0]!r} in {list(map_out)!r}.") 

251 if not (options & DerivativeOptions.FillGrad): # pylint: disable=C0325 

252 out = map_out[yn.input[0]] 

253 new_input = onnx.ValueInfoProto() 

254 new_input.name = yn.output[0] 

255 new_input.doc_string = "from yieldop" 

256 new_input.type.CopyFrom(out.type) 

257 inputs.append(new_input) 

258 else: 

259 if not (options & DerivativeOptions.KeepOutputs): # pylint: disable=C0325 

260 raise ValueError( # pragma: no cover 

261 "FillGrad should be set with KeepOutputs.") 

262 name = f"{yn.input[0]}_shape" 

263 node = make_node('Shape', [yn.input[0]], [name]) 

264 other_nodes.append((index + 0.1, node)) 

265 out = map_out[yn.input[0]] 

266 elem_type = out.type.tensor_type.elem_type 

267 node = make_node( 

268 'ConstantOfShape', [name], [yn.output[0]], 

269 value=make_tensor( 

270 "value", elem_type, (1, ), [1])) 

271 other_nodes.append((index + 0.2, node)) 

272 if options & DerivativeOptions.KeepOutputs: 

273 # Keeps output from the original graph. 

274 outputs.append(out) 

275 

276 # Final graph. 

277 other_nodes.sort() 

278 other_nodes = [o[1] for o in other_nodes] 

279 graph = make_graph( 

280 other_nodes, grad_yield.graph.name, inputs, outputs, 

281 list(grad_yield.graph.initializer)) 

282 new_model = make_model(graph) 

283 new_model.ir_version = grad_yield.ir_version 

284 new_model.producer_name = grad_yield.producer_name 

285 new_model.producer_version = grad_yield.producer_version 

286 new_model.domain = grad_yield.domain 

287 new_model.model_version = grad_yield.model_version 

288 new_model.doc_string = grad_yield.doc_string 

289 if hasattr(onx, 'value_info'): 

290 graph.value_info.extend(grad_yield.value_info) 

291 del new_model.opset_import[:] 

292 for oimp in grad_yield.opset_import: 

293 op_set = new_model.opset_import.add() 

294 op_set.domain = oimp.domain 

295 op_set.version = oimp.version 

296 

297 return onnx_remove_node(new_model) 

298 

299 

300def _onnx_derivative_loss(onx, weights, inputs, options, loss, label, 

301 path_name): 

302 """ 

303 Implements a gradient based on class `PyGradientGraphBuilder`. 

304 """ 

305 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611,C0415 

306 GradientGraphBuilder) 

307 if path_name is None: 

308 raise ValueError( 

309 "path_name must not be None if options is 'Loss'.") 

310 if weights is not None: 

311 raise ValueError( 

312 "weights must be None if options is 'Loss'.") 

313 if label is None: 

314 raise ValueError( 

315 "label must not be None if options is 'Loss'.") 

316 if loss is None or not isinstance(loss, str): 

317 raise ValueError( 

318 "loss must not None and a string if options is 'Loss'.") 

319 if isinstance(label, str): 

320 label = {label} 

321 else: 

322 label = set(label) 

323 if inputs is None: 

324 inputs_name = _default_inputs(onx) 

325 inputs = inputs_name 

326 if isinstance(inputs, str): 

327 inputs = {inputs} 

328 else: 

329 inputs = set(inputs) 

330 inputs = set(x for x in inputs if x not in label) 

331 

332 builder = GradientGraphBuilder( 

333 onx.SerializeToString(), label, inputs, loss) 

334 builder.build() 

335 builder.save(path_name) 

336 with open(path_name, "rb") as f: 

337 return onnx.load(f)