Coverage for mlprodict/onnxrt/ops_cpu/op_gru.py: 97%
67 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
11class CommonGRU(OpRun):
13 def __init__(self, onnx_node, expected_attributes=None, desc=None,
14 **options):
15 OpRun.__init__(self, onnx_node, desc=desc,
16 expected_attributes=expected_attributes,
17 **options)
18 self.nb_outputs = len(onnx_node.output)
19 self.number_of_gates = 3
21 def f(self, x):
22 return 1 / (1 + numpy.exp(-x))
24 def g(self, x):
25 return numpy.tanh(x)
27 def _step(self, X, R, B, W, H_0):
28 seq_length = X.shape[0]
29 hidden_size = H_0.shape[-1]
30 batch_size = X.shape[1]
32 Y = numpy.empty(
33 [seq_length, self.num_directions, batch_size, hidden_size])
34 h_list = []
36 [w_z, w_r, w_h] = numpy.split(W, 3) # pylint: disable=W0632
37 [r_z, r_r, r_h] = numpy.split(R, 3) # pylint: disable=W0632
38 [w_bz, w_br, w_bh, r_bz, r_br, r_bh] = numpy.split( # pylint: disable=W0632
39 B, 6) # pylint: disable=W0632
40 gates_w = numpy.transpose(numpy.concatenate((w_z, w_r)))
41 gates_r = numpy.transpose(numpy.concatenate((r_z, r_r)))
42 gates_b = numpy.add(numpy.concatenate((w_bz, w_br)),
43 numpy.concatenate((r_bz, r_br)))
45 H_t = H_0
46 for x in numpy.split(X, X.shape[0], axis=0):
47 gates = numpy.dot(x, gates_w) + numpy.dot(H_t, gates_r) + gates_b
48 z, r = numpy.split(gates, 2, -1) # pylint: disable=W0632
49 z = self.f(z)
50 r = self.f(r)
51 h_default = self.g(numpy.dot(x, numpy.transpose(
52 w_h)) + numpy.dot(r * H_t, numpy.transpose(r_h)) + w_bh + r_bh)
53 h_linear = self.g(numpy.dot(x, numpy.transpose(
54 w_h)) + r * (numpy.dot(H_t, numpy.transpose(r_h)) + r_bh) + w_bh)
55 h = h_linear if self.linear_before_reset else h_default
56 H = (1 - z) * h + z * H_t
57 h_list.append(H)
58 H_t = H
60 concatenated = numpy.concatenate(h_list)
61 if self.num_directions == 1:
62 Y[:, 0, :, :] = concatenated
64 if self.layout == 0:
65 Y_h = Y[-1]
66 else:
67 Y = numpy.transpose(Y, [2, 0, 1, 3])
68 Y_h = Y[:, :, -1, :]
70 return Y, Y_h
72 def _run(self, X, W, R, B=None, attributes=None, sequence_lens=None, # pylint: disable=W0221
73 initial_h=None, verbose=0, fLOG=None):
74 self.num_directions = W.shape[0]
76 if self.num_directions == 1:
77 R = numpy.squeeze(R, axis=0)
78 W = numpy.squeeze(W, axis=0)
79 if B is not None:
80 B = numpy.squeeze(B, axis=0)
81 if sequence_lens is not None:
82 sequence_lens = numpy.squeeze(sequence_lens, axis=0)
83 if initial_h is not None:
84 initial_h = numpy.squeeze(initial_h, axis=0)
86 hidden_size = R.shape[-1]
87 batch_size = X.shape[1]
89 b = (B if B is not None else
90 numpy.zeros(2 * self.number_of_gates * hidden_size, dtype=X.dtype))
91 h_0 = (initial_h if initial_h is not None else
92 numpy.zeros((batch_size, hidden_size), dtype=X.dtype))
94 B = b
95 H_0 = h_0
96 else:
97 raise NotImplementedError( # pragma: no cover
98 "Unsupported value %r for num_directions and operator %r." % (
99 self.num_directions, self.__class__.__name__))
101 Y, Y_h = self._step(X, R, B, W, H_0)
103 return (Y, ) if self.nb_outputs == 1 else (Y, Y_h)
106class GRU(CommonGRU):
108 atts = {
109 'activation_alpha': [0.],
110 'activation_beta': [0.],
111 'activations': [b'Tanh', b'Tanh'],
112 'clip': [],
113 'direction': b'forward',
114 'hidden_size': None,
115 'layout': 0,
116 'linear_before_reset': 0,
117 }
119 def __init__(self, onnx_node, desc=None, **options):
120 CommonGRU.__init__(self, onnx_node, desc=desc,
121 expected_attributes=GRU.atts,
122 **options)