Coverage for mlprodict/onnxrt/ops_cpu/op_momentum.py: 100%

26 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7from ._op import OpRun 

8 

9 

10def _apply_momentum(r, t, x, g, v, norm_coefficient, alpha, beta): 

11 # Add gradient of regularization term. 

12 g_regularized = norm_coefficient * x + g 

13 # Coefficient of gradient should be 1 at the first iteration. 

14 beta_adjusted = beta if t > 0 else 1 

15 # Update momentum. 

16 v_new = alpha * v + beta_adjusted * g_regularized 

17 # Apply SG with momentum update rule. 

18 x_new = x - r * v_new 

19 return x_new, v_new 

20 

21 

22class Momentum(OpRun): 

23 

24 atts = {'alpha': 0, 

25 'beta': 0, 

26 'mode': b'standard', 

27 'norm_coefficient': 0.} 

28 

29 def __init__(self, onnx_node, desc=None, **options): 

30 OpRun.__init__(self, onnx_node, desc=desc, 

31 expected_attributes=Momentum.atts, 

32 **options) 

33 

34 def _run(self, *data, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

35 if len(data) == 5: 

36 return self._run1(*data) 

37 n = (len(data) - 2) // 3 

38 xs = [] 

39 vs = [] 

40 for i in range(0, n): 

41 a, b = self._run1(*data[:2], data[2 + i], 

42 data[2 + n + i], data[2 + n * 2 + i]) 

43 xs.append(a) 

44 vs.append(b) 

45 return tuple(xs + vs) 

46 

47 def _run1(self, r, t, x, g, v): # pylint: disable=W0221 

48 x_new, v_new = _apply_momentum( 

49 r, t, x, g, v, self.norm_coefficient, self.alpha, self.beta) 

50 return x_new, v_new