Coverage for mlprodict/onnxrt/ops_cpu/op_label_encoder.py: 94%
53 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
11class LabelEncoder(OpRun):
13 atts = {'default_float': 0., 'default_int64': -1,
14 'default_string': b'',
15 'keys_floats': numpy.empty(0, dtype=numpy.float32),
16 'keys_int64s': numpy.empty(0, dtype=numpy.int64),
17 'keys_strings': numpy.empty(0, dtype=numpy.str_),
18 'values_floats': numpy.empty(0, dtype=numpy.float32),
19 'values_int64s': numpy.empty(0, dtype=numpy.int64),
20 'values_strings': numpy.empty(0, dtype=numpy.str_),
21 }
23 def __init__(self, onnx_node, desc=None, **options):
24 OpRun.__init__(self, onnx_node, desc=desc,
25 expected_attributes=LabelEncoder.atts,
26 **options)
27 if len(self.keys_floats) > 0 and len(self.values_floats) > 0:
28 self.classes_ = {k: v for k, v in zip(
29 self.keys_floats, self.values_floats)}
30 self.default_ = self.default_float
31 self.dtype_ = numpy.float32
32 elif len(self.keys_floats) > 0 and len(self.values_int64s) > 0:
33 self.classes_ = {k: v for k, v in zip(
34 self.keys_floats, self.values_int64s)}
35 self.default_ = self.default_int64
36 self.dtype_ = numpy.int64
37 elif len(self.keys_int64s) > 0 and len(self.values_int64s) > 0:
38 self.classes_ = {k: v for k, v in zip(
39 self.keys_int64s, self.values_int64s)}
40 self.default_ = self.default_int64
41 self.dtype_ = numpy.int64
42 elif len(self.keys_int64s) > 0 and len(self.values_floats) > 0:
43 self.classes_ = {k: v for k, v in zip(
44 self.keys_int64s, self.values_floats)}
45 self.default_ = self.default_int64
46 self.dtype_ = numpy.float32
47 elif len(self.keys_strings) > 0 and len(self.values_floats) > 0:
48 self.classes_ = {k.decode('utf-8'): v for k, v in zip(
49 self.keys_strings, self.values_floats)}
50 self.default_ = self.default_float
51 self.dtype_ = numpy.float32
52 elif len(self.keys_strings) > 0 and len(self.values_int64s) > 0:
53 self.classes_ = {k.decode('utf-8'): v for k, v in zip(
54 self.keys_strings, self.values_int64s)}
55 self.default_ = self.default_int64
56 self.dtype_ = numpy.int64
57 elif len(self.keys_strings) > 0 and len(self.values_strings) > 0:
58 self.classes_ = {
59 k.decode('utf-8'): v.decode('utf-8') for k, v in zip(
60 self.keys_strings, self.values_strings)}
61 self.default_ = self.default_string
62 self.dtype_ = numpy.array(self.classes_.values).dtype
63 elif len(self.keys_floats) > 0 and len(self.values_strings) > 0:
64 self.classes_ = {k: v.decode('utf-8') for k, v in zip(
65 self.keys_floats, self.values_strings)}
66 self.default_ = self.default_string
67 self.dtype_ = numpy.array(self.classes_.values).dtype
68 elif len(self.keys_int64s) > 0 and len(self.values_strings) > 0:
69 self.classes_ = {k: v.decode('utf-8') for k, v in zip(
70 self.keys_int64s, self.values_strings)}
71 self.default_ = self.default_string
72 self.dtype_ = numpy.array(self.classes_.values).dtype
73 elif hasattr(self, 'classes_strings'):
74 raise RuntimeError( # pragma: no cover
75 "This runtime does not implement version 1 of "
76 "operator LabelEncoder.")
77 else:
78 raise RuntimeError(
79 f"No encoding was defined in {onnx_node}.")
80 if len(self.classes_) == 0:
81 raise RuntimeError( # pragma: no cover
82 "Empty classes for LabelEncoder, (onnx_node='{}')\n{}.".format(
83 self.onnx_node.name, onnx_node))
85 def _run(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
86 if len(x.shape) > 1:
87 x = numpy.squeeze(x)
88 res = numpy.empty((x.shape[0], ), dtype=self.dtype_)
89 for i in range(0, res.shape[0]):
90 res[i] = self.classes_.get(x[i], self.default_)
91 return (res, )