Coverage for onnxcustom/training/sgd_learning_rate.py: 100%

134 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 01:42 +0100

1# pylint: disable=W0105 

2""" 

3@file 

4@brief Helper for :epkg:`onnxruntime-training`. 

5""" 

6import numpy 

7from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE 

8from onnxruntime import SessionOptions, InferenceSession, RunOptions 

9from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

10 OrtValue as C_OrtValue) 

11from ..utils.onnx_function import function_onnx_graph 

12from ..utils.onnxruntime_helper import device_to_providers 

13from ._base_onnx_function import BaseLearningOnnx 

14 

15 

16class BaseLearningRate(BaseLearningOnnx): 

17 """ 

18 Class handling the learning rate update after every 

19 iteration of a gradient. Two methods need to be overwritten 

20 `init_learning_rate` and `update_learning_rate`. The first one 

21 starts the loop, the second returns the next one. 

22 """ 

23 

24 def __init__(self): 

25 BaseLearningOnnx.__init__(self) 

26 self.ro_ = RunOptions() 

27 

28 def _call_iobinding(self, sess, bind): 

29 sess.run_with_iobinding(bind, self.ro_) 

30 

31 def init_learning_rate(self): 

32 """ 

33 Initializes the learning rate at the beginning of the training. 

34 :return: self 

35 """ 

36 raise NotImplementedError( 

37 "This method must be overwritten.") 

38 

39 def update_learning_rate(self, t): 

40 """ 

41 Updates the learning rate at the end of an iteration. 

42 :param t: iteration number 

43 :return: self 

44 """ 

45 raise NotImplementedError( 

46 "This method must be overwritten.") 

47 

48 @property 

49 def value(self): 

50 "Returns the current learning rate." 

51 raise NotImplementedError( 

52 "This method must be overwritten.") 

53 

54 def __repr_extended__(self): 

55 return ( 

56 f', value={self.value!r}' 

57 if hasattr(self, 'value_') and self.value_ is not None # pylint: disable=E1101 

58 else '') 

59 

60 @property 

61 def needs_grad(self): 

62 """ 

63 Returns the True if the gradient update needs to retain 

64 past gradients. 

65 """ 

66 raise NotImplementedError( 

67 "This method must be overwritten.") 

68 

69 def update_weights(self, device, statei, gradienti, batch_size, 

70 velocity=None): 

71 """ 

72 Updates weights based on the algorithm this class 

73 is setting up. 

74 

75 :param device: device 

76 :param statei: current weight 

77 :param gradienti: gradient 

78 :param batch_size: batch_size 

79 :param velocity: same shape as the gradient 

80 """ 

81 raise NotImplementedError( 

82 "This method must be overwritten.") 

83 

84 def loop(self, n=1000): 

85 """ 

86 Loops over learning rate values, *n* to be precise. 

87 :param n: number of requested iterations 

88 :return: iterator 

89 """ 

90 self.init_learning_rate() 

91 for i in range(n): 

92 yield self.value 

93 self.update_learning_rate(i + 1) 

94 

95 @staticmethod 

96 def select(class_name, **kwargs): 

97 """ 

98 Returns an instance of a given initialized with 

99 *kwargs*. 

100 :param class_name: an instance of @see cl BaseLearningRate 

101 or a string among the following class names (see below), 

102 it can also be a float and in that case, class 

103 @see cl LearningRateSGD is used 

104 :return: instance of @see cl BaseLearningRate 

105 

106 Possible values for *class_name*: 

107 * `'SGD'` or `'LearningRateSGD'`: see @see cl LearningRateSGD 

108 """ 

109 if isinstance(class_name, BaseLearningRate): 

110 return class_name 

111 if isinstance(class_name, float): 

112 return LearningRateSGD(class_name) 

113 cls = {LearningRateSGD: ['SGD'], 

114 LearningRateSGDNesterov: ['SGDNesterov', 'Nesterov']} 

115 for cl, aliases in cls.items(): 

