Coverage for onnxcustom/training/sgd_learning_loss.py: 100%

97 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 01:42 +0100

1# pylint: disable=W0105 

2""" 

3@file 

4@brief Helper for :epkg:`onnxruntime-training`. 

5""" 

6from onnxruntime import SessionOptions, InferenceSession, RunOptions 

7from ..utils.onnx_function import function_onnx_graph 

8from ..utils.onnxruntime_helper import device_to_providers 

9from ..utils.onnx_rewriter import unreduced_onnx_loss 

10from ._base_onnx_function import BaseLearningOnnx 

11 

12 

13class BaseLearningLoss(BaseLearningOnnx): 

14 """ 

15 Class handling the loss for class 

16 @see cl OrtGradientForwardBackwardOptimizer. 

17 All classes inheriting from this one creates one ONNX function, 

18 returning the loss and the gradient of the loss against the 

19 outputs. Method `loss_gradient` is the main method, it computes 

20 the loss and the gradient defiend by one ONNX graph and 

21 executed by an instance of :epkg:`InferenceSession`. 

22 """ 

23 

24 def __init__(self): 

25 BaseLearningOnnx.__init__(self) 

26 self.ro_ = RunOptions() 

27 

28 def build_onnx_score_function(self, opset, device, weight_name): 

29 """ 

30 Assuming the loss function was created. This 

31 one takes the onnx graph and generate the onnx graph 

32 for the method `loss_score`. 

33 """ 

34 if not hasattr(self, 'loss_grad_onnx_'): 

35 raise RuntimeError( # pragma: no cover 

36 "Missing attribute 'loss_grad_onnx_'. " 

37 "Method 'build_onnx_function' should be called first.") 

38 

39 # score 

40 so = SessionOptions() 

41 so.log_severity_level = 4 

42 self.loss_score_onnx_ = unreduced_onnx_loss( 

43 self.loss_grad_onnx_, 'Y') # pylint: disable=E1101 

44 self.loss_score_sess_ = InferenceSession( 

45 self.loss_score_onnx_.SerializeToString(), so, 

46 providers=device_to_providers(device)) 

47 self.loss_score_sess_bind_ = ( 

48 self.loss_score_sess_.io_binding()._iobinding) 

49 

50 def _call_iobinding(self, sess, bind): 

51 sess.run_with_iobinding(bind, self.ro_) 

52 

53 def loss_gradient( # pylint: disable=E1101 

54 self, device, expected, predicted, weight=None): 

55 """ 

56 Returns the loss and the gradient as OrtValue. 

57 

58 :param device: device where the training takes place 

59 :param expected: expected value 

60 :param predicted: predicted value 

61 :param weight: optional, training weights 

62 (same dimension as expected and predicted tensors) 

63 :return: loss and gradient 

64 """ 

65 if (not hasattr(self, "loss_grad_sess_") or 

66 not hasattr(self, "loss_grad_sess_bind_")): 

67 raise RuntimeError( # pragma: no cover 

68 "Attributes 'loss_grad_sess_bind_' or 'loss_grad_sess_' is " 

69 "missing. Method 'build_onnx_function' has not been called.") 

70 bind = self.loss_grad_sess_bind_ 

71 if weight is not None: 

72 self._bind_input_ortvalue( 

73 "weight", bind, weight, device, cache=True) 

74 else: 

75 self.clear_binding_inputs("weight", bind, cache=True) 

76 self._bind_input_ortvalue("X1", bind, expected, device, cache=True) 

77 self._bind_input_ortvalue("X2", bind, predicted, device, cache=True) 

78 self.loss_grad_sess_bind_.bind_output('Y', device) 

79 self.loss_grad_sess_bind_.bind_output('Y_grad', device) 

80 self._call_iobinding(self.loss_grad_sess_._sess, bind) 

81 loss, grad = bind.get_outputs() 

82 return loss, grad 

83 

84 def loss_scores( # pylint: disable=E1101 

85 self, device, expected, predicted, weight=None): 

86 """ 

87 Returns the weighted loss (or score) 

88 for every observation as OrtValue. 

89 

90 :param device: device where the training takes place 

91 :param expected: expected value 

92 :param predicted: predicted value 

93 :param weight: optional, training weights 

94 (same dimension as expected and predicted tensors) 

95 :return: a score for every observation 

96 """ 

