Coverage for mlprodict/onnxrt/ops_cpu/op_non_max_suppression.py: 85%
20 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
9from .op_non_max_suppression_ import RuntimeNonMaxSuppression # pylint: disable=E0611
12class NonMaxSuppression(OpRun):
14 atts = {'center_point_box': 0}
16 def __init__(self, onnx_node, desc=None, **options):
17 OpRun.__init__(self, onnx_node, desc=desc,
18 expected_attributes=NonMaxSuppression.atts,
19 **options)
20 self.inst = RuntimeNonMaxSuppression()
21 self.inst.init(self.center_point_box)
23 def _run(self, boxes, scores, max_output_boxes_per_class=None, # pylint: disable=W0221
24 iou_threshold=None, score_threshold=None,
25 attributes=None, verbose=0, fLOG=None):
26 if max_output_boxes_per_class is None:
27 max_output_boxes_per_class = numpy.array([], dtype=numpy.int64)
28 if iou_threshold is None:
29 iou_threshold = numpy.array([], dtype=numpy.float32)
30 if score_threshold is None:
31 score_threshold = numpy.array([], dtype=numpy.float32)
32 res = self.inst.compute(boxes, scores, max_output_boxes_per_class,
33 iou_threshold, score_threshold)
34 res = res.reshape((-1, 3))
35 return (res, )