116 if class_name == cl.__class__.__name__ or class_name in aliases: 

117 return cl(**kwargs) 

118 raise ValueError( # pragma: no cover 

119 "Unexpected class name %r. It should be one of %r." % ( 

120 class_name, list(map(lambda c: c.__name__, cls)))) 

121 

122 

123class LearningRateSGD(BaseLearningRate): 

124 """ 

125 Implements the learning the same way as 

126 :class:`sklearn.linear_model.SGDRegressor`. 

127 

128 :param eta0: initial learning rate for the `'constant'`, `'invscaling'` 

129 or `'adaptive'` schedules. 

130 :param alpha: constant that multiplies the regularization term, 

131 the higher the value, the stronger the regularization. 

132 Also used to compute the learning rate when set to *learning_rate* 

133 is set to `'optimal'`. 

134 :param power_t: exponent for inverse scaling learning rate 

135 :param learning_rate: learning rate schedule: 

136 * `'constant'`: `eta = eta0` 

137 * `'optimal'`: `eta = 1.0 / (alpha * (t + t0))` where *t0* is chosen 

138 by a heuristic proposed by Leon Bottou, this number is multiplied 

139 by a constant C to make the first number equal to *eta0* 

140 * `'invscaling'`: `eta = eta0 / pow(t, power_t)` 

141 

142 Created attributes: 

143 * `eta0_`: initial eta0 

144 * `optimal_init_`: use when `learning_rate=='optimal'` 

145 * `value_`: value to be returned by property `value` 

146 """ 

147 

148 def __init__(self, eta0=0.01, alpha=0.0001, power_t=0.25, 

149 learning_rate='invscaling'): 

150 BaseLearningRate.__init__(self) 

151 if learning_rate not in ('invscaling', 'optimal', 'constant'): 

152 raise ValueError( 

153 f"Unxepected value for learning_rate={learning_rate!r}.") 

154 self.eta0 = eta0 

155 self.alpha = alpha 

156 self.power_t = power_t 

157 self.learning_rate = learning_rate.lower() 

158 self.value_ = None 

159 

160 @property 

161 def value(self): 

162 "Returns the current learning rate." 

163 if self.value_ is None: 

164 raise RuntimeError( # pragma: no cover 

165 "Method init_learning_rate was never called.") 

166 return self.value_ 

167 

168 @property 

169 def needs_grad(self): 

170 """ 

171 Returns the True if the gradient update needs to retain 

172 past gradients. 

173 """ 

174 return False 

175 

176 def init_learning_rate(self): 

177 """ 

178 Updates the learning rate at the end of an iteration. 

179 :return: self 

180 """ 

181 self.eta0_ = self.eta0 

182 if self.learning_rate == "optimal": 

183 typw = numpy.sqrt(1.0 / numpy.sqrt(self.alpha)) 

184 eta0 = typw / max(1.0, (1 + typw) * 2) 

185 self.optimal_init_ = 1.0 / (eta0 * self.alpha) 

186 eta = 1. / (self.alpha * self.optimal_init_) 

187 self.optimal_fact_ = self.eta0 / eta 

188 self.eta0_ = self.eta0 

189 else: 

190 self.eta0_ = self.eta0 

191 self.value_ = self.eta0_ 

192 return self 

193 

194 def update_learning_rate(self, t): 

195 """ 

196 Updates the learning rate at the end of an iteration. 

197 :param t: iteration number 

198 :return: self 

199 """ 

200 eta = self.value_ 

201 if self.learning_rate == "optimal": 

202 eta = self.optimal_fact_ / (self.alpha * (self.optimal_init_ + t)) 

203 elif self.learning_rate == "invscaling": 

204 eta = self.eta0_ / numpy.power(t + 1, self.power_t) 

205 self.value_ = eta 

206 return self 

207 

208 def build_onnx_function(self, opset, device, n_tensors): 

209 so = SessionOptions() 

210 so.log_severity_level = 4 

211 

