Coverage for mlprodict/onnxrt/ops_cpu/op_lstm.py: 96%

78 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9 

10 

11class CommonLSTM(OpRun): 

12 

13 def __init__(self, onnx_node, expected_attributes=None, desc=None, 

14 **options): 

15 OpRun.__init__(self, onnx_node, desc=desc, 

16 expected_attributes=expected_attributes, 

17 **options) 

18 self.nb_outputs = len(onnx_node.output) 

19 self.number_of_gates = 3 

20 

21 def f(self, x): 

22 return 1 / (1 + numpy.exp(-x)) 

23 

24 def g(self, x): 

25 return numpy.tanh(x) 

26 

27 def h(self, x): 

28 return numpy.tanh(x) 

29 

30 def _step(self, X, R, B, W, H_0, C_0, P): 

31 seq_length = X.shape[0] 

32 hidden_size = H_0.shape[-1] 

33 batch_size = X.shape[1] 

34 

35 Y = numpy.empty( 

36 [seq_length, self.num_directions, batch_size, hidden_size]) 

37 h_list = [] 

38 

39 [p_i, p_o, p_f] = numpy.split(P, 3) # pylint: disable=W0632 

40 H_t = H_0 

41 C_t = C_0 

42 for x in numpy.split(X, X.shape[0], axis=0): 

43 gates = numpy.dot(x, numpy.transpose(W)) + numpy.dot(H_t, numpy.transpose(R)) + numpy.add( 

44 *numpy.split(B, 2)) 

45 i, o, f, c = numpy.split(gates, 4, -1) # pylint: disable=W0632 

46 i = self.f(i + p_i * C_t) 

47 f = self.f(f + p_f * C_t) 

48 c = self.g(c) 

49 C = f * C_t + i * c 

50 o = self.f(o + p_o * C) 

51 H = o * self.h(C) 

52 h_list.append(H) 

53 H_t = H 

54 C_t = C 

55 

56 concatenated = numpy.concatenate(h_list) 

57 if self.num_directions == 1: 

58 Y[:, 0, :, :] = concatenated 

59 

60 if self.layout == 0: 

61 Y_h = Y[-1] 

62 else: 

63 Y = numpy.transpose(Y, [2, 0, 1, 3]) 

64 Y_h = Y[:, :, -1, :] 

65 

66 return Y, Y_h 

67 

68 def _run(self, X, W, R, B=None, sequence_lens=None, # pylint: disable=W0221 

69 initial_h=None, initial_c=None, P=None, 

70 attributes=None, verbose=0, fLOG=None): 

71 number_of_gates = 4 

72 number_of_peepholes = 3 

73 

74 self.num_directions = W.shape[0] 

75 

76 if self.num_directions == 1: 

77 R = numpy.squeeze(R, axis=0) 

78 W = numpy.squeeze(W, axis=0) 

79 if B is not None: 

80 B = numpy.squeeze(B, axis=0) 

81 if sequence_lens is not None: 

82 sequence_lens = numpy.squeeze(sequence_lens, axis=0) 

83 if initial_h is not None: 

84 initial_h = numpy.squeeze(initial_h, axis=0) 

85 if initial_c is not None: 

86 initial_c = numpy.squeeze(initial_c, axis=0) 

87 if P is not None: 

88 P = numpy.squeeze(P, axis=0) 

89 

90 hidden_size = R.shape[-1] 

91 batch_size = X.shape[1] 

92 

93 if self.layout != 0: 

94 X = numpy.swapaxes(X, 0, 1) 

95 if B is None: 

96 B = numpy.zeros(2 * number_of_gates * 

97 hidden_size, dtype=numpy.float32) 

98 if P is None: 

99 P = numpy.zeros(number_of_peepholes * 

100 hidden_size, dtype=numpy.float32) 

101 if initial_h is None: 

102 initial_h = numpy.zeros( 

103 (batch_size, hidden_size), dtype=numpy.float32) 

104 if initial_c is None: 

105 initial_c = numpy.zeros( 

106 (batch_size, hidden_size), dtype=numpy.float32) 

107 else: 

108 raise NotImplementedError( # pragma: no cover 

109 "Unsupported value %r for num_directions and operator %r." % ( 

110 self.num_directions, self.__class__.__name__)) 

111 

112 Y, Y_h = self._step(X, R, B, W, initial_h, initial_c, P) 

113 

114 return (Y, ) if self.nb_outputs == 1 else (Y, Y_h) 

115 

116 

117class LSTM(CommonLSTM): 

118 

119 atts = { 

120 'activation_alpha': [0.], 

121 'activation_beta': [0.], 

122 'activations': [b'Tanh', b'Tanh'], 

123 'clip': [], 

124 'direction': b'forward', 

125 'hidden_size': None, 

126 'layout': 0, 

127 'input_forget': 0, 

128 } 

129 

130 def __init__(self, onnx_node, desc=None, **options): 

131 CommonLSTM.__init__(self, onnx_node, desc=desc, 

132 expected_attributes=LSTM.atts, 

133 **options)