Coverage for mlprodict/onnxrt/ops_cpu/op_tokenizer.py: 98%

94 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import re 

8import numpy 

9from ._op import OpRunUnary, RuntimeTypeError 

10from ._new_ops import OperatorSchema 

11 

12 

13class Tokenizer(OpRunUnary): 

14 """ 

15 See :epkg:`Tokenizer`. 

16 """ 

17 

18 atts = {'mark': 0, 

19 'mincharnum': 1, 

20 'pad_value': b'#', 

21 'separators': [], 

22 'tokenexp': b'[a-zA-Z0-9_]+', 

23 'tokenexpsplit': 0, 

24 'stopwords': []} 

25 

26 def __init__(self, onnx_node, desc=None, **options): 

27 OpRunUnary.__init__(self, onnx_node, desc=desc, 

28 expected_attributes=Tokenizer.atts, 

29 **options) 

30 self.char_tokenization_ = ( 

31 self.tokenexp == b'.' or list(self.separators) == [b'']) 

32 self.stops_ = set(_.decode() for _ in self.stopwords) 

33 try: 

34 self.str_separators_ = set(_.decode('utf-8') 

35 for _ in self.separators) 

36 except AttributeError as e: # pragma: no cover 

37 raise RuntimeTypeError( 

38 f"Unable to interpret separators {self.separators}.") from e 

39 if self.tokenexp not in (None, b''): 

40 self.tokenexp_ = re.compile(self.tokenexp.decode('utf-8')) 

41 

42 def _find_custom_operator_schema(self, op_name): 

43 if op_name == "Tokenizer": 

44 return TokenizerSchema() 

45 raise RuntimeError( # pragma: no cover 

46 f"Unable to find a schema for operator '{op_name}'.") 

47 

48 def _run(self, text, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

49 if self.char_tokenization_: 

50 return self._run_char_tokenization(text, self.stops_) 

51 if self.str_separators_ is not None and len(self.str_separators_) > 0: 

52 return self._run_sep_tokenization( 

53 text, self.stops_, self.str_separators_) 

54 if self.tokenexp not in (None, ''): 

55 return self._run_regex_tokenization( 

56 text, self.stops_, self.tokenexp_) 

57 raise RuntimeError( # pragma: no cover 

58 "Unable to guess which tokenization to use, sep={}, " 

59 "tokenexp='{}'.".format(self.separators, self.tokenexp)) 

60 

61 def _run_tokenization(self, text, stops, split): 

62 """ 

63 Tokenizes a char level. 

64 """ 

65 max_len = max(map(len, text.flatten())) 

66 if self.mark: 

67 max_len += 2 

68 begin = 1 

69 else: 

70 begin = 0 

71 shape = text.shape + (max_len, ) 

72 max_pos = 0 

73 res = numpy.empty(shape, dtype=text.dtype) 

74 if len(text.shape) == 1: 

75 res[:] = self.pad_value 

76 for i in range(text.shape[0]): 

77 pos = begin 

78 for c in split(text[i]): 

79 if c not in stops: 

80 res[i, pos] = c 

81 pos += 1 

82 if self.mark: 

83 res[i, 0] = self.pad_value 

84 max_pos = max(pos + 1, max_pos) 

85 else: 

86 max_pos = max(pos, max_pos) 

87 res = res[:, :max_pos] 

88 elif len(text.shape) == 2: 

89 res[:, :] = self.pad_value 

90 for i in range(text.shape[0]): 

91 for ii in range(text.shape[1]): 

92 pos = begin 

93 for c in split(text[i, ii]): 

94 if c not in stops: 

95 res[i, ii, pos] = c 

96 pos += 1 

97 if self.mark: 

98 res[i, ii, 0] = self.pad_value 

99 max_pos = max(pos + 1, max_pos) 

100 else: 

101 max_pos = max(pos, max_pos) 

102 res = res[:, :, :max_pos] 

103 else: 

104 raise RuntimeError( # pragma: no cover 

105 f"Only vector or matrices are supported not shape {text.shape}.") 

106 return (res, ) 

107 

108 def _run_char_tokenization(self, text, stops): 

109 """ 

110 Tokenizes y charaters. 

111 """ 

112 def split(t): 

113 for c in t: 

114 yield c 

115 return self._run_tokenization(text, stops, split) 

116 

117 def _run_sep_tokenization(self, text, stops, separators): 

118 """ 

119 Tokenizes using separators. 

120 The function should use a trie to find text. 

121 """ 

122 def split(t): 

123 begin = 0 

124 pos = 0 

125 while pos < len(t): 

126 for sep in separators: 

127 if (pos + len(sep) <= len(t) and 

128 sep == t[pos: pos + len(sep)]): 

129 word = t[begin: pos] 

130 yield word 

131 begin = pos + len(sep) 

132 break 

133 pos += 1 

134 if begin < pos: 

135 word = t[begin: pos] 

136 yield word 

137 

138 return self._run_tokenization(text, stops, split) 

139 

140 def _run_regex_tokenization(self, text, stops, exp): 

141 """ 

142 Tokenizes using separators. 

143 The function should use a trie to find text. 

144 """ 

145 if self.tokenexpsplit: 

146 def split(t): 

147 return filter(lambda x: x, exp.split(t)) 

148 else: 

149 def split(t): 

150 return filter(lambda x: x, exp.findall(t)) 

151 return self._run_tokenization(text, stops, split) 

152 

153 

154class TokenizerSchema(OperatorSchema): 

155 """ 

156 Defines a schema for operators added in this package 

157 such as @see cl TreeEnsembleClassifierDouble. 

158 """ 

159 

160 def __init__(self): 

161 OperatorSchema.__init__(self, 'Tokenizer') 

162 self.attributes = Tokenizer.atts