Coverage for mlprodict/onnxrt/ops_cpu/op_normalizer.py: 91%

46 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRunUnaryNum 

9 

10 

11class Normalizer(OpRunUnaryNum): 

12 

13 atts = {'norm': 'MAX'} 

14 

15 def __init__(self, onnx_node, desc=None, **options): 

16 OpRunUnaryNum.__init__(self, onnx_node, desc=desc, 

17 expected_attributes=Normalizer.atts, 

18 **options) 

19 if self.norm == b'MAX': # pylint: disable=E1101 

20 self._norm = Normalizer.norm_max 

21 elif self.norm == b'L1': # pylint: disable=E1101 

22 self._norm = Normalizer.norm_l1 

23 elif self.norm == b'L2': # pylint: disable=E1101 

24 self._norm = Normalizer.norm_l2 

25 else: 

26 raise ValueError( # pragma: no cover 

27 f"Unexpected value for norm='{self.norm}'.") # pylint: disable=E1101 

28 

29 @staticmethod 

30 def norm_max(x, inplace): 

31 "max normalization" 

32 if inplace: 

33 return Normalizer._norm_max_inplace(x) 

34 div = numpy.abs(x).max(axis=1).reshape((x.shape[0], -1)) 

35 return x / numpy.maximum(div, 1e-30) 

36 

37 @staticmethod 

38 def _norm_max_inplace(x): 

39 div = numpy.abs(x).max(axis=1).reshape((x.shape[0], -1)) 

40 numpy.divide(x, numpy.maximum(div, 1e-30), out=x) 

41 return x 

42 

43 @staticmethod 

44 def norm_l1(x, inplace): 

45 "L1 normalization" 

46 if inplace: 

47 return Normalizer._norm_L1_inplace(x) 

48 div = numpy.abs(x).sum(axis=1).reshape((x.shape[0], -1)) 

49 return x / numpy.maximum(div, 1e-30) 

50 

51 @staticmethod 

52 def _norm_L1_inplace(x): 

53 div = numpy.abs(x).sum(axis=1).reshape((x.shape[0], -1)) 

54 numpy.divide(x, numpy.maximum(div, 1e-30), out=x) 

55 return x 

56 

57 @staticmethod 

58 def norm_l2(x, inplace): 

59 "L2 normalization" 

60 xn = numpy.square(x).sum(axis=1) 

61 numpy.sqrt(xn, out=xn) 

62 norm = numpy.maximum(xn.reshape((x.shape[0], -1)), 1e-30) 

63 if inplace: 

64 numpy.divide(x, norm, out=x) 

65 return x 

66 return x / norm 

67 

68 def _run(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

69 return (self._norm( 

70 x, inplace=self.inplaces.get(0, False) and x.flags['WRITEABLE']), )