Coverage for mlprodict/onnxrt/ops_cpu/op_fft2d.py: 91%

34 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from numpy.fft import fft2 

9from ._op import OpRun 

10from ._new_ops import OperatorSchema 

11 

12 

13class FFT2D(OpRun): 

14 

15 atts = {'axes': [-2, -1]} 

16 

17 def __init__(self, onnx_node, desc=None, **options): 

18 OpRun.__init__(self, onnx_node, desc=desc, 

19 expected_attributes=FFT2D.atts, 

20 **options) 

21 if self.axes is not None: 

22 self.axes = tuple(self.axes) 

23 if len(self.axes) != 2: 

24 raise ValueError( # pragma: no cover 

25 f"axes must a set of 1 integers not {self.axes!r}.") 

26 

27 def _find_custom_operator_schema(self, op_name): 

28 if op_name == "FFT2D": 

29 return FFT2DSchema() 

30 raise RuntimeError( # pragma: no cover 

31 f"Unable to find a schema for operator '{op_name}'.") 

32 

33 def _run(self, a, fft_length=None, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

34 if fft_length is None: 

35 y = fft2(a, axes=self.axes) 

36 else: 

37 y = fft2(a, tuple(fft_length), axes=self.axes) 

38 if a.dtype in (numpy.float32, numpy.complex64): 

39 return (y.astype(numpy.complex64), ) 

40 if a.dtype in (numpy.float64, numpy.complex128): 

41 return (y.astype(numpy.complex128), ) 

42 raise TypeError( # pragma: no cover 

43 f"Unexpected input type: {a.dtype!r}.") 

44 

45 def to_python(self, inputs): 

46 if self.axes is not None: 

47 axes = tuple(self.axes) 

48 else: 

49 axes = None 

50 if len(inputs) == 1: 

51 return ('from numpy.fft import fft2', 

52 f"return fft2({inputs[0]}, axes={axes})") 

53 return ('from numpy.fft import fft2', 

54 f"return fft2({inputs[0]}, tuple({inputs[1]}), axes={axes})") 

55 

56 

57class FFT2DSchema(OperatorSchema): 

58 """ 

59 Defines a schema for operators added in this package 

60 such as @see cl FFT. 

61 """ 

62 

63 def __init__(self): 

64 OperatorSchema.__init__(self, 'FFT2D') 

65 self.attributes = FFT2D.atts