Coverage for mlprodict/onnxrt/ops_cpu/op_string_normalizer.py: 100%

57 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import unicodedata 

8import locale 

9import warnings 

10import numpy 

11from ._op import OpRunUnary, RuntimeTypeError 

12 

13 

14class StringNormalizer(OpRunUnary): 

15 """ 

16 The operator is not really threadsafe as python cannot 

17 play with two locales at the same time. stop words 

18 should not be implemented here as the tokenization 

19 usually happens after this steps. 

20 """ 

21 

22 atts = {'case_change_action': b'NONE', # LOWER UPPER NONE 

23 'is_case_sensitive': 1, 

24 'locale': b'', 

25 'stopwords': []} 

26 

27 def __init__(self, onnx_node, desc=None, **options): 

28 OpRunUnary.__init__(self, onnx_node, desc=desc, 

29 expected_attributes=StringNormalizer.atts, 

30 **options) 

31 self.slocale = self.locale.decode('ascii') 

32 self.stops = set(self.stopwords) 

33 

34 def _run(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

35 """ 

36 Normalizes strings. 

37 """ 

38 res = numpy.empty(x.shape, dtype=x.dtype) 

39 if len(x.shape) == 2: 

40 for i in range(0, x.shape[1]): 

41 self._run_column(x[:, i], res[:, i]) 

42 elif len(x.shape) == 1: 

43 self._run_column(x, res) 

44 else: 

45 raise RuntimeTypeError( # pragma: no cover 

46 "x must be a matrix or a vector.") 

47 return (res, ) 

48 

49 def _run_column(self, cin, cout): 

50 """ 

51 Normalizes string in a columns. 

52 """ 

53 if locale.getlocale() != self.slocale: 

54 try: 

55 locale.setlocale(locale.LC_ALL, self.slocale) 

56 except locale.Error as e: 

57 warnings.warn( 

58 "Unknown local setting '{}' (current: '{}') - {}." 

59 "".format(self.slocale, locale.getlocale(), e)) 

60 stops = set(_.decode() for _ in self.stops) 

61 cout[:] = cin[:] 

62 

63 for i in range(0, cin.shape[0]): 

64 if isinstance(cout[i], float): 

65 # nan 

66 cout[i] = '' # pragma: no cover 

67 else: 

68 cout[i] = self.strip_accents_unicode(cout[i]) 

69 

70 if self.is_case_sensitive and len(stops) > 0: 

71 for i in range(0, cin.shape[0]): 

72 cout[i] = self._remove_stopwords(cout[i], stops) 

73 

74 if self.case_change_action == b'LOWER': 

75 for i in range(0, cin.shape[0]): 

76 cout[i] = cout[i].lower() 

77 elif self.case_change_action == b'UPPER': 

78 for i in range(0, cin.shape[0]): 

79 cout[i] = cout[i].upper() 

80 elif self.case_change_action != b'NONE': 

81 raise RuntimeError( 

82 f"Unknown option for case_change_action: {self.case_change_action}.") 

83 

84 if not self.is_case_sensitive and len(stops) > 0: 

85 for i in range(0, cin.shape[0]): 

86 cout[i] = self._remove_stopwords(cout[i], stops) 

87 

88 return cout 

89 

90 def _remove_stopwords(self, text, stops): 

91 spl = text.split(' ') 

92 return ' '.join(filter(lambda s: s not in stops, spl)) 

93 

94 def strip_accents_unicode(self, s): 

95 """ 

96 Transforms accentuated unicode symbols into their simple counterpart. 

97 Source: `sklearn/feature_extraction/text.py 

98 <https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/ 

99 feature_extraction/text.py#L115>`_. 

100 

101 :param s: string 

102 The string to strip 

103 :return: the cleaned string 

104 """ 

105 try: 

106 # If `s` is ASCII-compatible, then it does not contain any accented 

107 # characters and we can avoid an expensive list comprehension 

108 s.encode("ASCII", errors="strict") 

109 return s 

110 except UnicodeEncodeError: 

111 normalized = unicodedata.normalize('NFKD', s) 

112 s = ''.join( 

113 [c for c in normalized if not unicodedata.combining(c)]) 

114 return s