Coverage for mlprodict/onnxrt/ops_cpu/op_normalizer.py: 91%
46 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRunUnaryNum
11class Normalizer(OpRunUnaryNum):
13 atts = {'norm': 'MAX'}
15 def __init__(self, onnx_node, desc=None, **options):
16 OpRunUnaryNum.__init__(self, onnx_node, desc=desc,
17 expected_attributes=Normalizer.atts,
18 **options)
19 if self.norm == b'MAX': # pylint: disable=E1101
20 self._norm = Normalizer.norm_max
21 elif self.norm == b'L1': # pylint: disable=E1101
22 self._norm = Normalizer.norm_l1
23 elif self.norm == b'L2': # pylint: disable=E1101
24 self._norm = Normalizer.norm_l2
25 else:
26 raise ValueError( # pragma: no cover
27 f"Unexpected value for norm='{self.norm}'.") # pylint: disable=E1101
29 @staticmethod
30 def norm_max(x, inplace):
31 "max normalization"
32 if inplace:
33 return Normalizer._norm_max_inplace(x)
34 div = numpy.abs(x).max(axis=1).reshape((x.shape[0], -1))
35 return x / numpy.maximum(div, 1e-30)
37 @staticmethod
38 def _norm_max_inplace(x):
39 div = numpy.abs(x).max(axis=1).reshape((x.shape[0], -1))
40 numpy.divide(x, numpy.maximum(div, 1e-30), out=x)
41 return x
43 @staticmethod
44 def norm_l1(x, inplace):
45 "L1 normalization"
46 if inplace:
47 return Normalizer._norm_L1_inplace(x)
48 div = numpy.abs(x).sum(axis=1).reshape((x.shape[0], -1))
49 return x / numpy.maximum(div, 1e-30)
51 @staticmethod
52 def _norm_L1_inplace(x):
53 div = numpy.abs(x).sum(axis=1).reshape((x.shape[0], -1))
54 numpy.divide(x, numpy.maximum(div, 1e-30), out=x)
55 return x
57 @staticmethod
58 def norm_l2(x, inplace):
59 "L2 normalization"
60 xn = numpy.square(x).sum(axis=1)
61 numpy.sqrt(xn, out=xn)
62 norm = numpy.maximum(xn.reshape((x.shape[0], -1)), 1e-30)
63 if inplace:
64 numpy.divide(x, norm, out=x)
65 return x
66 return x / norm
68 def _run(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
69 return (self._norm(
70 x, inplace=self.inplaces.get(0, False) and x.flags['WRITEABLE']), )