Coverage for mlprodict/onnxrt/ops_cpu/op_fft.py: 93%
29 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from numpy.fft import fft
9from ._op import OpRun
10from ._new_ops import OperatorSchema
13class FFT(OpRun):
15 atts = {'axis': -1}
17 def __init__(self, onnx_node, desc=None, **options):
18 OpRun.__init__(self, onnx_node, desc=desc,
19 expected_attributes=FFT.atts,
20 **options)
22 def _find_custom_operator_schema(self, op_name):
23 if op_name == "FFT":
24 return FFTSchema()
25 raise RuntimeError( # pragma: no cover
26 f"Unable to find a schema for operator '{op_name}'.")
28 def _run(self, a, fft_length=None, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
29 if fft_length is not None:
30 fft_length = fft_length[0]
31 y = fft(a, fft_length, axis=self.axis)
32 else:
33 y = fft(a, axis=self.axis)
34 if a.dtype in (numpy.float32, numpy.complex64):
35 return (y.astype(numpy.complex64), )
36 if a.dtype in (numpy.float64, numpy.complex128):
37 return (y.astype(numpy.complex128), )
38 raise TypeError( # pragma: no cover
39 f"Unexpected input type: {a.dtype!r}.")
41 def to_python(self, inputs):
42 if len(inputs) == 1:
43 return ('from numpy.fft import fft',
44 f"return fft({inputs[0]}, axis={self.axis})")
45 return ('from numpy.fft import fft',
46 f"return fft({inputs[0]}, {inputs[1]}[0], axis={self.axis})")
49class FFTSchema(OperatorSchema):
50 """
51 Defines a schema for operators added in this package
52 such as @see cl FFT.
53 """
55 def __init__(self):
56 OperatorSchema.__init__(self, 'FFT')
57 self.attributes = FFT.atts