Coverage for mlprodict/onnxrt/ops_cpu/op_one_hot.py: 100%

24 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9 

10 

11def _one_hot(indices, depth, axis=-1, dtype=numpy.float32): 

12 values = numpy.asarray(indices) 

13 rank = len(values.shape) 

14 depth_range = numpy.arange(depth) 

15 if axis < 0: 

16 axis += (rank + 1) 

17 ls = values.shape[0:axis] 

18 rs = values.shape[axis:rank] 

19 new_shape = (1,) * len(ls) + depth_range.shape + (1,) * len(rs) 

20 targets = numpy.reshape(depth_range, new_shape) 

21 values = numpy.reshape(numpy.mod(values, depth), ls + (1,) + rs) 

22 return numpy.asarray(targets == values, dtype=dtype) 

23 

24 

25class OneHot(OpRun): 

26 

27 atts = {'axis': -1} 

28 

29 def __init__(self, onnx_node, desc=None, **options): 

30 OpRun.__init__(self, onnx_node, desc=desc, 

31 expected_attributes=OneHot.atts, 

32 **options) 

33 

34 def _run(self, indices, depth, values, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

35 off_value, on_value = values 

36 y = _one_hot(indices, depth, dtype=values.dtype) 

37 y = y * (on_value - off_value) + off_value 

38 return (y, )