Coverage for mlprodict/onnxrt/ops_cpu/op_rnn.py: 92%

64 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from onnx.defs import onnx_opset_version 

9from ._op import OpRun 

10 

11 

12class CommonRNN(OpRun): 

13 

14 def __init__(self, onnx_node, expected_attributes=None, desc=None, 

15 **options): 

16 OpRun.__init__(self, onnx_node, desc=desc, 

17 expected_attributes=expected_attributes, 

18 **options) 

19 

20 if self.direction in (b"forward", b"reverse"): 

21 self.num_directions = 1 

22 elif self.direction == "bidirectional": 

23 self.num_directions = 2 

24 else: 

25 raise RuntimeError( # pragma: no cover 

26 f"Unknown direction '{self.direction}'.") 

27 

28 if len(self.activation_alpha) != self.num_directions: 

29 raise RuntimeError( # pragma: no cover 

30 "activation_alpha must have the same size as num_directions={}".format( 

31 self.num_directions)) 

32 if len(self.activation_beta) != self.num_directions: 

33 raise RuntimeError( # pragma: no cover 

34 "activation_beta must have the same size as num_directions={}".format( 

35 self.num_directions)) 

36 

37 self.f1 = self.choose_act( 

38 self.activations[0], 

39 self.activation_alpha[0] if len( 

40 self.activation_alpha) > 0 else None, 

41 self.activation_beta[0] if len(self.activation_beta) > 0 else None) 

42 if len(self.activations) > 1: 

43 self.f2 = self.choose_act( 

44 self.activations[1], 

45 self.activation_alpha[1] if len( 

46 self.activation_alpha) > 1 else None, 

47 self.activation_beta[1] if len(self.activation_beta) > 1 else None) 

48 self.nb_outputs = len(onnx_node.output) 

49 

50 def choose_act(self, name, alpha, beta): 

51 if name in (b"Tanh", b'tanh', 'tanh', 'Tanh'): 

52 return self._f_tanh 

53 if name in (b"Affine", b"affine", 'Affine', 'affine'): 

54 return lambda x: x * alpha + beta 

55 raise RuntimeError( # pragma: no cover 

56 f"Unknown activation function '{name}'.") 

57 

58 def _f_tanh(self, x): 

59 return numpy.tanh(x) 

60 

61 def _step(self, X, R, B, W, H_0): 

62 h_list = [] 

63 H_t = H_0 

64 for x in numpy.split(X, X.shape[0], axis=0): 

65 H = self.f1(numpy.dot(x, numpy.transpose(W)) + 

66 numpy.dot(H_t, numpy.transpose(R)) + 

67 numpy.add(*numpy.split(B, 2))) 

68 h_list.append(H) 

69 H_t = H 

70 concatenated = numpy.concatenate(h_list) 

71 if self.num_directions == 1: 

72 output = numpy.expand_dims(concatenated, 1) 

73 return output, h_list[-1] 

74 

75 def _run(self, X, W, R, B=None, sequence_lens=None, initial_h=None, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

76 self.num_directions = W.shape[0] 

77 

78 if self.num_directions == 1: 

79 R = numpy.squeeze(R, axis=0) 

80 W = numpy.squeeze(W, axis=0) 

81 if B is not None: 

82 B = numpy.squeeze(B, axis=0) 

83 if sequence_lens is not None: 

84 sequence_lens = numpy.squeeze(sequence_lens, axis=0) 

85 if initial_h is not None: 

86 initial_h = numpy.squeeze(initial_h, axis=0) 

87 

88 hidden_size = R.shape[-1] 

89 batch_size = X.shape[1] 

90 

91 b = (B if B is not None else 

92 numpy.zeros(2 * hidden_size, dtype=X.dtype)) 

93 h_0 = (initial_h if initial_h is not None else 

94 numpy.zeros((batch_size, hidden_size), dtype=X.dtype)) 

95 

96 B = b 

97 H_0 = h_0 

98 else: 

99 raise NotImplementedError( # pragma: no cover 

100 "Unsupported value %r for num_directions and operator %r." % ( 

101 self.num_directions, self.__class__.__name__)) 

102 

103 Y, Y_h = self._step(X, R, B, W, H_0) 

104 # if self.layout == 1: 

105 # #Y = numpy.transpose(Y, [2, 0, 1, 3]) 

106 # Y_h = Y[:, :, -1, :] 

107 

108 return (Y, ) if self.nb_outputs == 1 else (Y, Y_h) 

109 

110 

111class RNN_7(CommonRNN): 

112 

113 atts = { 

114 'activation_alpha': [0.], 

115 'activation_beta': [0.], 

116 'activations': [b'Tanh', b'Tanh'], 

117 'clip': [], 

118 'direction': b'forward', 

119 'hidden_size': None, 

120 } 

121 

122 def __init__(self, onnx_node, desc=None, **options): 

123 CommonRNN.__init__(self, onnx_node, desc=desc, 

124 expected_attributes=RNN_7.atts, 

125 **options) 

126 

127 

128class RNN_14(CommonRNN): 

129 

130 atts = { 

131 'activation_alpha': [0.], 

132 'activation_beta': [0.], 

133 'activations': [b'Tanh', b'Tanh'], 

134 'clip': [], 

135 'direction': b'forward', 

136 'hidden_size': None, 

137 'layout': 0, 

138 } 

139 

140 def __init__(self, onnx_node, desc=None, **options): 

141 CommonRNN.__init__(self, onnx_node, desc=desc, 

142 expected_attributes=RNN_14.atts, 

143 **options) 

144 

145 

146if onnx_opset_version() >= 14: 

147 RNN = RNN_14 

148else: # pragma: no cover 

149 RNN = RNN_7