microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
rel-0.7

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/tokenizer/bert_tokenizer.cc

400lines · modecode

1#include "bert_tokenizer.hpp"
2
3#include <utility>
4
5BertTokenizerVocab::BertTokenizerVocab(std::string_view vocab) : raw_vocab_(vocab) {
6 auto tokens = SplitString(raw_vocab_, "\r\n", true);
7
8 for (size_t i = 0; i < tokens.size(); i++) {
9 (vocab_)[tokens[i]] = static_cast<int32_t>(i);
10 }
11}
12
13bool BertTokenizerVocab::FindToken(const ustring& token) {
14 auto utf8_token = std::string(token);
15
16 return vocab_.find(utf8_token) != vocab_.end();
17}
18
19bool BertTokenizerVocab::FindTokenId(const ustring& token, int32_t& token_id) {
20 auto utf8_token = std::string(token);
21
22 auto it = vocab_.find(utf8_token);
23 if (it == vocab_.end()) {
24 return false;
25 }
26
27 token_id = it->second;
28 return true;
29}
30
31int32_t BertTokenizerVocab::FindTokenId(const ustring& token) {
32 auto utf8_token = std::string(token);
33
34 auto it = vocab_.find(utf8_token);
35 if (it == vocab_.end()) {
36 ORTX_CXX_API_THROW("[BertTokenizerVocab]: can not find tokens: " + std::string(token), ORT_RUNTIME_EXCEPTION);
37 }
38
39 return it->second;
40}
41
42WordpieceTokenizer::WordpieceTokenizer(
43 std::shared_ptr<BertTokenizerVocab> vocab,
44 ustring unk_token,
45 ustring suffix_indicator,
46 int max_input_chars_per_word) : max_input_chars_per_word_(max_input_chars_per_word),
47 suffix_indicator_(std::move(suffix_indicator)),
48 unk_token_(std::move(unk_token)),
49 vocab_(std::move(vocab)) {
50 unk_token_id_ = vocab_->FindTokenId(unk_token_);
51}
52
53std::vector<ustring> WordpieceTokenizer::Tokenize(const ustring& text) {
54 std::vector<ustring> result;
55 ustring token;
56 for (auto c : text) {
57 if (c == U' ' && !token.empty()) {
58 GreedySearch(token, result);
59 token.clear();
60 continue;
61 }
62
63 token.push_back(c);
64 }
65
66 if (!token.empty()) {
67 GreedySearch(token, result);
68 }
69
70 return result;
71}
72
73std::vector<ustring> WordpieceTokenizer::Tokenize(const std::vector<ustring>& tokens) {
74 std::vector<ustring> result;
75 for (const auto& token : tokens) {
76 GreedySearch(token, result);
77 }
78
79 return result;
80}
81
82std::vector<int64_t> WordpieceTokenizer::Encode(const std::vector<ustring>& tokens) {
83 std::vector<int64_t> ids;
84 for (const auto& token : tokens) {
85 int32_t token_id = -1;
86 if (!vocab_->FindTokenId(token, token_id)) {
87 ids.push_back(unk_token_id_);
88 continue;
89 }
90
91 ids.push_back(token_id);
92 }
93 return ids;
94}
95
96void WordpieceTokenizer::GreedySearch(const ustring& token, std::vector<ustring>& tokenized_result) {
97 if (static_cast<int64_t>(token.size()) > max_input_chars_per_word_) {
98 tokenized_result.push_back(unk_token_);
99 return;
100 }
101
102 size_t start = 0;
103 size_t end = 0;
104 ustring substr;
105 for (; start < token.size();) {
106 end = token.size();
107 bool is_found = false;
108 // try to found the longest matched sub-token in vocab
109 for (; start < end;) {
110 substr = static_cast<const ustring>(token.substr(start, end - start));
111 if (start > 0) {
112 substr = static_cast<const ustring>(suffix_indicator_ + substr);
113 }
114 if (vocab_->FindToken(substr)) {
115 is_found = true;
116 break;
117 }
118 end -= 1;
119 }
120 // token not found in vocab
121 if (!is_found) {
122 tokenized_result.push_back(unk_token_);
123 break;
124 }
125
126 tokenized_result.push_back(substr);
127 start = end;
128 }
129}
130
131void TruncateStrategy::Truncate(std::vector<int64_t>& ids, int32_t max_len) {
132 if ((max_len > 0) && (static_cast<size_t>(max_len) < ids.size())) {
133 ids.resize(max_len);
134 }
135}
136
137void TruncateStrategy::Truncate(std::vector<int64_t>& ids1, std::vector<int64_t>& ids2, int32_t max_len) {
138 if (max_len < 0 || (ids1.size() + ids2.size() <= static_cast<size_t>(max_len))) {
139 return;
140 }
141
142 auto ids1_keep_len = ids1.size();
143 auto ids2_keep_len = ids2.size();
144 auto half_max_len = max_len / 2;
145
146 switch (strategy_) {
147 case TruncateStrategyType::LONGEST_FIRST:
148 case TruncateStrategyType::LONGEST_FROM_BACK:
149
150 if ((ids1_keep_len > static_cast<size_t>(half_max_len)) && (ids2_keep_len > static_cast<size_t>(half_max_len))) {
151 ids1_keep_len = static_cast<size_t>(max_len) - half_max_len;
152 ids2_keep_len = half_max_len;
153 } else if (ids2_keep_len > ids1_keep_len) {
154 ids2_keep_len = static_cast<size_t>(max_len) - ids1_keep_len;
155 } else {
156 ids1_keep_len = static_cast<size_t>(max_len) - ids2_keep_len;
157 }
158
159 if (strategy_ == TruncateStrategyType::LONGEST_FIRST) {
160 ids1.resize(ids1_keep_len);
161 ids2.resize(ids2_keep_len);
162 } else {
163 ids1.erase(ids1.begin(), ids1.end() - ids1_keep_len);
164 ids2.erase(ids2.begin(), ids2.end() - ids2_keep_len);
165 }
166
167 return;
168 case TruncateStrategyType::ONLY_FIRST:
169 return;
170 case TruncateStrategyType::ONLY_SECOND:
171 return;
172 default:
173 return;
174 }
175}
176
177BertTokenizer::BertTokenizer(
178 const std::string& vocab,
179 bool do_lower_case,
180 bool do_basic_tokenize,
181 ustring unk_token,
182 ustring sep_token,
183 ustring pad_token,
184 ustring cls_token,
185 ustring mask_token,
186 bool tokenize_chinese_chars,
187 bool strip_accents,
188 ustring suffix_indicator,
189 int32_t max_len,
190 const std::string& truncation_strategy) : max_length_(max_len),
191 do_basic_tokenize_(do_basic_tokenize),
192 truncate_(std::make_unique<TruncateStrategy>(truncation_strategy)) {
193 vocab_ = std::make_shared<BertTokenizerVocab>(vocab);
194
195 if (do_basic_tokenize) {
196 basic_tokenizer_ = std::make_unique<BasicTokenizer>(
197 do_lower_case, tokenize_chinese_chars, strip_accents, true, true);
198 }
199 wordpiece_tokenizer_ = std::make_unique<WordpieceTokenizer>(
200 vocab_, unk_token, suffix_indicator);
201
202 unk_token_id_ = vocab_->FindTokenId(unk_token);
203 sep_token_id_ = vocab_->FindTokenId(sep_token);
204 pad_token_id_ = vocab_->FindTokenId(pad_token);
205 cls_token_id_ = vocab_->FindTokenId(cls_token);
206 mask_token_id_ = vocab_->FindTokenId(mask_token);
207}
208
209std::vector<ustring> BertTokenizer::Tokenize(const ustring& text) {
210 if (do_basic_tokenize_) {
211 return wordpiece_tokenizer_->Tokenize(basic_tokenizer_->Tokenize(text));
212 }
213 return wordpiece_tokenizer_->Tokenize(text);
214}
215
216std::vector<int64_t> BertTokenizer::Encode(const std::vector<ustring>& tokens) {
217 return wordpiece_tokenizer_->Encode(tokens);
218}
219
220void BertTokenizer::Truncate(std::vector<int64_t>& ids) {
221 truncate_->Truncate(ids, (max_length_ > 0 && max_length_ <= 2) ? 0 : max_length_ - 2);
222}
223
224void BertTokenizer::Truncate(std::vector<int64_t>& ids1, std::vector<int64_t>& ids2) {
225 truncate_->Truncate(ids1, ids2, (max_length_ > 0 && max_length_ <= 3) ? 0 : max_length_ - 3);
226}
227
228std::vector<int64_t> BertTokenizer::AddSpecialToken(const std::vector<int64_t>& ids) {
229 std::vector<int64_t> result;
230 result.reserve(ids.size() + 2);
231 result.push_back(cls_token_id_);
232 result.insert(result.end(), ids.begin(), ids.end());
233 result.push_back(sep_token_id_);
234 return result;
235}
236
237std::vector<int64_t> BertTokenizer::AddSpecialToken(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2) {
238 std::vector<int64_t> result;
239 result.reserve(ids1.size() + ids2.size() + 3);
240 result.push_back(cls_token_id_);
241 result.insert(result.end(), ids1.begin(), ids1.end());
242 result.push_back(sep_token_id_);
243 result.insert(result.end(), ids2.begin(), ids2.end());
244 result.push_back(sep_token_id_);
245 return result;
246}
247
248std::vector<int64_t> BertTokenizer::GenerateTypeId(const std::vector<int64_t>& ids) {
249 return std::vector<int64_t>(ids.size() + 2, 0);
250}
251
252std::vector<int64_t> BertTokenizer::GenerateTypeId(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2) {
253 std::vector<int64_t> result;
254 result.reserve(ids1.size() + ids2.size() + 3);
255 result.insert(result.end(), ids1.size() + 2, 0);
256 result.insert(result.end(), ids2.size() + 1, 1);
257 return result;
258}
259
260TruncateStrategy::TruncateStrategy(std::string_view strategy_name) : strategy_(TruncateStrategyType::LONGEST_FIRST) {
261 if (strategy_name == "longest_first") {
262 strategy_ = TruncateStrategyType::LONGEST_FIRST;
263 } else if (strategy_name == "only_first") {
264 strategy_ = TruncateStrategyType::ONLY_FIRST;
265 } else if (strategy_name == "only_second") {
266 strategy_ = TruncateStrategyType::ONLY_SECOND;
267 } else if (strategy_name == "longest_from_back") {
268 strategy_ = TruncateStrategyType::LONGEST_FROM_BACK;
269 }
270}
271
272KernelBertTokenizer::KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
273 std::string vocab = ort_.KernelInfoGetAttribute<std::string>(&info, "vocab_file");
274 bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
275 bool do_basic_tokenize = TryToGetAttributeWithDefault("do_basic_tokenize", true);
276 std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]"));
277 std::string sep_token = TryToGetAttributeWithDefault("sep_token", std::string("[SEP]"));
278 std::string pad_token = TryToGetAttributeWithDefault("pad_token", std::string("[PAD]"));
279 std::string cls_token = TryToGetAttributeWithDefault("cls_token", std::string("[CLS]"));
280 std::string mask_token = TryToGetAttributeWithDefault("mask_token", std::string("[MASK]"));
281 bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
282 bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
283 std::string suffix_indicator = TryToGetAttributeWithDefault("suffix_indicator", std::string("##"));
284 std::string truncation_strategy_name = TryToGetAttributeWithDefault("truncation_strategy_name",
285 std::string("longest_first"));
286 int32_t max_len = static_cast<int32_t>(TryToGetAttributeWithDefault("max_length", int64_t(-1)));
287
288 tokenizer_ = std::make_unique<BertTokenizer>(
289 vocab, do_lower_case, do_basic_tokenize, ustring(unk_token),
290 ustring(sep_token), ustring(pad_token), ustring(cls_token),
291 ustring(mask_token), tokenize_chinese_chars, strip_accents,
292 ustring(suffix_indicator), max_len, truncation_strategy_name);
293}
294
295void KernelBertTokenizer::Compute(OrtKernelContext* context) {
296 // Setup inputs
297 const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
298 std::vector<std::string> input_data;
299 GetTensorMutableDataString(api_, ort_, context, input, input_data);
300
301 if (input_data.size() != 1 && input_data.size() != 2) {
302 ORTX_CXX_API_THROW("[BertTokenizer]: only support one or two query.", ORT_INVALID_GRAPH);
303 }
304 std::vector<int64_t> input_ids;
305 std::vector<int64_t> token_type_ids;
306
307 if (input_data.size() == 1) {
308 std::vector<ustring> tokens = tokenizer_->Tokenize(ustring(input_data[0]));
309 std::vector<int64_t> encoded = tokenizer_->Encode(tokens);
310 tokenizer_->Truncate(encoded);
311 input_ids = tokenizer_->AddSpecialToken(encoded);
312 token_type_ids = tokenizer_->GenerateTypeId(encoded);
313 } else {
314 std::vector<ustring> tokens1 = tokenizer_->Tokenize(ustring(input_data[0]));
315 std::vector<ustring> tokens2 = tokenizer_->Tokenize(ustring(input_data[1]));
316 std::vector<int64_t> encoded1 = tokenizer_->Encode(tokens1);
317 std::vector<int64_t> encoded2 = tokenizer_->Encode(tokens2);
318 input_ids = tokenizer_->AddSpecialToken(encoded1, encoded2);
319 token_type_ids = tokenizer_->GenerateTypeId(encoded1, encoded2);
320 }
321
322 std::vector<int64_t> attention_mask(input_ids.size(), 1);
323
324 std::vector<int64_t> output_dim{static_cast<int64_t>(input_ids.size())};
325
326 SetOutput(context, 0, output_dim, input_ids);
327 SetOutput(context, 1, output_dim, token_type_ids);
328 SetOutput(context, 2, output_dim, attention_mask);
329}
330
331const char* CustomOpBertTokenizer::GetName() const { return "BertTokenizer"; }
332
333size_t CustomOpBertTokenizer::GetInputTypeCount() const {
334 return 1;
335}
336
337ONNXTensorElementDataType CustomOpBertTokenizer::GetInputType(size_t /* index */) const {
338 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
339}
340
341size_t CustomOpBertTokenizer::GetOutputTypeCount() const {
342 return 3;
343}
344
345ONNXTensorElementDataType CustomOpBertTokenizer::GetOutputType(size_t /* index */) const {
346 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
347}
348
349KernelHfBertTokenizer::KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo& info)
350 : KernelBertTokenizer(api, info) {}
351
352void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
353 // Setup inputs
354 const OrtValue* const input = ort_.KernelContext_GetInput(context, 0);
355 std::vector<std::string> input_data;
356 GetTensorMutableDataString(api_, ort_, context, input, input_data);
357
358 if (input_data.size() != 2) {
359 ORTX_CXX_API_THROW("[HfBertTokenizer]: Support only two input strings.", ORT_INVALID_GRAPH);
360 }
361
362 std::vector<ustring> tokens1 = tokenizer_->Tokenize(ustring(input_data[0]));
363 std::vector<ustring> tokens2 = tokenizer_->Tokenize(ustring(input_data[1]));
364 std::vector<int64_t> encoded1 = tokenizer_->Encode(tokens1);
365 std::vector<int64_t> encoded2 = tokenizer_->Encode(tokens2);
366 std::vector<int64_t> input_ids = tokenizer_->AddSpecialToken(encoded1, encoded2);
367 std::vector<int64_t> token_type_ids = tokenizer_->GenerateTypeId(encoded1, encoded2);
368 std::vector<int64_t> attention_mask(input_ids.size(), 1LL);
369
370 const std::vector<int64_t> outer_dims{1LL, static_cast<int64_t>(input_ids.size())};
371 const std::vector<int64_t> inner_dims{1LL};
372 for (int32_t i = 0; i < 3; ++i) {
373 OrtValue* const value = ort_.KernelContext_GetOutput(context, i, outer_dims.data(), outer_dims.size());
374 OrtTensorTypeAndShapeInfo* const info = ort_.GetTensorTypeAndShape(value);
375 ort_.SetDimensions(info, inner_dims.data(), inner_dims.size());
376 ort_.ReleaseTensorTypeAndShapeInfo(info);
377 }
378
379 SetOutput(context, 0, outer_dims, input_ids);
380 SetOutput(context, 1, outer_dims, attention_mask);
381 SetOutput(context, 2, outer_dims, token_type_ids);
382}
383
384const char* CustomOpHfBertTokenizer::GetName() const { return "HfBertTokenizer"; }
385
386size_t CustomOpHfBertTokenizer::GetInputTypeCount() const {
387 return 1;
388}
389
390ONNXTensorElementDataType CustomOpHfBertTokenizer::GetInputType(size_t /* index */) const {
391 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
392}
393
394size_t CustomOpHfBertTokenizer::GetOutputTypeCount() const {
395 return 3;
396}
397
398ONNXTensorElementDataType CustomOpHfBertTokenizer::GetOutputType(size_t /* index */) const {
399 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
400}
401