Coverage for mlprodict/onnxrt/ops_cpu/op_concat_from_sequence.py: 88%
17 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
11def _concat_from_sequence(seq, axis, new_axis=0):
12 if new_axis == 1:
13 seq2 = [s[..., numpy.newaxis] for s in seq]
14 res = numpy.concatenate(seq2, axis=-1)
15 else:
16 res = numpy.concatenate(seq, axis=axis)
17 return res
20class ConcatFromSequence(OpRun):
22 atts = {'axis': 0, 'new_axis': 0}
24 def __init__(self, onnx_node, desc=None, **options):
25 OpRun.__init__(self, onnx_node, desc=desc,
26 expected_attributes=ConcatFromSequence.atts,
27 **options)
29 def _run(self, seq, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
30 if seq is None:
31 raise RuntimeError( # pragma: no cover
32 "A sequence cannot be null.")
33 res = _concat_from_sequence(seq, self.axis, new_axis=self.new_axis)
34 return (res, )