Coverage for mlprodict/onnxrt/ops_cpu/op_adam.py: 97%
35 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
11def _apply_adam(r, t, x, g, v, h,
12 norm_coefficient, norm_coefficient_post,
13 alpha, beta, epsilon): # type: ignore
14 # Add gradient of regularization term.
15 g_regularized = norm_coefficient * x + g
16 # Update momentum.
17 v_new = alpha * v + (1 - alpha) * g_regularized
18 # Update second-order momentum.
19 h_new = beta * h + (1 - beta) * (g_regularized * g_regularized)
20 # Compute element-wise square root.
21 h_sqrt = numpy.sqrt(h_new) + epsilon
22 # Adjust learning rate.
23 r_adjusted = None
24 if t > 0:
25 # Consider bias correction on momentums.
26 r_adjusted = r * numpy.sqrt(1 - beta**t) / (1 - alpha**t)
27 else:
28 # No bias correction on momentums.
29 r_adjusted = r
30 # Apply Adam update rule.
31 x_new = x - r_adjusted * (v_new / h_sqrt)
32 # It's possible to apply regularization in the end.
33 x_final = (1 - norm_coefficient_post) * x_new
34 return x_final, v_new, h_new
37class Adam(OpRun):
39 atts = {'alpha': 0.8999999761581421,
40 'beta': 0.9990000128746033,
41 'epsilon': 9.999999974752427e-07,
42 'norm_coefficient': 0.,
43 'norm_coefficient_post': 0.}
45 def __init__(self, onnx_node, desc=None, **options):
46 OpRun.__init__(self, onnx_node, desc=desc,
47 expected_attributes=Adam.atts,
48 **options)
50 def _run(self, *data, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
51 if len(data) == 6:
52 return self._run1(*data)
53 n = (len(data) - 2) // 4
54 xs = []
55 vs = []
56 hs = []
57 for i in range(0, n):
58 a, b, c = self._run1(*data[:2], data[2 + i],
59 data[2 + n + i], data[2 + n * 2 + i],
60 data[2 + n * 3 + i])
61 xs.append(a)
62 vs.append(b)
63 hs.append(c)
64 return tuple(xs + vs + hs)
66 def _run1(self, r, t, x, g, v, h): # pylint: disable=W0221
67 x_new, v_new, h_new = _apply_adam(
68 r, t, x, g, v, h, self.norm_coefficient,
69 self.norm_coefficient_post, self.alpha, self.beta, self.epsilon)
70 return x_new, v_new, h_new