Coverage for mlprodict/onnx_conv/onnx_ops/onnx_fft.py: 96%

47 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Custom operators for FFT. 

4""" 

5import numpy 

6from skl2onnx.algebra.onnx_operator import OnnxOperator 

7 

8 

9class OnnxFFT_1(OnnxOperator): 

10 """ 

11 Defines a custom operator for FFT. 

12 """ 

13 

14 since_version = 1 

15 expected_inputs = [('A', 'T'), ('fft_length', numpy.int64)] 

16 expected_outputs = [('FFT', 'T')] 

17 input_range = [1, 2] 

18 output_range = [1, 1] 

19 is_deprecated = False 

20 domain = 'mlprodict' 

21 operator_name = 'FFT' 

22 past_version = {} 

23 

24 def __init__(self, *args, axis=-1, 

25 op_version=None, **kwargs): 

26 """ 

27 :param A: array or OnnxOperatorMixin 

28 :param fft_length: (optional) array or OnnxOperatorMixin (args) 

29 :param axis: axis 

30 :param op_version: opset version 

31 :param kwargs: additional parameter 

32 """ 

33 if isinstance(axis, tuple): 

34 axis = list(axis) 

35 OnnxOperator.__init__( 

36 self, *args, axis=axis, 

37 op_version=op_version, **kwargs) 

38 

39 

40class OnnxFFT2D_1(OnnxOperator): 

41 """ 

42 Defines a custom operator for FFT2D. 

43 """ 

44 

45 since_version = 1 

46 expected_inputs = [('A', 'T'), ('fft_length', numpy.int64)] 

47 expected_outputs = [('FFT2D', 'T')] 

48 input_range = [1, 2] 

49 output_range = [1, 1] 

50 is_deprecated = False 

51 domain = 'mlprodict' 

52 operator_name = 'FFT2D' 

53 past_version = {} 

54 

55 def __init__(self, *args, axes=(-2, -1), 

56 op_version=None, **kwargs): 

57 """ 

58 :param A: array or OnnxOperatorMixin 

59 :param fft_length: (optional) array or OnnxOperatorMixin (args) 

60 :param axes: axes 

61 :param op_version: opset version 

62 :param kwargs: additional parameter 

63 """ 

64 if isinstance(axes, tuple): 

65 axes = list(axes) 

66 OnnxOperator.__init__( 

67 self, *args, axes=axes, 

68 op_version=op_version, **kwargs) 

69 

70 

71class OnnxRFFT_1(OnnxOperator): 

72 """ 

73 Defines a custom operator for FFT. 

74 """ 

75 

76 since_version = 1 

77 expected_inputs = [('A', 'T'), ('fft_length', numpy.int64)] 

78 expected_outputs = [('RFFT', 'T')] 

79 input_range = [1, 2] 

80 output_range = [1, 1] 

81 is_deprecated = False 

82 domain = 'mlprodict' 

83 operator_name = 'RFFT' 

84 past_version = {} 

85 

86 def __init__(self, *args, axis=-1, 

87 op_version=None, **kwargs): 

88 """ 

89 :param A: array or OnnxOperatorMixin 

90 :param fft_length: (optional) array or OnnxOperatorMixin (args) 

91 :param axis: axis 

92 :param op_version: opset version 

93 :param kwargs: additional parameter 

94 """ 

95 if isinstance(axis, tuple): 

96 axis = list(axis) 

97 OnnxOperator.__init__( 

98 self, *args, axis=axis, 

99 op_version=op_version, **kwargs) 

100 

101 

102OnnxFFT = OnnxFFT_1 

103OnnxFFT2D = OnnxFFT2D_1 

104OnnxRFFT = OnnxRFFT_1