Coverage for mlprodict/onnx_conv/onnx_ops/onnx_fft.py: 96%
47 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Custom operators for FFT.
4"""
5import numpy
6from skl2onnx.algebra.onnx_operator import OnnxOperator
9class OnnxFFT_1(OnnxOperator):
10 """
11 Defines a custom operator for FFT.
12 """
14 since_version = 1
15 expected_inputs = [('A', 'T'), ('fft_length', numpy.int64)]
16 expected_outputs = [('FFT', 'T')]
17 input_range = [1, 2]
18 output_range = [1, 1]
19 is_deprecated = False
20 domain = 'mlprodict'
21 operator_name = 'FFT'
22 past_version = {}
24 def __init__(self, *args, axis=-1,
25 op_version=None, **kwargs):
26 """
27 :param A: array or OnnxOperatorMixin
28 :param fft_length: (optional) array or OnnxOperatorMixin (args)
29 :param axis: axis
30 :param op_version: opset version
31 :param kwargs: additional parameter
32 """
33 if isinstance(axis, tuple):
34 axis = list(axis)
35 OnnxOperator.__init__(
36 self, *args, axis=axis,
37 op_version=op_version, **kwargs)
40class OnnxFFT2D_1(OnnxOperator):
41 """
42 Defines a custom operator for FFT2D.
43 """
45 since_version = 1
46 expected_inputs = [('A', 'T'), ('fft_length', numpy.int64)]
47 expected_outputs = [('FFT2D', 'T')]
48 input_range = [1, 2]
49 output_range = [1, 1]
50 is_deprecated = False
51 domain = 'mlprodict'
52 operator_name = 'FFT2D'
53 past_version = {}
55 def __init__(self, *args, axes=(-2, -1),
56 op_version=None, **kwargs):
57 """
58 :param A: array or OnnxOperatorMixin
59 :param fft_length: (optional) array or OnnxOperatorMixin (args)
60 :param axes: axes
61 :param op_version: opset version
62 :param kwargs: additional parameter
63 """
64 if isinstance(axes, tuple):
65 axes = list(axes)
66 OnnxOperator.__init__(
67 self, *args, axes=axes,
68 op_version=op_version, **kwargs)
71class OnnxRFFT_1(OnnxOperator):
72 """
73 Defines a custom operator for FFT.
74 """
76 since_version = 1
77 expected_inputs = [('A', 'T'), ('fft_length', numpy.int64)]
78 expected_outputs = [('RFFT', 'T')]
79 input_range = [1, 2]
80 output_range = [1, 1]
81 is_deprecated = False
82 domain = 'mlprodict'
83 operator_name = 'RFFT'
84 past_version = {}
86 def __init__(self, *args, axis=-1,
87 op_version=None, **kwargs):
88 """
89 :param A: array or OnnxOperatorMixin
90 :param fft_length: (optional) array or OnnxOperatorMixin (args)
91 :param axis: axis
92 :param op_version: opset version
93 :param kwargs: additional parameter
94 """
95 if isinstance(axis, tuple):
96 axis = list(axis)
97 OnnxOperator.__init__(
98 self, *args, axis=axis,
99 op_version=op_version, **kwargs)
102OnnxFFT = OnnxFFT_1
103OnnxFFT2D = OnnxFFT2D_1
104OnnxRFFT = OnnxRFFT_1