212 self.axpy_onnx_ = function_onnx_graph("axpy") 

213 self.axpy_sess_ = InferenceSession( 

214 self.axpy_onnx_.SerializeToString(), so, 

215 providers=device_to_providers(device)) 

216 self.axpy_sess_binds_ = [ 

217 self.axpy_sess_.io_binding()._iobinding 

218 for i in range(n_tensors)] 

219 self.alpha_ = numpy.array( 

220 [0], dtype=TENSOR_TYPE_TO_NP_TYPE[ 

221 self.axpy_onnx_.graph.input[0].type.tensor_type.elem_type]) 

222 

223 def update_weights(self, n_bind, device, statei, # pylint: disable=W0237 

224 gradienti, batch_size, velocity=None): 

225 if velocity is not None: 

226 raise RuntimeError( # pragma: no cover 

227 "Velocity must be None for this way of updating weights.") 

228 if (not hasattr(self, "axpy_onnx_") or 

229 not hasattr(self, "axpy_sess_binds_")): 

230 raise RuntimeError( # pragma: no cover 

231 "Attributes 'axpy_sess_binds_' or " 

232 "'axpy_onnx_' is missing. Method " 

233 "'build_onnx_function' has not been called.") 

234 bind = self.axpy_sess_binds_[n_bind] 

235 self._bind_input_ortvalue("X1", bind, gradienti, device, cache=True) 

236 self._bind_input_ortvalue("X2", bind, statei, device, cache=True) 

237 self.alpha_[0] = - self.value / batch_size # pylint: disable=E1130 

238 ort_alpha = C_OrtValue.ortvalue_from_numpy(self.alpha_, device) 

239 self._bind_input_ortvalue("alpha", bind, ort_alpha, device, cache=True) 

240 self._bind_output_ortvalue('Y', bind, statei, cache=True) 

241 self._call_iobinding(self.axpy_sess_._sess, bind) 

242 new_weights = bind.get_outputs()[0] 

243 return new_weights 

244 

245 

246class LearningRateSGDNesterov(LearningRateSGD): 

247 """ 

248 Implements the learning the same way as 

249 :class:`sklearn.linear_model.SGDRegressor`. 

250 

251 :param eta0: initial learning rate for the `'constant'`, `'invscaling'` 

252 or `'adaptive'` schedules. 

253 :param alpha: constant that multiplies the regularization term, 

254 the higher the value, the stronger the regularization. 

255 Also used to compute the learning rate when set to *learning_rate* 

256 is set to `'optimal'`. 

257 :param power_t: exponent for inverse scaling learning rate 

258 :param learning_rate: learning rate schedule: 

259 * `'constant'`: `eta = eta0` 

260 * `'optimal'`: `eta = 1.0 / (alpha * (t + t0))` where *t0* is chosen 

261 by a heuristic proposed by Leon Bottou, this number is multiplied 

262 by a constant C to make the first number equal to *eta0* 

263 * `'invscaling'`: `eta = eta0 / pow(t, power_t)` 

264 :param momentum: float, default=0.9 

265 Value of momentum used, must be larger than or equal to 0. 

266 :param nesterov: bool, default=True 

267 Whether to use nesterov's momentum or not. Use nesterov's if True 

268 Not using nesterov is equivalent to class @see cl LearningRateSGD. 

269 

270 Created attributes: 

271 * `eta0_`: initial eta0 

272 * `optimal_init_`: use when `learning_rate=='optimal'` 

273 * `value_`: value to be returned by property `value` 

274 

275 :: 

276 

277 updates = [ 

278 self.momentum * velocity - self.learning_rate * grad 

279 for velocity, grad in zip(self.velocities, grads)] 

280 self.velocities = updates 

281 

282 if self.nesterov: 

283 updates_nesterov = [ 

284 self.momentum * velocity - self.learning_rate * grad 

285 for velocity, grad in zip(self.velocities, grads)] 

286 return updates, updates_nesterov --> new gradient and velocities 

287 else: 

288 return updates --> new gradient 

289 """ 

