Coverage for mlprodict/onnxrt/ops_cpu/op_celu.py: 100%

24 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRunUnaryNum 

9 

10 

11def pycelu(x, alpha=1.): 

12 """ 

13 Computes function ``celu(x)``. 

14 

15 .. math:: 

16 

17 celu(x) = \\left \\{\\begin{array}{ll} x \\text{ if } x > 0 \\\\ 

18 \\alpha ( e^{\\frac{x}{\\alpha}} - 1) \\, \\text{ otherwise } 

19 \\end{array} \\right. 

20 """ 

21 if x > 0: 

22 return x 

23 return (numpy.exp(x / alpha) - 1) * alpha 

24 

25 

26def _vcelu1(x, alpha=1.): 

27 positive_input = numpy.maximum(0, x) 

28 negative_input = numpy.minimum(0, alpha * ( 

29 numpy.exp(x / alpha) - 1)) 

30 return positive_input + negative_input 

31 

32 

33class Celu(OpRunUnaryNum): 

34 

35 atts = {'alpha': numpy.float32(1.0)} 

36 

37 def __init__(self, onnx_node, desc=None, **options): 

38 OpRunUnaryNum.__init__(self, onnx_node, desc=desc, 

39 expected_attributes=Celu.atts, 

40 **options) 

41 self._vcelu2 = numpy.vectorize( 

42 lambda x: pycelu(x, self.alpha), otypes=[numpy.float64]) 

43 

44 def _run(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

45 if self.inplaces.get(0, False) and x.flags['WRITEABLE']: 

46 return self._run_inplace(x) 

47 return (_vcelu1(x, self.alpha), ) 

48 

49 def _run_inplace(self, x): 

50 return (self._vcelu2(x), ) 

51 

52 def to_python(self, inputs): 

53 return ('from mlprodict.onnxrt.ops_cpu.op_celu import _vcelu1', 

54 "return _vcelu1(X, alpha)")