Coverage for mlprodict/onnxrt/ops_cpu/op_non_max_suppression.py: 85%

20 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9from .op_non_max_suppression_ import RuntimeNonMaxSuppression # pylint: disable=E0611 

10 

11 

12class NonMaxSuppression(OpRun): 

13 

14 atts = {'center_point_box': 0} 

15 

16 def __init__(self, onnx_node, desc=None, **options): 

17 OpRun.__init__(self, onnx_node, desc=desc, 

18 expected_attributes=NonMaxSuppression.atts, 

19 **options) 

20 self.inst = RuntimeNonMaxSuppression() 

21 self.inst.init(self.center_point_box) 

22 

23 def _run(self, boxes, scores, max_output_boxes_per_class=None, # pylint: disable=W0221 

24 iou_threshold=None, score_threshold=None, 

25 attributes=None, verbose=0, fLOG=None): 

26 if max_output_boxes_per_class is None: 

27 max_output_boxes_per_class = numpy.array([], dtype=numpy.int64) 

28 if iou_threshold is None: 

29 iou_threshold = numpy.array([], dtype=numpy.float32) 

30 if score_threshold is None: 

31 score_threshold = numpy.array([], dtype=numpy.float32) 

32 res = self.inst.compute(boxes, scores, max_output_boxes_per_class, 

33 iou_threshold, score_threshold) 

34 res = res.reshape((-1, 3)) 

35 return (res, )