97 if (not hasattr(self, "loss_score_sess_") or 

98 not hasattr(self, "loss_score_sess_bind_")): 

99 raise RuntimeError( # pragma: no cover 

100 "Attributes 'loss_score_sess_bind_' or 'loss_score_sess_' is " 

101 "missing. Method 'build_onnx_function' has not been called.") 

102 bind = self.loss_score_sess_bind_ 

103 if weight is not None: 

104 self._bind_input_ortvalue( 

105 "weight", bind, weight, device, cache=True) 

106 else: 

107 self.clear_binding_inputs("weight", bind, cache=True) 

108 self._bind_input_ortvalue("X1", bind, expected, device, cache=True) 

109 self._bind_input_ortvalue("X2", bind, predicted, device, cache=True) 

110 self.loss_score_sess_bind_.bind_output('Y', device) 

111 self._call_iobinding(self.loss_score_sess_._sess, bind) 

112 score = bind.get_outputs() 

113 return score[0] 

114 

115 @staticmethod 

116 def select(class_name, **kwargs): 

117 """ 

118 Returns an instance of a given initialized with 

119 *kwargs*. 

120 :param class_name: an instance of @see cl BaseLearningLoss 

121 or a string among the following class names (see below) 

122 :return: instance of @see cl BaseLearningLoss 

123 

124 Possible values for *class_name*: 

125 * `'square_error'`: see @see cl SquareLearningLoss 

126 * `'absolute_error'`: see @see cl AbsoluteLearningLoss 

127 * `'elastic_error'`: see @see cl ElasticLearningLoss 

128 """ 

129 if isinstance(class_name, BaseLearningLoss): 

130 return class_name 

131 cls = {SquareLearningLoss: ['square_error', 'square'], 

132 AbsoluteLearningLoss: ['absolute_error', 'absolute'], 

133 ElasticLearningLoss: ['elastic_error', 'elastic'], 

134 NegLogLearningLoss: ['log', 'neglog', 'logloss']} 

135 for cl, aliases in cls.items(): 

136 if class_name == cl.__class__.__name__ or class_name in aliases: 

137 return cl(**kwargs) 

138 raise ValueError( # pragma: no cover 

139 "Unexpected class name %r. It should be one of %r." % ( 

140 class_name, list(map(lambda c: c.__name__, cls)))) 

141 

142 

143class SquareLearningLoss(BaseLearningLoss): 

144 """ 

145 Implements a square loss :math:`(Y - Z)^2` 

146 where *Y* is the output and *Z* the expected output. 

147 See @see fn _onnx_grad_loss_square_error for the ONNX 

148 implementation. 

149 """ 

150 

151 def __init__(self): 

152 BaseLearningLoss.__init__(self) 

153 

154 def build_onnx_function(self, opset, device, weight_name): 

155 so = SessionOptions() 

156 so.log_severity_level = 4 

157 

158 # loss_grad 

159 self.loss_grad_onnx_ = function_onnx_graph( 

160 "grad_loss_square_error", target_opset=opset, 

161 weight_name=weight_name, multiply=1) 

162 self.loss_grad_sess_ = InferenceSession( 

163 self.loss_grad_onnx_.SerializeToString(), so, 

164 providers=device_to_providers(device)) 

165 self.loss_grad_sess_bind_ = ( 

166 self.loss_grad_sess_.io_binding()._iobinding) 

167 

168 # score 

169 self.build_onnx_score_function(opset, device, weight_name) 

170 

171 

172class AbsoluteLearningLoss(BaseLearningLoss): 

173 """ 

174 Implements a square loss :math:`|Y - Z|` 

175 where *Y* is the output and *Z* the expected output. 

176 See @see fn _onnx_grad_loss_absolute_error for the ONNX 

177 implementation. 

178 """ 

179 

180 def __init__(self): 

181 BaseLearningLoss.__init__(self) 

182 

183 def build_onnx_function(self, opset, device, weight_name): 

184 so = SessionOptions() 

185 so.log_severity_level = 4 

