Coverage for mlprodict/onnxrt/ops_cpu/op_dft.py: 82%

71 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9 

10 

11def _fft(x, fft_length, axis): 

12 if fft_length is None: 

13 fft_length = [x.shape[axis]] 

14 ft = numpy.fft.fft(x, fft_length[0], axis=axis) 

15 r = numpy.real(ft) 

16 i = numpy.imag(ft) 

17 merged = numpy.vstack([r[numpy.newaxis, ...], i[numpy.newaxis, ...]]) 

18 perm = numpy.arange(len(merged.shape)) 

19 perm[:-1] = perm[1:] 

20 perm[-1] = 0 

21 tr = numpy.transpose(merged, list(perm)) 

22 if tr.shape[-1] != 2: 

23 raise RuntimeError( 

24 f"Unexpected shape {tr.shape}, x.shape={x.shape} " 

25 f"fft_length={fft_length}.") 

26 return tr 

27 

28 

29def _cfft(x, fft_length, axis, onesided=False, normalize=False): 

30 # if normalize: 

31 # raise NotImplementedError() 

32 if x.shape[-1] == 1: 

33 tmp = x 

34 else: 

35 slices = [slice(0, x) for x in x.shape] 

36 slices[-1] = slice(0, x.shape[-1], 2) 

37 real = x[tuple(slices)] 

38 slices[-1] = slice(1, x.shape[-1], 2) 

39 imag = x[tuple(slices)] 

40 tmp = real + 1j * imag 

41 c = numpy.squeeze(tmp, -1) 

42 res = _fft(c, fft_length, axis=axis) 

43 if onesided: 

44 slices = [slice(0, a) for a in res.shape] 

45 slices[axis] = slice(0, res.shape[axis] // 2 + 1) 

46 return res[tuple(slices)] 

47 return res 

48 

49 

50def _ifft(x, fft_length, axis=-1, onesided=False): 

51 ft = numpy.fft.ifft(x, fft_length[0], axis=axis) 

52 r = numpy.real(ft) 

53 i = numpy.imag(ft) 

54 merged = numpy.vstack([r[numpy.newaxis, ...], i[numpy.newaxis, ...]]) 

55 perm = numpy.arange(len(merged.shape)) 

56 perm[:-1] = perm[1:] 

57 perm[-1] = 0 

58 tr = numpy.transpose(merged, list(perm)) 

59 if tr.shape[-1] != 2: 

60 raise RuntimeError( 

61 f"Unexpected shape {tr.shape}, x.shape={x.shape} " 

62 f"fft_length={fft_length}.") 

63 if onesided: 

64 slices = [slice() for a in tr.shape] 

65 slices[axis] = slice(0, tr.shape[axis] // 2 + 1) 

66 return tr[tuple(slices)] 

67 return tr 

68 

69 

70def _cifft(x, fft_length, axis=-1, onesided=False): 

71 if x.shape[-1] == 1: 

72 tmp = x 

73 else: 

74 slices = [slice(0, x) for x in x.shape] 

75 slices[-1] = slice(0, x.shape[-1], 2) 

76 real = x[tuple(slices)] 

77 slices[-1] = slice(1, x.shape[-1], 2) 

78 imag = x[tuple(slices)] 

79 tmp = real + 1j * imag 

80 c = numpy.squeeze(tmp, -1) 

81 return _ifft(c, fft_length, axis=axis, onesided=onesided) 

82 

83 

84class DFT(OpRun): 

85 

86 atts = {'axis': 1, 'inverse': 0, 'onesided': 0} 

87 

88 def __init__(self, onnx_node, desc=None, **options): 

89 OpRun.__init__(self, onnx_node, desc=desc, 

90 expected_attributes=DFT.atts, 

91 **options) 

92 

93 def _run(self, x, dft_length=None, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

94 if dft_length is None: 

95 dft_length = numpy.array([x.shape[self.axis]], dtype=numpy.int64) 

96 if self.inverse: 

97 res = _cifft(x, dft_length, axis=self.axis, onesided=self.onesided) 

98 else: 

99 res = _cfft(x, dft_length, axis=self.axis, onesided=self.onesided) 

100 return (res.astype(x.dtype), )