Coverage for mlprodict/onnxrt/ops_cpu/op_one_hot.py: 100%
24 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
11def _one_hot(indices, depth, axis=-1, dtype=numpy.float32):
12 values = numpy.asarray(indices)
13 rank = len(values.shape)
14 depth_range = numpy.arange(depth)
15 if axis < 0:
16 axis += (rank + 1)
17 ls = values.shape[0:axis]
18 rs = values.shape[axis:rank]
19 new_shape = (1,) * len(ls) + depth_range.shape + (1,) * len(rs)
20 targets = numpy.reshape(depth_range, new_shape)
21 values = numpy.reshape(numpy.mod(values, depth), ls + (1,) + rs)
22 return numpy.asarray(targets == values, dtype=dtype)
25class OneHot(OpRun):
27 atts = {'axis': -1}
29 def __init__(self, onnx_node, desc=None, **options):
30 OpRun.__init__(self, onnx_node, desc=desc,
31 expected_attributes=OneHot.atts,
32 **options)
34 def _run(self, indices, depth, values, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
35 off_value, on_value = values
36 y = _one_hot(indices, depth, dtype=values.dtype)
37 y = y * (on_value - off_value) + off_value
38 return (y, )