186 

187 # loss_grad 

188 self.loss_grad_onnx_ = function_onnx_graph( 

189 "grad_loss_absolute_error", target_opset=opset, 

190 weight_name=weight_name) 

191 self.loss_grad_sess_ = InferenceSession( 

192 self.loss_grad_onnx_.SerializeToString(), so, 

193 providers=device_to_providers(device)) 

194 self.loss_grad_sess_bind_ = ( 

195 self.loss_grad_sess_.io_binding()._iobinding) 

196 

197 # score 

198 self.build_onnx_score_function(opset, device, weight_name) 

199 

200 

201class ElasticLearningLoss(BaseLearningLoss): 

202 """ 

203 Implements a square loss 

204 :math:`(Y - Z)^2 \\alpha + |Y - Z| * \\beta` 

205 where *Y* is the output and *Z* the expected output, 

206 :math:`\\alpha` is *l2_weight* and :math:`\\beta` 

207 is *l1_weight*. 

208 

209 :param l1_weight: weight of L1 norm 

210 :param l2_weight: weight of L2 norm 

211 

212 See @see fn _onnx_grad_loss_elastic_error for the ONNX 

213 implementation. 

214 """ 

215 

216 def __init__(self, l1_weight=0.5, l2_weight=0.5): 

217 BaseLearningLoss.__init__(self) 

218 self.l1_weight = l1_weight 

219 self.l2_weight = l2_weight 

220 

221 def build_onnx_function(self, opset, device, weight_name): 

222 so = SessionOptions() 

223 so.log_severity_level = 4 

224 

225 # loss_grad 

226 self.loss_grad_onnx_ = function_onnx_graph( 

227 "grad_loss_elastic_error", target_opset=opset, 

228 weight_name=weight_name, l1_weight=self.l1_weight, 

229 l2_weight=self.l2_weight) 

230 self.loss_grad_sess_ = InferenceSession( 

231 self.loss_grad_onnx_.SerializeToString(), so, 

232 providers=device_to_providers(device)) 

233 self.loss_grad_sess_bind_ = ( 

234 self.loss_grad_sess_.io_binding()._iobinding) 

235 

236 # score 

237 self.build_onnx_score_function(opset, device, weight_name) 

238 

239 

240class NegLogLearningLoss(BaseLearningLoss): 

241 """ 

242 Implements a negative log loss 

243 `'log(yt, yp) = -(1-yt)\\log(1-yp) - yt\\log(yp)`, 

244 this only works for a binary classification where *yp* is the 

245 predicted probability, *yt* is the expected probability. 

246 *yt* is expected to be binary, *yp* is a matrix with two 

247 columns, the sum on every line is 1. 

248 However, this loss is usually applied after a function softmax 

249 and the gradient is directly computed from the loss to the 

250 raw score before they are processed through the softmax function 

251 (see class `Log 

252 <https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/ 

253 linear_model/_sgd_fast.pyx#L236>`_). 

254 

255 :param eps: clipping value for probabilities, 

256 avoids computing `log(0)` 

257 :param probability_function: function to convert 

258 raw scores into probabilities, default value is `sigmoid` 

259 for a logistic regression 

260 """ 

261 

262 def __init__(self, eps=1e-5, probability_function='sigmoid'): 

263 BaseLearningLoss.__init__(self) 

264 self.eps = eps 

265 self.probability_function = probability_function 

266 

267 def build_onnx_function(self, opset, device, weight_name): 

268 so = SessionOptions() 

269 so.log_severity_level = 4 

270 

271 # loss_grad 

272 fct_name = f"grad_{self.probability_function}_neg_log_loss_error" 

273 self.loss_grad_onnx_ = function_onnx_graph( 

274 fct_name, target_opset=opset, 

275 weight_name=weight_name, eps=self.eps) 

276 self.loss_grad_sess_ = InferenceSession( 

277 self.loss_grad_onnx_.SerializeToString(), so, 

278 providers=device_to_providers(device)) 

279 self.loss_grad_sess_bind_ = ( 

280 self.loss_grad_sess_.io_binding()._iobinding) 

281 

282 # score 

283 self.build_onnx_score_function(opset, device, weight_name)