Coverage for mlprodict/onnxrt/ops_cpu/op_tokenizer.py: 98%
94 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import re
8import numpy
9from ._op import OpRunUnary, RuntimeTypeError
10from ._new_ops import OperatorSchema
13class Tokenizer(OpRunUnary):
14 """
15 See :epkg:`Tokenizer`.
16 """
18 atts = {'mark': 0,
19 'mincharnum': 1,
20 'pad_value': b'#',
21 'separators': [],
22 'tokenexp': b'[a-zA-Z0-9_]+',
23 'tokenexpsplit': 0,
24 'stopwords': []}
26 def __init__(self, onnx_node, desc=None, **options):
27 OpRunUnary.__init__(self, onnx_node, desc=desc,
28 expected_attributes=Tokenizer.atts,
29 **options)
30 self.char_tokenization_ = (
31 self.tokenexp == b'.' or list(self.separators) == [b''])
32 self.stops_ = set(_.decode() for _ in self.stopwords)
33 try:
34 self.str_separators_ = set(_.decode('utf-8')
35 for _ in self.separators)
36 except AttributeError as e: # pragma: no cover
37 raise RuntimeTypeError(
38 f"Unable to interpret separators {self.separators}.") from e
39 if self.tokenexp not in (None, b''):
40 self.tokenexp_ = re.compile(self.tokenexp.decode('utf-8'))
42 def _find_custom_operator_schema(self, op_name):
43 if op_name == "Tokenizer":
44 return TokenizerSchema()
45 raise RuntimeError( # pragma: no cover
46 f"Unable to find a schema for operator '{op_name}'.")
48 def _run(self, text, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
49 if self.char_tokenization_:
50 return self._run_char_tokenization(text, self.stops_)
51 if self.str_separators_ is not None and len(self.str_separators_) > 0:
52 return self._run_sep_tokenization(
53 text, self.stops_, self.str_separators_)
54 if self.tokenexp not in (None, ''):
55 return self._run_regex_tokenization(
56 text, self.stops_, self.tokenexp_)
57 raise RuntimeError( # pragma: no cover
58 "Unable to guess which tokenization to use, sep={}, "
59 "tokenexp='{}'.".format(self.separators, self.tokenexp))
61 def _run_tokenization(self, text, stops, split):
62 """
63 Tokenizes a char level.
64 """
65 max_len = max(map(len, text.flatten()))
66 if self.mark:
67 max_len += 2
68 begin = 1
69 else:
70 begin = 0
71 shape = text.shape + (max_len, )
72 max_pos = 0
73 res = numpy.empty(shape, dtype=text.dtype)
74 if len(text.shape) == 1:
75 res[:] = self.pad_value
76 for i in range(text.shape[0]):
77 pos = begin
78 for c in split(text[i]):
79 if c not in stops:
80 res[i, pos] = c
81 pos += 1
82 if self.mark:
83 res[i, 0] = self.pad_value
84 max_pos = max(pos + 1, max_pos)
85 else:
86 max_pos = max(pos, max_pos)
87 res = res[:, :max_pos]
88 elif len(text.shape) == 2:
89 res[:, :] = self.pad_value
90 for i in range(text.shape[0]):
91 for ii in range(text.shape[1]):
92 pos = begin
93 for c in split(text[i, ii]):
94 if c not in stops:
95 res[i, ii, pos] = c
96 pos += 1
97 if self.mark:
98 res[i, ii, 0] = self.pad_value
99 max_pos = max(pos + 1, max_pos)
100 else:
101 max_pos = max(pos, max_pos)
102 res = res[:, :, :max_pos]
103 else:
104 raise RuntimeError( # pragma: no cover
105 f"Only vector or matrices are supported not shape {text.shape}.")
106 return (res, )
108 def _run_char_tokenization(self, text, stops):
109 """
110 Tokenizes y charaters.
111 """
112 def split(t):
113 for c in t:
114 yield c
115 return self._run_tokenization(text, stops, split)
117 def _run_sep_tokenization(self, text, stops, separators):
118 """
119 Tokenizes using separators.
120 The function should use a trie to find text.
121 """
122 def split(t):
123 begin = 0
124 pos = 0
125 while pos < len(t):
126 for sep in separators:
127 if (pos + len(sep) <= len(t) and
128 sep == t[pos: pos + len(sep)]):
129 word = t[begin: pos]
130 yield word
131 begin = pos + len(sep)
132 break
133 pos += 1
134 if begin < pos:
135 word = t[begin: pos]
136 yield word
138 return self._run_tokenization(text, stops, split)
140 def _run_regex_tokenization(self, text, stops, exp):
141 """
142 Tokenizes using separators.
143 The function should use a trie to find text.
144 """
145 if self.tokenexpsplit:
146 def split(t):
147 return filter(lambda x: x, exp.split(t))
148 else:
149 def split(t):
150 return filter(lambda x: x, exp.findall(t))
151 return self._run_tokenization(text, stops, split)
154class TokenizerSchema(OperatorSchema):
155 """
156 Defines a schema for operators added in this package
157 such as @see cl TreeEnsembleClassifierDouble.
158 """
160 def __init__(self):
161 OperatorSchema.__init__(self, 'Tokenizer')
162 self.attributes = Tokenizer.atts