microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
e0d48e255f28e5465f63e7fc141df1e1d533cc40

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/tokenizer/bert_tokenizer.cc

413lines · modecode

1#include "bert_tokenizer.hpp"
2
3#include <utility>
4
5BertTokenizerVocab::BertTokenizerVocab(std::string_view vocab) : raw_vocab_(vocab) {
6 auto tokens = SplitString(raw_vocab_, "\r\n", true);
7
8 for (size_t i = 0; i < tokens.size(); i++) {
9 (vocab_)[tokens[i]] = static_cast<int32_t>(i);
10 }
11}
12
13bool BertTokenizerVocab::FindToken(const ustring& token) {
14 auto utf8_token = std::string(token);
15
16 return vocab_.find(utf8_token) != vocab_.end();
17}
18
19bool BertTokenizerVocab::FindTokenId(const ustring& token, int32_t& token_id) {
20 auto utf8_token = std::string(token);
21
22 auto it = vocab_.find(utf8_token);
23 if (it == vocab_.end()) {
24 return false;
25 }
26
27 token_id = it->second;
28 return true;
29}
30
31int32_t BertTokenizerVocab::FindTokenId(const ustring& token) {
32 auto utf8_token = std::string(token);
33
34 auto it = vocab_.find(utf8_token);
35 if (it == vocab_.end()) {
36 ORTX_CXX_API_THROW("[BertTokenizerVocab]: can not find tokens: " + std::string(token), ORT_RUNTIME_EXCEPTION);
37 }
38
39 return it->second;
40}
41
42WordpieceTokenizer::WordpieceTokenizer(
43 std::shared_ptr<BertTokenizerVocab> vocab,
44 ustring unk_token,
45 ustring suffix_indicator,
46 int max_input_chars_per_word
47) :
48 max_input_chars_per_word_(max_input_chars_per_word),
49 suffix_indicator_(std::move(suffix_indicator)),
50 unk_token_(std::move(unk_token)),
51 vocab_(std::move(vocab))
52{
53 unk_token_id_ = vocab_->FindTokenId(unk_token_);
54}
55
56std::vector<ustring> WordpieceTokenizer::Tokenize(const ustring& text) {
57 std::vector<ustring> result;
58 ustring token;
59 for (auto c : text) {
60 if (c == U' ' && !token.empty()) {
61 GreedySearch(token, result);
62 token.clear();
63 continue;
64 }
65
66 token.push_back(c);
67 }
68
69 if (!token.empty()) {
70 GreedySearch(token, result);
71 }
72
73 return result;
74}
75
76std::vector<ustring> WordpieceTokenizer::Tokenize(const std::vector<ustring>& tokens) {
77 std::vector<ustring> result;
78 for (const auto& token : tokens) {
79 GreedySearch(token, result);
80 }
81
82 return result;
83}
84
85std::vector<int64_t> WordpieceTokenizer::Encode(const std::vector<ustring>& tokens) {
86 std::vector<int64_t> ids;
87 for (const auto& token : tokens) {
88 int32_t token_id = -1;
89 if (!vocab_->FindTokenId(token, token_id)) {
90 ids.push_back(unk_token_id_);
91 continue;
92 }
93
94 ids.push_back(token_id);
95 }
96 return ids;
97}
98
99void WordpieceTokenizer::GreedySearch(const ustring& token, std::vector<ustring>& tokenized_result) {
100 if (static_cast<int64_t>(token.size()) > max_input_chars_per_word_) {
101 tokenized_result.push_back(unk_token_);
102 return;
103 }
104
105 size_t start = 0;
106 size_t end = 0;
107 ustring substr;
108 for (; start < token.size();) {
109 end = token.size();
110 bool is_found = false;
111 // try to found the longest matched sub-token in vocab
112 for (; start < end;) {
113 substr = static_cast<const ustring>(token.substr(start, end - start));
114 if (start > 0) {
115 substr = static_cast<const ustring>(suffix_indicator_ + substr);
116 }
117 if (vocab_->FindToken(substr)) {
118 is_found = true;
119 break;
120 }
121 end -= 1;
122 }
123 // token not found in vocab
124 if (!is_found) {
125 tokenized_result.push_back(unk_token_);
126 break;
127 }
128
129 tokenized_result.push_back(substr);
130 start = end;
131 }
132}
133
134void TruncateStrategy::Truncate(std::vector<int64_t>& ids, int32_t max_len) {
135 if ((max_len > 0) && (static_cast<size_t>(max_len) < ids.size())) {
136 ids.resize(max_len);
137 }
138}
139
140void TruncateStrategy::Truncate(std::vector<int64_t>& ids1, std::vector<int64_t>& ids2, int32_t max_len) {
141 if (max_len < 0 || (ids1.size() + ids2.size() <= static_cast<size_t>(max_len))) {
142 return;
143 }
144
145 auto ids1_keep_len = ids1.size();
146 auto ids2_keep_len = ids2.size();
147 auto half_max_len = max_len / 2;
148
149 switch (strategy_) {
150 case TruncateStrategyType::LONGEST_FIRST:
151 case TruncateStrategyType::LONGEST_FROM_BACK:
152
153 if ((ids1_keep_len > static_cast<size_t>(half_max_len)) && (ids2_keep_len > static_cast<size_t>(half_max_len))) {
154 ids1_keep_len = static_cast<size_t>(max_len) - half_max_len;
155 ids2_keep_len = half_max_len;
156 } else if (ids2_keep_len > ids1_keep_len) {
157 ids2_keep_len = static_cast<size_t>(max_len) - ids1_keep_len;
158 } else {
159 ids1_keep_len = static_cast<size_t>(max_len) - ids2_keep_len;
160 }
161
162 if (strategy_ == TruncateStrategyType::LONGEST_FIRST) {
163 ids1.resize(ids1_keep_len);
164 ids2.resize(ids2_keep_len);
165 } else {
166 ids1.erase(ids1.begin(), ids1.end() - ids1_keep_len);
167 ids2.erase(ids2.begin(), ids2.end() - ids2_keep_len);
168 }
169
170 return;
171 case TruncateStrategyType::ONLY_FIRST:
172 return;
173 case TruncateStrategyType::ONLY_SECOND:
174 return;
175 default:
176 return;
177 }
178}
179
180BertTokenizer::BertTokenizer(
181 const std::string& vocab,
182 bool do_lower_case,
183 bool do_basic_tokenize,
184 ustring unk_token,
185 ustring sep_token,
186 ustring pad_token,
187 ustring cls_token,
188 ustring mask_token,
189 bool tokenize_chinese_chars,
190 bool strip_accents,
191 ustring suffix_indicator,
192 int32_t max_len,
193 const std::string& truncation_strategy
194) :
195 max_length_(max_len),
196 do_basic_tokenize_(do_basic_tokenize),
197 truncate_(std::make_unique<TruncateStrategy>(truncation_strategy))
198{
199
200 vocab_ = std::make_shared<BertTokenizerVocab>(vocab);
201
202 if (do_basic_tokenize) {
203 basic_tokenizer_ = std::make_unique<BasicTokenizer>(
204 do_lower_case, tokenize_chinese_chars, strip_accents, true, true);
205 }
206 wordpiece_tokenizer_ = std::make_unique<WordpieceTokenizer>(
207 vocab_, unk_token, suffix_indicator);
208
209 unk_token_id_ = vocab_->FindTokenId(unk_token);
210 sep_token_id_ = vocab_->FindTokenId(sep_token);
211 pad_token_id_ = vocab_->FindTokenId(pad_token);
212 cls_token_id_ = vocab_->FindTokenId(cls_token);
213 mask_token_id_ = vocab_->FindTokenId(mask_token);
214}
215
216std::vector<ustring> BertTokenizer::Tokenize(const ustring& text) {
217 if (do_basic_tokenize_) {
218 return wordpiece_tokenizer_->Tokenize(basic_tokenizer_->Tokenize(text));
219 }
220 return wordpiece_tokenizer_->Tokenize(text);
221}
222
223std::vector<int64_t> BertTokenizer::Encode(const std::vector<ustring>& tokens) {
224 return wordpiece_tokenizer_->Encode(tokens);
225}
226
227void BertTokenizer::Truncate(std::vector<int64_t>& ids) {
228 truncate_->Truncate(ids, (max_length_ > 0 && max_length_ <= 2) ? 0 : max_length_ - 2);
229}
230
231void BertTokenizer::Truncate(std::vector<int64_t>& ids1, std::vector<int64_t>& ids2) {
232 truncate_->Truncate(ids1, ids2, (max_length_ > 0 && max_length_ <= 3) ? 0 : max_length_ - 3);
233}
234
235std::vector<int64_t> BertTokenizer::AddSpecialToken(const std::vector<int64_t>& ids) {
236 std::vector<int64_t> result;
237 result.reserve(ids.size() + 2);
238 result.push_back(cls_token_id_);
239 result.insert(result.end(), ids.begin(), ids.end());
240 result.push_back(sep_token_id_);
241 return result;
242}
243
244std::vector<int64_t> BertTokenizer::AddSpecialToken(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2) {
245 std::vector<int64_t> result;
246 result.reserve(ids1.size() + ids2.size() + 3);
247 result.push_back(cls_token_id_);
248 result.insert(result.end(), ids1.begin(), ids1.end());
249 result.push_back(sep_token_id_);
250 result.insert(result.end(), ids2.begin(), ids2.end());
251 result.push_back(sep_token_id_);
252 return result;
253}
254
255std::vector<int64_t> BertTokenizer::GenerateTypeId(const std::vector<int64_t>& ids) {
256 return std::vector<int64_t>(ids.size() + 2, 0);
257}
258
259std::vector<int64_t> BertTokenizer::GenerateTypeId(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2) {
260 std::vector<int64_t> result;
261 result.reserve(ids1.size() + ids2.size() + 3);
262 result.insert(result.end(), ids1.size() + 2, 0);
263 result.insert(result.end(), ids2.size() + 1, 1);
264 return result;
265}
266
267TruncateStrategy::TruncateStrategy(std::string_view strategy_name) : strategy_(TruncateStrategyType::LONGEST_FIRST) {
268 if (strategy_name == "longest_first") {
269 strategy_ = TruncateStrategyType::LONGEST_FIRST;
270 } else if (strategy_name == "only_first") {
271 strategy_ = TruncateStrategyType::ONLY_FIRST;
272 } else if (strategy_name == "only_second") {
273 strategy_ = TruncateStrategyType::ONLY_SECOND;
274 } else if (strategy_name == "longest_from_back") {
275 strategy_ = TruncateStrategyType::LONGEST_FROM_BACK;
276 }
277}
278
279KernelBertTokenizer::KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
280 std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
281 bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
282 bool do_basic_tokenize = TryToGetAttributeWithDefault("do_basic_tokenize", true);
283 std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]"));
284 std::string sep_token = TryToGetAttributeWithDefault("sep_token", std::string("[SEP]"));
285 std::string pad_token = TryToGetAttributeWithDefault("pad_token", std::string("[PAD]"));
286 std::string cls_token = TryToGetAttributeWithDefault("cls_token", std::string("[CLS]"));
287 std::string mask_token = TryToGetAttributeWithDefault("mask_token", std::string("[MASK]"));
288 bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
289 bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
290 std::string suffix_indicator = TryToGetAttributeWithDefault("suffix_indicator", std::string("##"));
291 std::string truncation_strategy_name = TryToGetAttributeWithDefault("truncation_strategy_name", std::string("longest_first"));
292 int32_t max_len = static_cast<int32_t>(TryToGetAttributeWithDefault("max_length", int64_t(-1)));
293
294 tokenizer_ = std::make_unique<BertTokenizer>(
295 vocab, do_lower_case, do_basic_tokenize, ustring(unk_token),
296 ustring(sep_token), ustring(pad_token), ustring(cls_token),
297 ustring(mask_token), tokenize_chinese_chars, strip_accents,
298 ustring(suffix_indicator), max_len, truncation_strategy_name);
299}
300
301void KernelBertTokenizer::Compute(OrtKernelContext* context) {
302 // Setup inputs
303 const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
304 std::vector<std::string> input_data;
305 GetTensorMutableDataString(api_, ort_, context, input, input_data);
306
307 if (input_data.size() != 1 && input_data.size() != 2) {
308 ORTX_CXX_API_THROW("[BertTokenizer]: only support one or two query.", ORT_INVALID_GRAPH);
309 }
310 std::vector<int64_t> input_ids;
311 std::vector<int64_t> token_type_ids;
312
313 if (input_data.size() == 1) {
314 std::vector<ustring> tokens = tokenizer_->Tokenize(ustring(input_data[0]));
315 std::vector<int64_t> encoded = tokenizer_->Encode(tokens);
316 tokenizer_->Truncate(encoded);
317 input_ids = tokenizer_->AddSpecialToken(encoded);
318 token_type_ids = tokenizer_->GenerateTypeId(encoded);
319 } else {
320 std::vector<ustring> tokens1 = tokenizer_->Tokenize(ustring(input_data[0]));
321 std::vector<ustring> tokens2 = tokenizer_->Tokenize(ustring(input_data[1]));
322 std::vector<int64_t> encoded1 = tokenizer_->Encode(tokens1);
323 std::vector<int64_t> encoded2 = tokenizer_->Encode(tokens2);
324 input_ids = tokenizer_->AddSpecialToken(encoded1, encoded2);
325 token_type_ids = tokenizer_->GenerateTypeId(encoded1, encoded2);
326 }
327
328 std::vector<int64_t> attention_mask(input_ids.size(), 1);
329
330 std::vector<int64_t> output_dim{static_cast<int64_t>(input_ids.size())};
331
332 SetOutput(context, 0, output_dim, input_ids);
333 SetOutput(context, 1, output_dim, token_type_ids);
334 SetOutput(context, 2, output_dim, attention_mask);
335}
336
337void* CustomOpBertTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
338 return CreateKernelImpl(api, info);
339}
340
341const char* CustomOpBertTokenizer::GetName() const { return "BertTokenizer"; }
342
343size_t CustomOpBertTokenizer::GetInputTypeCount() const {
344 return 1;
345}
346
347ONNXTensorElementDataType CustomOpBertTokenizer::GetInputType(size_t /* index */) const {
348 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
349}
350
351size_t CustomOpBertTokenizer::GetOutputTypeCount() const {
352 return 3;
353}
354
355ONNXTensorElementDataType CustomOpBertTokenizer::GetOutputType(size_t /* index */) const {
356 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
357}
358
359KernelHfBertTokenizer::KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo* info) : KernelBertTokenizer(api, info) {}
360
361void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
362 // Setup inputs
363 const OrtValue *const input = ort_.KernelContext_GetInput(context, 0);
364 std::vector<std::string> input_data;
365 GetTensorMutableDataString(api_, ort_, context, input, input_data);
366
367 if (input_data.size() != 2) {
368 ORTX_CXX_API_THROW("[HfBertTokenizer]: Support only two input strings.", ORT_INVALID_GRAPH);
369 }
370
371 std::vector<ustring> tokens1 = tokenizer_->Tokenize(ustring(input_data[0]));
372 std::vector<ustring> tokens2 = tokenizer_->Tokenize(ustring(input_data[1]));
373 std::vector<int64_t> encoded1 = tokenizer_->Encode(tokens1);
374 std::vector<int64_t> encoded2 = tokenizer_->Encode(tokens2);
375 std::vector<int64_t> input_ids = tokenizer_->AddSpecialToken(encoded1, encoded2);
376 std::vector<int64_t> token_type_ids = tokenizer_->GenerateTypeId(encoded1, encoded2);
377 std::vector<int64_t> attention_mask(input_ids.size(), 1LL);
378
379 const std::vector<int64_t> outer_dims{1LL, static_cast<int64_t>(input_ids.size())};
380 const std::vector<int64_t> inner_dims{1LL};
381 for (int32_t i = 0; i < 3; ++i) {
382 OrtValue* const value = ort_.KernelContext_GetOutput(context, i, outer_dims.data(), outer_dims.size());
383 OrtTensorTypeAndShapeInfo *const info = ort_.GetTensorTypeAndShape(value);
384 ort_.SetDimensions(info, inner_dims.data(), inner_dims.size());
385 ort_.ReleaseTensorTypeAndShapeInfo(info);
386 }
387
388 SetOutput(context, 0, outer_dims, input_ids);
389 SetOutput(context, 1, outer_dims, attention_mask);
390 SetOutput(context, 2, outer_dims, token_type_ids);
391}
392
393void* CustomOpHfBertTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
394 return CreateKernelImpl(api, info);
395}
396
397const char* CustomOpHfBertTokenizer::GetName() const { return "HfBertTokenizer"; }
398
399size_t CustomOpHfBertTokenizer::GetInputTypeCount() const {
400 return 1;
401}
402
403ONNXTensorElementDataType CustomOpHfBertTokenizer::GetInputType(size_t /* index */) const {
404 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
405}
406
407size_t CustomOpHfBertTokenizer::GetOutputTypeCount() const {
408 return 3;
409}
410
411ONNXTensorElementDataType CustomOpHfBertTokenizer::GetOutputType(size_t /* index */) const {
412 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
413}
414