Coverage for mlprodict/onnxrt/ops_cpu/op_stft.py: 99%
112 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
9from .op_dft import _cfft as _dft
10from .op_slice import _slice
11from .op_concat_from_sequence import _concat_from_sequence
14def _concat(*args, axis=0):
15 return numpy.concatenate(tuple(args), axis=axis)
18def _unsqueeze(a, axis):
19 return numpy.expand_dims(a, axis=axis)
22def _switch_axes(a, ax1, ax2):
23 p = [i for i in range(len(a.shape))]
24 p[ax1], p[ax2] = p[ax2], p[ax1]
25 return numpy.transpose(a, p)
28def _stft(x, fft_length, hop_length, n_frames, window, onesided=False):
29 """
30 Applies one dimensional FFT with window weights.
31 torch defines the number of frames as:
32 `n_frames = 1 + (len - n_fft) / hop_length`.
33 """
34 last_axis = len(x.shape) - 1 # op.Sub(op.Shape(op.Shape(x)), one)
35 axis = [-2]
36 axis2 = [-3]
37 window_size = window.shape[0]
39 # building frames
40 seq = []
41 for fs in range(n_frames):
42 begin = fs * hop_length
43 end = begin + window_size
44 sliced_x = _slice(x, numpy.array([begin]), numpy.array([end]), axis)
46 # sliced_x may be smaller
47 new_dim = sliced_x.shape[-2:-1]
48 missing = (window_size - new_dim[0], )
49 new_shape = sliced_x.shape[:-2] + missing + sliced_x.shape[-1:]
50 cst = numpy.zeros(new_shape, dtype=x.dtype)
51 pad_sliced_x = _concat(sliced_x, cst, axis=-2)
53 # same size
54 un_sliced_x = _unsqueeze(pad_sliced_x, axis2)
55 seq.append(un_sliced_x)
57 # concatenation
58 new_x = _concat_from_sequence(seq, axis=-3, new_axis=0)
60 # calling weighted dft with weights=window
61 shape_x = new_x.shape
62 shape_x_short = shape_x[:-2]
63 shape_x_short_one = tuple(1 for _ in shape_x_short) + (1, )
64 window_shape = shape_x_short_one + (window_size, 1)
65 weights = numpy.reshape(window, window_shape)
66 weighted_new_x = new_x * weights
68 result = _dft(weighted_new_x, fft_length, last_axis,
69 onesided=onesided) # normalize=False
71 # final transpose -3, -2
72 dim = len(result.shape)
73 ax1 = dim - 3
74 ax2 = dim - 2
75 return _switch_axes(result, ax1, ax2)
78def _istft(x, fft_length, hop_length, window, onesided=False): # pylint: disable=R0914
79 """
80 Reverses of `stft`.
81 """
82 zero = [0]
83 one = [1]
84 two = [2]
85 axisf = [-2]
86 n_frames = x.shape[-2]
87 expected_signal_len = fft_length[0] + hop_length * (n_frames - 1)
89 # building frames
90 seqr = []
91 seqi = []
92 seqc = []
93 for fs in range(n_frames):
94 begin = fs
95 end = fs + 1
96 frame_x = numpy.squeeze(_slice(x, numpy.array([begin]),
97 numpy.array([end]), axisf),
98 axis=axisf[0])
100 # ifft
101 ift = _dft(frame_x, fft_length, axis=-1, onesided=onesided,
102 normalize=True)
103 n_dims = len(ift.shape)
105 # real part
106 n_dims_1 = n_dims - 1
107 sliced = _slice(ift, numpy.array(zero),
108 numpy.array(one), [n_dims_1])
109 ytmp = numpy.squeeze(sliced, axis=n_dims_1)
110 ctmp = numpy.full(ytmp.shape, fill_value=1, dtype=x.dtype) * window
112 shape_begin = ytmp.shape[:-1]
113 n_left = fs * hop_length
114 size = ytmp.shape[-1]
115 n_right = expected_signal_len - (n_left + size)
117 left_shape = shape_begin + (n_left, )
118 right_shape = shape_begin + (n_right, )
119 right = numpy.zeros(right_shape, dtype=x.dtype)
120 left = numpy.zeros(left_shape, dtype=x.dtype)
122 y = _concat(left, ytmp, right, axis=-1)
123 yc = _concat(left, ctmp, right, axis=-1)
125 # imaginary part
126 sliced = _slice(ift, numpy.array(one), numpy.array(two), [n_dims_1])
127 itmp = numpy.squeeze(sliced, axis=n_dims_1)
128 yi = _concat(left, itmp, right, axis=-1)
130 # append
131 seqr.append(_unsqueeze(y, axis=-1))
132 seqi.append(_unsqueeze(yi, axis=-1))
133 seqc.append(_unsqueeze(yc, axis=-1))
135 # concatenation
136 redr = _concat_from_sequence(seqr, axis=-1, new_axis=0)
137 redi = _concat_from_sequence(seqi, axis=-1, new_axis=0)
138 redc = _concat_from_sequence(seqc, axis=-1, new_axis=0)
140 # unweight
141 resr = redr.sum(axis=-1, keepdims=0)
142 resi = redi.sum(axis=-1, keepdims=0)
143 resc = redc.sum(axis=-1, keepdims=0)
144 rr = resr / resc
145 ri = resi / resc
147 # Make complex
148 rr0 = numpy.expand_dims(rr, axis=0)
149 ri0 = numpy.expand_dims(ri, axis=0)
150 conc = _concat(rr0, ri0, axis=0)
152 # rotation, bring first dimension to the last position
153 result_shape = conc.shape
154 reshaped_result = conc.reshape((2, -1))
155 transposed = numpy.transpose(reshaped_result, (1, 0))
156 other_dimensions = result_shape[1:]
157 final_shape = _concat(other_dimensions, two, axis=0)
158 final = transposed.reshape(final_shape)
159 return final
162class STFT(OpRun):
164 atts = {'onesided': 1, 'inverse': 0}
166 def __init__(self, onnx_node, desc=None, **options):
167 OpRun.__init__(self, onnx_node, desc=desc,
168 expected_attributes=STFT.atts,
169 **options)
171 def _run(self, x, frame_step, window=None, frame_length=None, # pylint: disable=W0221
172 attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
173 if frame_length is None:
174 frame_length = x.shape[-2]
175 hop_length = frame_length // 4
176 if window is None:
177 window = numpy.ones(x.shape[-2], dtype=x.dtype)
178 if self.inverse:
179 res = _istft(x, [frame_length], hop_length, window,
180 onesided=self.onesided)
181 else:
182 n_frames = 1 # int(1 + (x.shape[-2] - frame_length) / hop_length)
183 res = _stft(x, [frame_length], hop_length, n_frames, window,
184 onesided=self.onesided)
185 return (res.astype(x.dtype), )