Coverage for mlprodict/onnxrt/ops_cpu/op_lstm.py: 96%
78 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
11class CommonLSTM(OpRun):
13 def __init__(self, onnx_node, expected_attributes=None, desc=None,
14 **options):
15 OpRun.__init__(self, onnx_node, desc=desc,
16 expected_attributes=expected_attributes,
17 **options)
18 self.nb_outputs = len(onnx_node.output)
19 self.number_of_gates = 3
21 def f(self, x):
22 return 1 / (1 + numpy.exp(-x))
24 def g(self, x):
25 return numpy.tanh(x)
27 def h(self, x):
28 return numpy.tanh(x)
30 def _step(self, X, R, B, W, H_0, C_0, P):
31 seq_length = X.shape[0]
32 hidden_size = H_0.shape[-1]
33 batch_size = X.shape[1]
35 Y = numpy.empty(
36 [seq_length, self.num_directions, batch_size, hidden_size])
37 h_list = []
39 [p_i, p_o, p_f] = numpy.split(P, 3) # pylint: disable=W0632
40 H_t = H_0
41 C_t = C_0
42 for x in numpy.split(X, X.shape[0], axis=0):
43 gates = numpy.dot(x, numpy.transpose(W)) + numpy.dot(H_t, numpy.transpose(R)) + numpy.add(
44 *numpy.split(B, 2))
45 i, o, f, c = numpy.split(gates, 4, -1) # pylint: disable=W0632
46 i = self.f(i + p_i * C_t)
47 f = self.f(f + p_f * C_t)
48 c = self.g(c)
49 C = f * C_t + i * c
50 o = self.f(o + p_o * C)
51 H = o * self.h(C)
52 h_list.append(H)
53 H_t = H
54 C_t = C
56 concatenated = numpy.concatenate(h_list)
57 if self.num_directions == 1:
58 Y[:, 0, :, :] = concatenated
60 if self.layout == 0:
61 Y_h = Y[-1]
62 else:
63 Y = numpy.transpose(Y, [2, 0, 1, 3])
64 Y_h = Y[:, :, -1, :]
66 return Y, Y_h
68 def _run(self, X, W, R, B=None, sequence_lens=None, # pylint: disable=W0221
69 initial_h=None, initial_c=None, P=None,
70 attributes=None, verbose=0, fLOG=None):
71 number_of_gates = 4
72 number_of_peepholes = 3
74 self.num_directions = W.shape[0]
76 if self.num_directions == 1:
77 R = numpy.squeeze(R, axis=0)
78 W = numpy.squeeze(W, axis=0)
79 if B is not None:
80 B = numpy.squeeze(B, axis=0)
81 if sequence_lens is not None:
82 sequence_lens = numpy.squeeze(sequence_lens, axis=0)
83 if initial_h is not None:
84 initial_h = numpy.squeeze(initial_h, axis=0)
85 if initial_c is not None:
86 initial_c = numpy.squeeze(initial_c, axis=0)
87 if P is not None:
88 P = numpy.squeeze(P, axis=0)
90 hidden_size = R.shape[-1]
91 batch_size = X.shape[1]
93 if self.layout != 0:
94 X = numpy.swapaxes(X, 0, 1)
95 if B is None:
96 B = numpy.zeros(2 * number_of_gates *
97 hidden_size, dtype=numpy.float32)
98 if P is None:
99 P = numpy.zeros(number_of_peepholes *
100 hidden_size, dtype=numpy.float32)
101 if initial_h is None:
102 initial_h = numpy.zeros(
103 (batch_size, hidden_size), dtype=numpy.float32)
104 if initial_c is None:
105 initial_c = numpy.zeros(
106 (batch_size, hidden_size), dtype=numpy.float32)
107 else:
108 raise NotImplementedError( # pragma: no cover
109 "Unsupported value %r for num_directions and operator %r." % (
110 self.num_directions, self.__class__.__name__))
112 Y, Y_h = self._step(X, R, B, W, initial_h, initial_c, P)
114 return (Y, ) if self.nb_outputs == 1 else (Y, Y_h)
117class LSTM(CommonLSTM):
119 atts = {
120 'activation_alpha': [0.],
121 'activation_beta': [0.],
122 'activations': [b'Tanh', b'Tanh'],
123 'clip': [],
124 'direction': b'forward',
125 'hidden_size': None,
126 'layout': 0,
127 'input_forget': 0,
128 }
130 def __init__(self, onnx_node, desc=None, **options):
131 CommonLSTM.__init__(self, onnx_node, desc=desc,
132 expected_attributes=LSTM.atts,
133 **options)