Coverage for mlprodict/onnxrt/ops_cpu/op_adam.py: 97%

35 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9 

10 

11def _apply_adam(r, t, x, g, v, h, 

12 norm_coefficient, norm_coefficient_post, 

13 alpha, beta, epsilon): # type: ignore 

14 # Add gradient of regularization term. 

15 g_regularized = norm_coefficient * x + g 

16 # Update momentum. 

17 v_new = alpha * v + (1 - alpha) * g_regularized 

18 # Update second-order momentum. 

19 h_new = beta * h + (1 - beta) * (g_regularized * g_regularized) 

20 # Compute element-wise square root. 

21 h_sqrt = numpy.sqrt(h_new) + epsilon 

22 # Adjust learning rate. 

23 r_adjusted = None 

24 if t > 0: 

25 # Consider bias correction on momentums. 

26 r_adjusted = r * numpy.sqrt(1 - beta**t) / (1 - alpha**t) 

27 else: 

28 # No bias correction on momentums. 

29 r_adjusted = r 

30 # Apply Adam update rule. 

31 x_new = x - r_adjusted * (v_new / h_sqrt) 

32 # It's possible to apply regularization in the end. 

33 x_final = (1 - norm_coefficient_post) * x_new 

34 return x_final, v_new, h_new 

35 

36 

37class Adam(OpRun): 

38 

39 atts = {'alpha': 0.8999999761581421, 

40 'beta': 0.9990000128746033, 

41 'epsilon': 9.999999974752427e-07, 

42 'norm_coefficient': 0., 

43 'norm_coefficient_post': 0.} 

44 

45 def __init__(self, onnx_node, desc=None, **options): 

46 OpRun.__init__(self, onnx_node, desc=desc, 

47 expected_attributes=Adam.atts, 

48 **options) 

49 

50 def _run(self, *data, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

51 if len(data) == 6: 

52 return self._run1(*data) 

53 n = (len(data) - 2) // 4 

54 xs = [] 

55 vs = [] 

56 hs = [] 

57 for i in range(0, n): 

58 a, b, c = self._run1(*data[:2], data[2 + i], 

59 data[2 + n + i], data[2 + n * 2 + i], 

60 data[2 + n * 3 + i]) 

61 xs.append(a) 

62 vs.append(b) 

63 hs.append(c) 

64 return tuple(xs + vs + hs) 

65 

66 def _run1(self, r, t, x, g, v, h): # pylint: disable=W0221 

67 x_new, v_new, h_new = _apply_adam( 

68 r, t, x, g, v, h, self.norm_coefficient, 

69 self.norm_coefficient_post, self.alpha, self.beta, self.epsilon) 

70 return x_new, v_new, h_new