Coverage for mlprodict/onnxrt/ops_cpu/op_hardmax.py: 100%

14 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRunUnaryNum 

9 

10 

11class Hardmax(OpRunUnaryNum): 

12 

13 atts = {'axis': -1} 

14 

15 def __init__(self, onnx_node, desc=None, **options): 

16 OpRunUnaryNum.__init__(self, onnx_node, desc=desc, 

17 expected_attributes=Hardmax.atts, 

18 **options) 

19 

20 def _run(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

21 x_argmax = numpy.argmax(x, axis=self.axis) 

22 y = numpy.zeros_like(x) 

23 numpy.put_along_axis(y, numpy.expand_dims(x_argmax, axis=self.axis), 

24 1, axis=self.axis) 

25 return (y, ) 

26 

27 def to_python(self, inputs): 

28 return ("import numpy", 

29 "\n".join([ 

30 "{0}_argmax = numpy.argmax({0}, axis=axis)".format( 

31 inputs[0]), 

32 "{0}y = numpy.zeros_like({0})".format(inputs[0]), 

33 f"numpy.put_along_axis({inputs[0]}y,", 

34 " numpy.expand_dims(", 

35 f" {inputs[0]}_argmax, axis=axis),", 

36 " 1, axis=axis)", 

37 f"return {inputs[0]}y"]))