Coverage for mlprodict/onnxrt/ops_cpu/op_stft.py: 99%

112 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9from .op_dft import _cfft as _dft 

10from .op_slice import _slice 

11from .op_concat_from_sequence import _concat_from_sequence 

12 

13 

14def _concat(*args, axis=0): 

15 return numpy.concatenate(tuple(args), axis=axis) 

16 

17 

18def _unsqueeze(a, axis): 

19 return numpy.expand_dims(a, axis=axis) 

20 

21 

22def _switch_axes(a, ax1, ax2): 

23 p = [i for i in range(len(a.shape))] 

24 p[ax1], p[ax2] = p[ax2], p[ax1] 

25 return numpy.transpose(a, p) 

26 

27 

28def _stft(x, fft_length, hop_length, n_frames, window, onesided=False): 

29 """ 

30 Applies one dimensional FFT with window weights. 

31 torch defines the number of frames as: 

32 `n_frames = 1 + (len - n_fft) / hop_length`. 

33 """ 

34 last_axis = len(x.shape) - 1 # op.Sub(op.Shape(op.Shape(x)), one) 

35 axis = [-2] 

36 axis2 = [-3] 

37 window_size = window.shape[0] 

38 

39 # building frames 

40 seq = [] 

41 for fs in range(n_frames): 

42 begin = fs * hop_length 

43 end = begin + window_size 

44 sliced_x = _slice(x, numpy.array([begin]), numpy.array([end]), axis) 

45 

46 # sliced_x may be smaller 

47 new_dim = sliced_x.shape[-2:-1] 

48 missing = (window_size - new_dim[0], ) 

49 new_shape = sliced_x.shape[:-2] + missing + sliced_x.shape[-1:] 

50 cst = numpy.zeros(new_shape, dtype=x.dtype) 

51 pad_sliced_x = _concat(sliced_x, cst, axis=-2) 

52 

53 # same size 

54 un_sliced_x = _unsqueeze(pad_sliced_x, axis2) 

55 seq.append(un_sliced_x) 

56 

57 # concatenation 

58 new_x = _concat_from_sequence(seq, axis=-3, new_axis=0) 

59 

60 # calling weighted dft with weights=window 

61 shape_x = new_x.shape 

62 shape_x_short = shape_x[:-2] 

63 shape_x_short_one = tuple(1 for _ in shape_x_short) + (1, ) 

64 window_shape = shape_x_short_one + (window_size, 1) 

65 weights = numpy.reshape(window, window_shape) 

66 weighted_new_x = new_x * weights 

67 

68 result = _dft(weighted_new_x, fft_length, last_axis, 

69 onesided=onesided) # normalize=False 

70 

71 # final transpose -3, -2 

72 dim = len(result.shape) 

73 ax1 = dim - 3 

74 ax2 = dim - 2 

75 return _switch_axes(result, ax1, ax2) 

76 

77 

78def _istft(x, fft_length, hop_length, window, onesided=False): # pylint: disable=R0914 

79 """ 

80 Reverses of `stft`. 

81 """ 

82 zero = [0] 

83 one = [1] 

84 two = [2] 

85 axisf = [-2] 

86 n_frames = x.shape[-2] 

87 expected_signal_len = fft_length[0] + hop_length * (n_frames - 1) 

88 

89 # building frames 

90 seqr = [] 

91 seqi = [] 

92 seqc = [] 

93 for fs in range(n_frames): 

94 begin = fs 

95 end = fs + 1 

96 frame_x = numpy.squeeze(_slice(x, numpy.array([begin]), 

97 numpy.array([end]), axisf), 

98 axis=axisf[0]) 

99 

100 # ifft 

101 ift = _dft(frame_x, fft_length, axis=-1, onesided=onesided, 

102 normalize=True) 

103 n_dims = len(ift.shape) 

104 

105 # real part 

106 n_dims_1 = n_dims - 1 

107 sliced = _slice(ift, numpy.array(zero), 

108 numpy.array(one), [n_dims_1]) 

109 ytmp = numpy.squeeze(sliced, axis=n_dims_1) 

110 ctmp = numpy.full(ytmp.shape, fill_value=1, dtype=x.dtype) * window 

111 

112 shape_begin = ytmp.shape[:-1] 

113 n_left = fs * hop_length 

114 size = ytmp.shape[-1] 

115 n_right = expected_signal_len - (n_left + size) 

116 

117 left_shape = shape_begin + (n_left, ) 

118 right_shape = shape_begin + (n_right, ) 

119 right = numpy.zeros(right_shape, dtype=x.dtype) 

120 left = numpy.zeros(left_shape, dtype=x.dtype) 

121 

122 y = _concat(left, ytmp, right, axis=-1) 

123 yc = _concat(left, ctmp, right, axis=-1) 

124 

125 # imaginary part 

126 sliced = _slice(ift, numpy.array(one), numpy.array(two), [n_dims_1]) 

127 itmp = numpy.squeeze(sliced, axis=n_dims_1) 

128 yi = _concat(left, itmp, right, axis=-1) 

129 

130 # append 

131 seqr.append(_unsqueeze(y, axis=-1)) 

132 seqi.append(_unsqueeze(yi, axis=-1)) 

133 seqc.append(_unsqueeze(yc, axis=-1)) 

134 

135 # concatenation 

136 redr = _concat_from_sequence(seqr, axis=-1, new_axis=0) 

137 redi = _concat_from_sequence(seqi, axis=-1, new_axis=0) 

138 redc = _concat_from_sequence(seqc, axis=-1, new_axis=0) 

139 

140 # unweight 

141 resr = redr.sum(axis=-1, keepdims=0) 

142 resi = redi.sum(axis=-1, keepdims=0) 

143 resc = redc.sum(axis=-1, keepdims=0) 

144 rr = resr / resc 

145 ri = resi / resc 

146 

147 # Make complex 

148 rr0 = numpy.expand_dims(rr, axis=0) 

149 ri0 = numpy.expand_dims(ri, axis=0) 

150 conc = _concat(rr0, ri0, axis=0) 

151 

152 # rotation, bring first dimension to the last position 

153 result_shape = conc.shape 

154 reshaped_result = conc.reshape((2, -1)) 

155 transposed = numpy.transpose(reshaped_result, (1, 0)) 

156 other_dimensions = result_shape[1:] 

157 final_shape = _concat(other_dimensions, two, axis=0) 

158 final = transposed.reshape(final_shape) 

159 return final 

160 

161 

162class STFT(OpRun): 

163 

164 atts = {'onesided': 1, 'inverse': 0} 

165 

166 def __init__(self, onnx_node, desc=None, **options): 

167 OpRun.__init__(self, onnx_node, desc=desc, 

168 expected_attributes=STFT.atts, 

169 **options) 

170 

171 def _run(self, x, frame_step, window=None, frame_length=None, # pylint: disable=W0221 

172 attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

173 if frame_length is None: 

174 frame_length = x.shape[-2] 

175 hop_length = frame_length // 4 

176 if window is None: 

177 window = numpy.ones(x.shape[-2], dtype=x.dtype) 

178 if self.inverse: 

179 res = _istft(x, [frame_length], hop_length, window, 

180 onesided=self.onesided) 

181 else: 

182 n_frames = 1 # int(1 + (x.shape[-2] - frame_length) / hop_length) 

183 res = _stft(x, [frame_length], hop_length, n_frames, window, 

184 onesided=self.onesided) 

185 return (res.astype(x.dtype), )