Coverage for mlprodict/onnxrt/ops_cpu/op_one_hot_encoder.py: 72%
36 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
11class OneHotEncoder(OpRun):
12 """
13 :epkg:`ONNX` specifications does not mention
14 the possibility to change the output type,
15 sparse, dense, float, double.
16 """
18 atts = {'cats_int64s': numpy.empty(0, dtype=numpy.int64),
19 'cats_strings': numpy.empty(0, dtype=numpy.str_),
20 'zeros': 1,
21 }
23 def __init__(self, onnx_node, desc=None, **options):
24 OpRun.__init__(self, onnx_node, desc=desc,
25 expected_attributes=OneHotEncoder.atts,
26 **options)
27 if len(self.cats_int64s) > 0:
28 self.classes_ = {v: i for i, v in enumerate(self.cats_int64s)}
29 elif len(self.cats_strings) > 0:
30 self.classes_ = {v.decode('utf-8'): i for i,
31 v in enumerate(self.cats_strings)}
32 else:
33 raise RuntimeError("No encoding was defined.") # pragma: no cover
35 def _run(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
36 shape = x.shape
37 new_shape = shape + (len(self.classes_), )
38 res = numpy.zeros(new_shape, dtype=numpy.float32)
39 if len(x.shape) == 1:
40 for i, v in enumerate(x):
41 j = self.classes_.get(v, -1)
42 if j >= 0:
43 res[i, j] = 1.
44 elif len(x.shape) == 2:
45 for a, row in enumerate(x):
46 for i, v in enumerate(row):
47 j = self.classes_.get(v, -1)
48 if j >= 0:
49 res[a, i, j] = 1.
50 else:
51 raise RuntimeError( # pragma: no cover
52 f"This operator is not implemented for shape {x.shape}.")
54 if not self.zeros:
55 red = res.sum(axis=len(res.shape) - 1)
56 if numpy.min(red) == 0:
57 rows = []
58 for i, val in enumerate(red):
59 if val == 0:
60 rows.append(dict(row=i, value=x[i]))
61 if len(rows) > 5:
62 break
63 raise RuntimeError( # pragma no cover
64 "One observation did not have any defined category.\n"
65 "classes: {}\nfirst rows:\n{}\nres:\n{}\nx:\n{}".format(
66 self.classes_, "\n".join(str(_) for _ in rows),
67 res[:5], x[:5]))
69 return (res, )