Coverage for mlprodict/onnxrt/ops_cpu/op_fft2d.py: 91%
34 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from numpy.fft import fft2
9from ._op import OpRun
10from ._new_ops import OperatorSchema
13class FFT2D(OpRun):
15 atts = {'axes': [-2, -1]}
17 def __init__(self, onnx_node, desc=None, **options):
18 OpRun.__init__(self, onnx_node, desc=desc,
19 expected_attributes=FFT2D.atts,
20 **options)
21 if self.axes is not None:
22 self.axes = tuple(self.axes)
23 if len(self.axes) != 2:
24 raise ValueError( # pragma: no cover
25 f"axes must a set of 1 integers not {self.axes!r}.")
27 def _find_custom_operator_schema(self, op_name):
28 if op_name == "FFT2D":
29 return FFT2DSchema()
30 raise RuntimeError( # pragma: no cover
31 f"Unable to find a schema for operator '{op_name}'.")
33 def _run(self, a, fft_length=None, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
34 if fft_length is None:
35 y = fft2(a, axes=self.axes)
36 else:
37 y = fft2(a, tuple(fft_length), axes=self.axes)
38 if a.dtype in (numpy.float32, numpy.complex64):
39 return (y.astype(numpy.complex64), )
40 if a.dtype in (numpy.float64, numpy.complex128):
41 return (y.astype(numpy.complex128), )
42 raise TypeError( # pragma: no cover
43 f"Unexpected input type: {a.dtype!r}.")
45 def to_python(self, inputs):
46 if self.axes is not None:
47 axes = tuple(self.axes)
48 else:
49 axes = None
50 if len(inputs) == 1:
51 return ('from numpy.fft import fft2',
52 f"return fft2({inputs[0]}, axes={axes})")
53 return ('from numpy.fft import fft2',
54 f"return fft2({inputs[0]}, tuple({inputs[1]}), axes={axes})")
57class FFT2DSchema(OperatorSchema):
58 """
59 Defines a schema for operators added in this package
60 such as @see cl FFT.
61 """
63 def __init__(self):
64 OperatorSchema.__init__(self, 'FFT2D')
65 self.attributes = FFT2D.atts