Coverage for mlprodict/onnxrt/ops_cpu/op_momentum.py: 100%
26 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7from ._op import OpRun
10def _apply_momentum(r, t, x, g, v, norm_coefficient, alpha, beta):
11 # Add gradient of regularization term.
12 g_regularized = norm_coefficient * x + g
13 # Coefficient of gradient should be 1 at the first iteration.
14 beta_adjusted = beta if t > 0 else 1
15 # Update momentum.
16 v_new = alpha * v + beta_adjusted * g_regularized
17 # Apply SG with momentum update rule.
18 x_new = x - r * v_new
19 return x_new, v_new
22class Momentum(OpRun):
24 atts = {'alpha': 0,
25 'beta': 0,
26 'mode': b'standard',
27 'norm_coefficient': 0.}
29 def __init__(self, onnx_node, desc=None, **options):
30 OpRun.__init__(self, onnx_node, desc=desc,
31 expected_attributes=Momentum.atts,
32 **options)
34 def _run(self, *data, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
35 if len(data) == 5:
36 return self._run1(*data)
37 n = (len(data) - 2) // 3
38 xs = []
39 vs = []
40 for i in range(0, n):
41 a, b = self._run1(*data[:2], data[2 + i],
42 data[2 + n + i], data[2 + n * 2 + i])
43 xs.append(a)
44 vs.append(b)
45 return tuple(xs + vs)
47 def _run1(self, r, t, x, g, v): # pylint: disable=W0221
48 x_new, v_new = _apply_momentum(
49 r, t, x, g, v, self.norm_coefficient, self.alpha, self.beta)
50 return x_new, v_new