Coverage for mlprodict/onnxrt/ops_cpu/op_rnn.py: 92%
64 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from onnx.defs import onnx_opset_version
9from ._op import OpRun
12class CommonRNN(OpRun):
14 def __init__(self, onnx_node, expected_attributes=None, desc=None,
15 **options):
16 OpRun.__init__(self, onnx_node, desc=desc,
17 expected_attributes=expected_attributes,
18 **options)
20 if self.direction in (b"forward", b"reverse"):
21 self.num_directions = 1
22 elif self.direction == "bidirectional":
23 self.num_directions = 2
24 else:
25 raise RuntimeError( # pragma: no cover
26 f"Unknown direction '{self.direction}'.")
28 if len(self.activation_alpha) != self.num_directions:
29 raise RuntimeError( # pragma: no cover
30 "activation_alpha must have the same size as num_directions={}".format(
31 self.num_directions))
32 if len(self.activation_beta) != self.num_directions:
33 raise RuntimeError( # pragma: no cover
34 "activation_beta must have the same size as num_directions={}".format(
35 self.num_directions))
37 self.f1 = self.choose_act(
38 self.activations[0],
39 self.activation_alpha[0] if len(
40 self.activation_alpha) > 0 else None,
41 self.activation_beta[0] if len(self.activation_beta) > 0 else None)
42 if len(self.activations) > 1:
43 self.f2 = self.choose_act(
44 self.activations[1],
45 self.activation_alpha[1] if len(
46 self.activation_alpha) > 1 else None,
47 self.activation_beta[1] if len(self.activation_beta) > 1 else None)
48 self.nb_outputs = len(onnx_node.output)
50 def choose_act(self, name, alpha, beta):
51 if name in (b"Tanh", b'tanh', 'tanh', 'Tanh'):
52 return self._f_tanh
53 if name in (b"Affine", b"affine", 'Affine', 'affine'):
54 return lambda x: x * alpha + beta
55 raise RuntimeError( # pragma: no cover
56 f"Unknown activation function '{name}'.")
58 def _f_tanh(self, x):
59 return numpy.tanh(x)
61 def _step(self, X, R, B, W, H_0):
62 h_list = []
63 H_t = H_0
64 for x in numpy.split(X, X.shape[0], axis=0):
65 H = self.f1(numpy.dot(x, numpy.transpose(W)) +
66 numpy.dot(H_t, numpy.transpose(R)) +
67 numpy.add(*numpy.split(B, 2)))
68 h_list.append(H)
69 H_t = H
70 concatenated = numpy.concatenate(h_list)
71 if self.num_directions == 1:
72 output = numpy.expand_dims(concatenated, 1)
73 return output, h_list[-1]
75 def _run(self, X, W, R, B=None, sequence_lens=None, initial_h=None, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
76 self.num_directions = W.shape[0]
78 if self.num_directions == 1:
79 R = numpy.squeeze(R, axis=0)
80 W = numpy.squeeze(W, axis=0)
81 if B is not None:
82 B = numpy.squeeze(B, axis=0)
83 if sequence_lens is not None:
84 sequence_lens = numpy.squeeze(sequence_lens, axis=0)
85 if initial_h is not None:
86 initial_h = numpy.squeeze(initial_h, axis=0)
88 hidden_size = R.shape[-1]
89 batch_size = X.shape[1]
91 b = (B if B is not None else
92 numpy.zeros(2 * hidden_size, dtype=X.dtype))
93 h_0 = (initial_h if initial_h is not None else
94 numpy.zeros((batch_size, hidden_size), dtype=X.dtype))
96 B = b
97 H_0 = h_0
98 else:
99 raise NotImplementedError( # pragma: no cover
100 "Unsupported value %r for num_directions and operator %r." % (
101 self.num_directions, self.__class__.__name__))
103 Y, Y_h = self._step(X, R, B, W, H_0)
104 # if self.layout == 1:
105 # #Y = numpy.transpose(Y, [2, 0, 1, 3])
106 # Y_h = Y[:, :, -1, :]
108 return (Y, ) if self.nb_outputs == 1 else (Y, Y_h)
111class RNN_7(CommonRNN):
113 atts = {
114 'activation_alpha': [0.],
115 'activation_beta': [0.],
116 'activations': [b'Tanh', b'Tanh'],
117 'clip': [],
118 'direction': b'forward',
119 'hidden_size': None,
120 }
122 def __init__(self, onnx_node, desc=None, **options):
123 CommonRNN.__init__(self, onnx_node, desc=desc,
124 expected_attributes=RNN_7.atts,
125 **options)
128class RNN_14(CommonRNN):
130 atts = {
131 'activation_alpha': [0.],
132 'activation_beta': [0.],
133 'activations': [b'Tanh', b'Tanh'],
134 'clip': [],
135 'direction': b'forward',
136 'hidden_size': None,
137 'layout': 0,
138 }
140 def __init__(self, onnx_node, desc=None, **options):
141 CommonRNN.__init__(self, onnx_node, desc=desc,
142 expected_attributes=RNN_14.atts,
143 **options)
146if onnx_opset_version() >= 14:
147 RNN = RNN_14
148else: # pragma: no cover
149 RNN = RNN_7