Coverage for mlprodict/onnxrt/ops_cpu/op_dft.py: 82%
71 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
11def _fft(x, fft_length, axis):
12 if fft_length is None:
13 fft_length = [x.shape[axis]]
14 ft = numpy.fft.fft(x, fft_length[0], axis=axis)
15 r = numpy.real(ft)
16 i = numpy.imag(ft)
17 merged = numpy.vstack([r[numpy.newaxis, ...], i[numpy.newaxis, ...]])
18 perm = numpy.arange(len(merged.shape))
19 perm[:-1] = perm[1:]
20 perm[-1] = 0
21 tr = numpy.transpose(merged, list(perm))
22 if tr.shape[-1] != 2:
23 raise RuntimeError(
24 f"Unexpected shape {tr.shape}, x.shape={x.shape} "
25 f"fft_length={fft_length}.")
26 return tr
29def _cfft(x, fft_length, axis, onesided=False, normalize=False):
30 # if normalize:
31 # raise NotImplementedError()
32 if x.shape[-1] == 1:
33 tmp = x
34 else:
35 slices = [slice(0, x) for x in x.shape]
36 slices[-1] = slice(0, x.shape[-1], 2)
37 real = x[tuple(slices)]
38 slices[-1] = slice(1, x.shape[-1], 2)
39 imag = x[tuple(slices)]
40 tmp = real + 1j * imag
41 c = numpy.squeeze(tmp, -1)
42 res = _fft(c, fft_length, axis=axis)
43 if onesided:
44 slices = [slice(0, a) for a in res.shape]
45 slices[axis] = slice(0, res.shape[axis] // 2 + 1)
46 return res[tuple(slices)]
47 return res
50def _ifft(x, fft_length, axis=-1, onesided=False):
51 ft = numpy.fft.ifft(x, fft_length[0], axis=axis)
52 r = numpy.real(ft)
53 i = numpy.imag(ft)
54 merged = numpy.vstack([r[numpy.newaxis, ...], i[numpy.newaxis, ...]])
55 perm = numpy.arange(len(merged.shape))
56 perm[:-1] = perm[1:]
57 perm[-1] = 0
58 tr = numpy.transpose(merged, list(perm))
59 if tr.shape[-1] != 2:
60 raise RuntimeError(
61 f"Unexpected shape {tr.shape}, x.shape={x.shape} "
62 f"fft_length={fft_length}.")
63 if onesided:
64 slices = [slice() for a in tr.shape]
65 slices[axis] = slice(0, tr.shape[axis] // 2 + 1)
66 return tr[tuple(slices)]
67 return tr
70def _cifft(x, fft_length, axis=-1, onesided=False):
71 if x.shape[-1] == 1:
72 tmp = x
73 else:
74 slices = [slice(0, x) for x in x.shape]
75 slices[-1] = slice(0, x.shape[-1], 2)
76 real = x[tuple(slices)]
77 slices[-1] = slice(1, x.shape[-1], 2)
78 imag = x[tuple(slices)]
79 tmp = real + 1j * imag
80 c = numpy.squeeze(tmp, -1)
81 return _ifft(c, fft_length, axis=axis, onesided=onesided)
84class DFT(OpRun):
86 atts = {'axis': 1, 'inverse': 0, 'onesided': 0}
88 def __init__(self, onnx_node, desc=None, **options):
89 OpRun.__init__(self, onnx_node, desc=desc,
90 expected_attributes=DFT.atts,
91 **options)
93 def _run(self, x, dft_length=None, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
94 if dft_length is None:
95 dft_length = numpy.array([x.shape[self.axis]], dtype=numpy.int64)
96 if self.inverse:
97 res = _cifft(x, dft_length, axis=self.axis, onesided=self.onesided)
98 else:
99 res = _cfft(x, dft_length, axis=self.axis, onesided=self.onesided)
100 return (res.astype(x.dtype), )