290 

291 def __init__(self, eta0=0.01, alpha=0.0001, power_t=0.25, 

292 learning_rate='invscaling', momentum=0.9, nesterov=True): 

293 LearningRateSGD.__init__( 

294 self, eta0=eta0, alpha=alpha, power_t=power_t, 

295 learning_rate=learning_rate) 

296 self.momentum = momentum 

297 self.nesterov = nesterov 

298 

299 @property 

300 def needs_grad(self): 

301 """ 

302 Returns the True if the gradient update needs to retain 

303 past gradients. 

304 """ 

305 return True 

306 

307 def init_learning_rate(self): 

308 """ 

309 Updates the learning rate at the end of an iteration. 

310 :return: self 

311 """ 

312 return LearningRateSGD.init_learning_rate(self) 

313 

314 def update_learning_rate(self, t): 

315 """ 

316 Updates the learning rate at the end of an iteration. 

317 :param t: iteration number 

318 :return: self 

319 """ 

320 return LearningRateSGD.update_learning_rate(self, t) 

321 

322 def build_onnx_function(self, opset, device, n_tensors): 

323 so = SessionOptions() 

324 so.log_severity_level = 4 

325 

326 # axpyw 

327 if self.nesterov: 

328 self.axpyw_onnx_ = function_onnx_graph("axpyw2") 

329 else: 

330 self.axpyw_onnx_ = function_onnx_graph("axpyw") 

331 self.axpyw_sess_ = InferenceSession( 

332 self.axpyw_onnx_.SerializeToString(), so, 

333 providers=device_to_providers(device)) 

334 self.axpyw_sess_binds_ = [ 

335 self.axpyw_sess_.io_binding()._iobinding 

336 for n in range(n_tensors)] 

337 

338 self.alpha_ = numpy.array( 

339 [0], dtype=TENSOR_TYPE_TO_NP_TYPE[ 

340 self.axpyw_onnx_.graph.input[0].type.tensor_type.elem_type]) 

341 self.beta_ = numpy.array( 

342 [0], dtype=TENSOR_TYPE_TO_NP_TYPE[ 

343 self.axpyw_onnx_.graph.input[0].type.tensor_type.elem_type]) 

344 

345 def update_weights(self, n_bind, device, statei, gradienti, batch_size, 

346 velocity=None): 

347 if (not hasattr(self, "axpyw_onnx_") or 

348 not hasattr(self, "axpyw_sess_binds_")): 

349 raise RuntimeError( # pragma: no cover 

350 "Attributes 'axpyw_sess_binds_' or " 

351 "'axpyw_onnx_' is missing. Method " 

352 "'build_onnx_function' has not been called.") 

353 if velocity is None: 

354 raise RuntimeError( # pragma: no cover 

355 "Velocity must not be None for this way of updating weights.") 

356 bind = self.axpyw_sess_binds_[n_bind] 

357 self._bind_input_ortvalue("X1", bind, gradienti, device, cache=True) 

358 self._bind_input_ortvalue("X2", bind, statei, device, cache=True) 

359 self._bind_input_ortvalue("G", bind, velocity, device, cache=True) 

360 self.alpha_[0] = - self.value / batch_size # pylint: disable=E1130 

361 self.beta_[0] = self.momentum 

362 ort_alpha = C_OrtValue.ortvalue_from_numpy(self.alpha_, device) 

363 ort_beta = C_OrtValue.ortvalue_from_numpy(self.beta_, device) 

364 self._bind_input_ortvalue("alpha", bind, ort_alpha, device, cache=True) 

365 self._bind_input_ortvalue("beta", bind, ort_beta, device, cache=True) 

366 self._bind_output_ortvalue('Y', bind, statei, cache=True) 

367 self._bind_output_ortvalue('Z', bind, velocity, cache=True) 

368 self._call_iobinding(self.axpyw_sess_._sess, bind) 

369 return bind.get_outputs() # loss, velocity