Coverage for mlprodict/onnxrt/ops_cpu/op_string_normalizer.py: 100%
57 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import unicodedata
8import locale
9import warnings
10import numpy
11from ._op import OpRunUnary, RuntimeTypeError
14class StringNormalizer(OpRunUnary):
15 """
16 The operator is not really threadsafe as python cannot
17 play with two locales at the same time. stop words
18 should not be implemented here as the tokenization
19 usually happens after this steps.
20 """
22 atts = {'case_change_action': b'NONE', # LOWER UPPER NONE
23 'is_case_sensitive': 1,
24 'locale': b'',
25 'stopwords': []}
27 def __init__(self, onnx_node, desc=None, **options):
28 OpRunUnary.__init__(self, onnx_node, desc=desc,
29 expected_attributes=StringNormalizer.atts,
30 **options)
31 self.slocale = self.locale.decode('ascii')
32 self.stops = set(self.stopwords)
34 def _run(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
35 """
36 Normalizes strings.
37 """
38 res = numpy.empty(x.shape, dtype=x.dtype)
39 if len(x.shape) == 2:
40 for i in range(0, x.shape[1]):
41 self._run_column(x[:, i], res[:, i])
42 elif len(x.shape) == 1:
43 self._run_column(x, res)
44 else:
45 raise RuntimeTypeError( # pragma: no cover
46 "x must be a matrix or a vector.")
47 return (res, )
49 def _run_column(self, cin, cout):
50 """
51 Normalizes string in a columns.
52 """
53 if locale.getlocale() != self.slocale:
54 try:
55 locale.setlocale(locale.LC_ALL, self.slocale)
56 except locale.Error as e:
57 warnings.warn(
58 "Unknown local setting '{}' (current: '{}') - {}."
59 "".format(self.slocale, locale.getlocale(), e))
60 stops = set(_.decode() for _ in self.stops)
61 cout[:] = cin[:]
63 for i in range(0, cin.shape[0]):
64 if isinstance(cout[i], float):
65 # nan
66 cout[i] = '' # pragma: no cover
67 else:
68 cout[i] = self.strip_accents_unicode(cout[i])
70 if self.is_case_sensitive and len(stops) > 0:
71 for i in range(0, cin.shape[0]):
72 cout[i] = self._remove_stopwords(cout[i], stops)
74 if self.case_change_action == b'LOWER':
75 for i in range(0, cin.shape[0]):
76 cout[i] = cout[i].lower()
77 elif self.case_change_action == b'UPPER':
78 for i in range(0, cin.shape[0]):
79 cout[i] = cout[i].upper()
80 elif self.case_change_action != b'NONE':
81 raise RuntimeError(
82 f"Unknown option for case_change_action: {self.case_change_action}.")
84 if not self.is_case_sensitive and len(stops) > 0:
85 for i in range(0, cin.shape[0]):
86 cout[i] = self._remove_stopwords(cout[i], stops)
88 return cout
90 def _remove_stopwords(self, text, stops):
91 spl = text.split(' ')
92 return ' '.join(filter(lambda s: s not in stops, spl))
94 def strip_accents_unicode(self, s):
95 """
96 Transforms accentuated unicode symbols into their simple counterpart.
97 Source: `sklearn/feature_extraction/text.py
98 <https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/
99 feature_extraction/text.py#L115>`_.
101 :param s: string
102 The string to strip
103 :return: the cleaned string
104 """
105 try:
106 # If `s` is ASCII-compatible, then it does not contain any accented
107 # characters and we can avoid an expensive list comprehension
108 s.encode("ASCII", errors="strict")
109 return s
110 except UnicodeEncodeError:
111 normalized = unicodedata.normalize('NFKD', s)
112 s = ''.join(
113 [c for c in normalized if not unicodedata.combining(c)])
114 return s