Coverage for mlprodict/onnxrt/ops_cpu/op_gather.py: 90%
20 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
9from .op_gather_ import ( # pylint: disable=E0611,E0401
10 GatherFloat, GatherDouble, GatherInt64)
13class Gather(OpRun):
15 atts = {'axis': 0}
17 def __init__(self, onnx_node, desc=None, **options):
18 OpRun.__init__(self, onnx_node, desc=desc,
19 expected_attributes=Gather.atts,
20 **options)
21 self.rt_ = {
22 'float32': GatherFloat(self.axis),
23 'float64': GatherDouble(self.axis),
24 'int64': GatherInt64(self.axis)}
26 def _run(self, x, indices, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
27 if not x.flags['C_CONTIGUOUS']:
28 x = numpy.ascontiguousarray(x)
29 if not indices.flags['C_CONTIGUOUS']:
30 indices = indices.ascontiguousarray()
31 if indices.size == 0:
32 return (numpy.empty((0, ), dtype=x.dtype), )
33 try:
34 return (self.rt_[str(x.dtype)].compute(x, indices), )
35 except (KeyError, ValueError):
36 return (numpy.take(x, indices, axis=self.axis), )