Coverage for mlprodict/onnxrt/ops_cpu/op_array_feature_extractor.py: 100%
41 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
9from ._op_onnx_numpy import ( # pylint: disable=E0611,E0401
10 array_feature_extractor_double,
11 array_feature_extractor_int64,
12 array_feature_extractor_float)
15def _array_feature_extrator(data, indices):
16 """
17 Implementation of operator *ArrayFeatureExtractor*
18 with :epkg:`numpy`.
19 """
20 if len(indices.shape) == 2 and indices.shape[0] == 1:
21 index = indices.ravel().tolist()
22 add = len(index)
23 elif len(indices.shape) == 1:
24 index = indices.tolist()
25 add = len(index)
26 else:
27 add = 1
28 for s in indices.shape:
29 add *= s
30 index = indices.ravel().tolist()
31 if len(data.shape) == 1:
32 new_shape = (1, add)
33 else:
34 new_shape = list(data.shape[:-1]) + [add]
35 tem = data[..., index]
36 res = tem.reshape(new_shape)
37 return res
40def sizeof_dtype(dty):
41 if dty == numpy.float64:
42 return 8
43 if dty == numpy.float32:
44 return 4
45 if dty == numpy.int64:
46 return 8
47 raise ValueError(
48 f"Unable to get bytes size for type {numpy.dtype}.")
51class ArrayFeatureExtractor(OpRun):
53 def __init__(self, onnx_node, desc=None, **options):
54 OpRun.__init__(self, onnx_node, desc=desc,
55 **options)
57 def _run(self, data, indices, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
58 """
59 Runtime for operator *ArrayFeatureExtractor*.
61 .. warning::
62 ONNX specifications may be imprecise in some cases.
63 When the input data is a vector (one dimension),
64 the output has still two like a matrix with one row.
65 The implementation follows what :epkg:`onnxruntime` does in
66 `array_feature_extractor.cc
67 <https://github.com/microsoft/onnxruntime/blob/master/
68 onnxruntime/core/providers/cpu/ml/array_feature_extractor.cc#L84>`_.
69 """
70 if data.dtype == numpy.float64:
71 res = array_feature_extractor_double(data, indices)
72 elif data.dtype == numpy.float32:
73 res = array_feature_extractor_float(data, indices)
74 elif data.dtype == numpy.int64:
75 res = array_feature_extractor_int64(data, indices)
76 else:
77 # for strings, still not C++
78 res = _array_feature_extrator(data, indices)
79 return (res, )