Coverage for mlprodict/onnxrt/ops_cpu/op_grid_sample.py: 74%
27 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
9from .op_grid_sample_ import GridSampleFloat, GridSampleDouble # pylint: disable=E0611
12class GridSample(OpRun):
14 atts = {'align_corners': 0,
15 'mode': b'bilinear',
16 'padding_mode': b'zeros'}
18 def __init__(self, onnx_node, desc=None, **options):
19 OpRun.__init__(self, onnx_node, desc=desc,
20 expected_attributes=GridSample.atts,
21 **options)
22 self.rt32_ = None
23 self.rt64_ = None
24 self.rt32_ = GridSampleFloat()
25 self.rt64_ = GridSampleDouble()
26 self.rt32_.init(self.align_corners, self.mode, self.padding_mode)
27 self.rt64_.init(self.align_corners, self.mode, self.padding_mode)
29 def _run(self, X, grid, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
30 if X.dtype == numpy.float32:
31 if self.rt32_ is None:
32 self.rt32_ = GridSampleFloat()
33 self.rt32_.init(self.align_corners,
34 self.mode, self.padding_mode)
35 rt = self.rt32_
36 elif X.dtype == numpy.float32:
37 if self.rt64_ is None:
38 self.rt64_ = GridSampleDouble()
39 self.rt64_.init(self.align_corners,
40 self.mode, self.padding_mode)
41 rt = self.rt64_
42 else:
43 raise TypeError( # pragma: no cover
44 f"Unsupported type {X.dtype!r} for GridSample.")
46 res = rt.compute(X, grid)
47 return (res, )