Coverage for mlprodict/onnxrt/ops_cpu/op_roi_align.py: 78%
23 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
9from .op_roi_align_ import RoiAlignFloat, RoiAlignDouble # pylint: disable=E0611
12class RoiAlign(OpRun):
14 atts = {'coordinate_transformation_mode': b'half_pixel',
15 'mode': b'avg',
16 'output_height': 1,
17 'output_width': 1,
18 'sampling_ratio': 0,
19 'spatial_scale': 1.}
21 def __init__(self, onnx_node, desc=None, **options):
22 OpRun.__init__(self, onnx_node, desc=desc,
23 expected_attributes=RoiAlign.atts,
24 **options)
25 self.rt32_ = None
26 self.rt64_ = None
28 def _run(self, X, rois, batch_indices, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
29 if X.dtype == numpy.float32:
30 if self.rt32_ is None:
31 self.rt32_ = RoiAlignFloat()
32 self.rt32_.init(
33 self.coordinate_transformation_mode.decode('ascii'),
34 self.mode.decode('ascii'), self.output_height,
35 self.output_width, self.sampling_ratio, self.spatial_scale)
36 rt = self.rt32_
37 elif X.dtype == numpy.float64:
38 if self.rt64_ is None:
39 self.rt64_ = RoiAlignDouble()
40 self.rt64_.init(
41 self.coordinate_transformation_mode.decode('ascii'),
42 self.mode.decode('ascii'), self.output_height,
43 self.output_width, self.sampling_ratio, self.spatial_scale)
44 rt = self.rt64_
45 else:
46 raise TypeError( # pragma: no cover
47 f"Unexpected type {X.dtype!r} for X.")
49 res = rt.compute(X, rois, batch_indices)
50 return (res, )