microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
9eef22cb81d762f6c093a4740c992582267a783f

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/tokenizer/bert_tokenizer.cc

326lines · modecode

1#include "bert_tokenizer.hpp"
2
3#include <utility>
4
5BertTokenizerVocab::BertTokenizerVocab(std::string vocab) : raw_vocab_(vocab) {
6 auto tokens = SplitString(raw_vocab_, "\n", true);
7
8 for (int i = 0; i < tokens.size(); i++) {
9 (vocab_)[tokens[i]] = i;
10 }
11}
12
13bool BertTokenizerVocab::FindToken(const ustring& token) {
14 auto utf8_token = std::string(token);
15
16 return vocab_.find(utf8_token) != vocab_.end();
17}
18
19bool BertTokenizerVocab::FindTokenId(const ustring& token, int32_t& token_id) {
20 auto utf8_token = std::string(token);
21
22 auto it = vocab_.find(utf8_token);
23 if (it == vocab_.end()) {
24 return false;
25 }
26
27 token_id = it->second;
28 return true;
29}
30
31int32_t BertTokenizerVocab::FindTokenId(const ustring& token) {
32 auto utf8_token = std::string(token);
33
34 auto it = vocab_.find(utf8_token);
35 if (it == vocab_.end()) {
36 ORT_CXX_API_THROW("[BertTokenizerVocab]: can not find tokens: " + std::string(token), ORT_RUNTIME_EXCEPTION);
37 }
38
39 return it->second;
40}
41
42WordpieceTokenizer::WordpieceTokenizer(std::shared_ptr<BertTokenizerVocab> vocab, ustring unk_token,
43 ustring suffix_indicator, int max_input_chars_per_word): vocab_(std::move(vocab)), unk_token_(std::move(unk_token)),
44 suffix_indicator_(std::move(suffix_indicator)), max_input_chars_per_word_(max_input_chars_per_word) {
45 unk_token_id_ = vocab_->FindTokenId(unk_token_);
46}
47
48std::vector<ustring> WordpieceTokenizer::Tokenize(const ustring& text) {
49 std::vector<ustring> result;
50 ustring token;
51 for (auto c : text) {
52 if (c == U' ' && !token.empty()) {
53 GreedySearch(token, result);
54 token.clear();
55 continue;
56 }
57
58 token.push_back(c);
59 }
60
61 if (!token.empty()) {
62 GreedySearch(token, result);
63 }
64
65 return result;
66}
67
68std::vector<ustring> WordpieceTokenizer::Tokenize(const std::vector<ustring>& tokens) {
69 std::vector<ustring> result;
70 for (const auto& token : tokens) {
71 GreedySearch(token, result);
72 }
73
74 return result;
75}
76
77std::vector<int64_t> WordpieceTokenizer::Encode(const std::vector<ustring>& tokens) {
78 std::vector<int64_t> ids;
79 for (const auto& token : tokens) {
80 int32_t token_id = -1;
81 if (!vocab_->FindTokenId(token, token_id)) {
82 ids.push_back(unk_token_id_);
83 continue;
84 }
85
86 ids.push_back(token_id);
87 }
88 return ids;
89}
90
91void WordpieceTokenizer::GreedySearch(const ustring& token, std::vector<ustring>& tokenized_result) {
92 if (token.size() > max_input_chars_per_word_) {
93 tokenized_result.push_back(unk_token_);
94 return;
95 }
96
97 int start = 0;
98 int end = -1;
99 ustring substr;
100 for (; start < token.size();) {
101 end = token.size();
102 bool is_found = false;
103 // try to found the longest matched sub-token in vocab
104 for (; start < end;) {
105 substr = static_cast<const ustring>(token.substr(start, end - start));
106 if (start > 0) {
107 substr = static_cast<const ustring>(suffix_indicator_ + substr);
108 }
109 if (vocab_->FindToken(substr)) {
110 is_found = true;
111 break;
112 }
113 end -= 1;
114 }
115 // token not found in vocab
116 if (!is_found) {
117 tokenized_result.push_back(unk_token_);
118 break;
119 }
120
121 tokenized_result.push_back(substr);
122 start = end;
123 }
124}
125
126void TruncateStrategy::Truncate(std::vector<int64_t>& ids, int64_t max_len) {
127 if (max_len < 0 || max_len >= ids.size()) {
128 return;
129 }
130
131 ids.resize(max_len);
132}
133
134void TruncateStrategy::Truncate(std::vector<int64_t>& input1, std::vector<int64_t>& input2, int64_t max_len) {
135 if (max_len < 0 || (input1.size() + input2.size() <= max_len)) {
136 return;
137 }
138
139 auto input1_keep_len = input1.size();
140 auto input2_keep_len = input2.size();
141 auto half_max_len = max_len / 2;
142
143 switch (strategy_) {
144 case TruncateStrategyType::LONGEST_FIRST:
145 case TruncateStrategyType::LONGEST_FROM_BACK:
146
147 if ((input1_keep_len > half_max_len) && (input2_keep_len > half_max_len)) {
148 input1_keep_len = max_len - half_max_len;
149 input2_keep_len = half_max_len;
150 } else if (input2_keep_len > input1_keep_len) {
151 input2_keep_len = max_len - input1_keep_len;
152 } else {
153 input1_keep_len = max_len - input2_keep_len;
154 }
155
156 if (strategy_ == TruncateStrategyType::LONGEST_FIRST) {
157 input1.resize(input1_keep_len);
158 input2.resize(input2_keep_len);
159 } else {
160 input1.erase(input1.begin(), input1.end() - input1_keep_len);
161 input2.erase(input2.begin(), input2.end() - input2_keep_len);
162 }
163
164 return;
165 case TruncateStrategyType::ONLY_FIRST:
166 return;
167 case TruncateStrategyType::ONLY_SECOND:
168 return;
169 default:
170 return;
171 }
172}
173
174BertTokenizer::BertTokenizer(std::string vocab, bool do_lower_case, bool do_basic_tokenize, ustring unk_token, ustring sep_token,
175 ustring pad_token, ustring cls_token, ustring mask_token, bool tokenize_chinese_chars, bool strip_accents,
176 ustring suffix_indicator) : do_basic_tokenize_(do_basic_tokenize) {
177 vocab_ = std::make_shared<BertTokenizerVocab>(vocab);
178
179 if (do_basic_tokenize) {
180 basic_tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents, true, true);
181 }
182 wordpiece_tokenizer_ = std::make_shared<WordpieceTokenizer>(vocab_, unk_token, suffix_indicator);
183
184 unk_token_id_ = vocab_->FindTokenId(unk_token);
185 sep_token_id_ = vocab_->FindTokenId(sep_token);
186 pad_token_id_ = vocab_->FindTokenId(pad_token);
187 cls_token_id_ = vocab_->FindTokenId(cls_token);
188 mask_token_id_ = vocab_->FindTokenId(mask_token);
189}
190std::vector<ustring> BertTokenizer::Tokenize(const ustring& text) {
191 if (do_basic_tokenize_) {
192 return wordpiece_tokenizer_->Tokenize(basic_tokenizer_->Tokenize(text));
193 }
194 return wordpiece_tokenizer_->Tokenize(text);
195}
196
197std::vector<int64_t> BertTokenizer::Encode(const std::vector<ustring>& tokens) {
198 return wordpiece_tokenizer_->Encode(tokens);
199}
200
201std::vector<int64_t> BertTokenizer::AddSpecialToken(const std::vector<int64_t>& ids) {
202 std::vector<int64_t> result;
203 result.reserve(ids.size() + 2);
204 result.push_back(cls_token_id_);
205 result.insert(result.end(), ids.begin(), ids.end());
206 result.push_back(sep_token_id_);
207 return result;
208}
209
210std::vector<int64_t> BertTokenizer::AddSpecialToken(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2) {
211 std::vector<int64_t> result;
212 result.reserve(ids1.size() + ids2.size() + 3);
213 result.push_back(cls_token_id_);
214 result.insert(result.end(), ids1.begin(), ids1.end());
215 result.push_back(sep_token_id_);
216 result.insert(result.end(), ids2.begin(), ids2.end());
217 result.push_back(sep_token_id_);
218 return result;
219}
220
221std::vector<int64_t> BertTokenizer::GenerateTypeId(const std::vector<int64_t>& ids) {
222 return std::vector<int64_t>(ids.size() + 2, 0);
223}
224
225std::vector<int64_t> BertTokenizer::GenerateTypeId(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2) {
226 std::vector<int64_t> result;
227 result.reserve(ids1.size() + ids2.size() + 3);
228 result.insert(result.end(), ids1.size() + 2, 0);
229 result.insert(result.end(), ids2.size() + 1, 1);
230 return result;
231}
232
233TruncateStrategy::TruncateStrategy(std::string strategy_name) {
234 if (strategy_name == "longest_first") {
235 strategy_ = TruncateStrategyType::LONGEST_FIRST;
236 } else if (strategy_name == "only_first") {
237 strategy_ = TruncateStrategyType::ONLY_FIRST;
238 } else if (strategy_name == "only_second") {
239 strategy_ = TruncateStrategyType::ONLY_SECOND;
240 } else if (strategy_name == "longest_from_back") {
241 strategy_ = TruncateStrategyType::LONGEST_FROM_BACK;
242 }
243}
244
245KernelBertTokenizer::KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
246 std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
247 bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
248 bool do_basic_tokenize = TryToGetAttributeWithDefault("do_basic_tokenize", true);
249 std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]"));
250 std::string sep_token = TryToGetAttributeWithDefault("sep_token", std::string("[SEP]"));
251 std::string pad_token = TryToGetAttributeWithDefault("pad_token", std::string("[PAD]"));
252 std::string cls_token = TryToGetAttributeWithDefault("cls_token", std::string("[CLS]"));
253 std::string mask_token = TryToGetAttributeWithDefault("mask_token", std::string("[MASK]"));
254 bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
255 bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
256 std::string suffix_indicator = TryToGetAttributeWithDefault("suffix_indicator", std::string("##"));
257 std::string truncation_strategy_name = TryToGetAttributeWithDefault("truncation_strategy_name", std::string("longest_first"));
258 max_length_ = TryToGetAttributeWithDefault("max_length", int64_t(-1));
259
260 tokenizer_ = std::make_shared<BertTokenizer>(vocab, do_lower_case, do_basic_tokenize, ustring(unk_token),
261 ustring(sep_token), ustring(pad_token), ustring(cls_token),
262 ustring(mask_token), tokenize_chinese_chars, strip_accents, ustring(suffix_indicator));
263
264 truncate_ = std::make_shared<TruncateStrategy>(truncation_strategy_name);
265}
266
267void KernelBertTokenizer::Compute(OrtKernelContext* context) {
268 // Setup inputs
269 const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
270 std::vector<std::string> input_data;
271 GetTensorMutableDataString(api_, ort_, context, input, input_data);
272
273 if (input_data.size() != 1 && input_data.size() != 2) {
274 ORT_CXX_API_THROW("[BertTokenizer]: only support one or two query.", ORT_INVALID_GRAPH);
275 }
276 std::vector<int64_t> input_ids;
277 std::vector<int64_t> token_type_ids;
278
279 if (input_data.size() == 1 || input_data[1].empty()) {
280 std::vector<int64_t> encode = tokenizer_->Encode(tokenizer_->Tokenize(ustring(input_data[0])));
281 truncate_->Truncate(encode, (max_length_ > 0 && max_length_ <= 2) ? 0 : max_length_ - 2);
282 input_ids = tokenizer_->AddSpecialToken(encode);
283 token_type_ids = tokenizer_->GenerateTypeId(encode);
284 } else if (input_data[0].empty()) {
285 std::vector<int64_t> encode = tokenizer_->Encode(tokenizer_->Tokenize(ustring(input_data[1])));
286 truncate_->Truncate(encode, (max_length_ > 0 && max_length_ <= 2) ? 0 : max_length_ - 2);
287 input_ids = tokenizer_->AddSpecialToken(encode);
288 token_type_ids = tokenizer_->GenerateTypeId(encode);
289 } else {
290 std::vector<int64_t> encode1 = tokenizer_->Encode(tokenizer_->Tokenize(ustring(input_data[0])));
291 std::vector<int64_t> encode2 = tokenizer_->Encode(tokenizer_->Tokenize(ustring(input_data[1])));
292 truncate_->Truncate(encode1, encode2, (max_length_ > 0 && max_length_ <= 3) ? 0 : max_length_ - 3);
293 input_ids = tokenizer_->AddSpecialToken(encode1, encode2);
294 token_type_ids = tokenizer_->GenerateTypeId(encode1, encode2);
295 }
296
297 std::vector<int64_t> attention_mask(input_ids.size(), 1);
298
299 std::vector<int64_t> output_dim({static_cast<int64_t>(input_ids.size())});
300
301 SetOutput(context, 0, output_dim, input_ids);
302 SetOutput(context, 1, output_dim, token_type_ids);
303 SetOutput(context, 2, output_dim, attention_mask);
304}
305
306void* CustomOpBertTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
307 return new KernelBertTokenizer(api, info);
308};
309
310const char* CustomOpBertTokenizer::GetName() const { return "BertTokenizer"; };
311
312size_t CustomOpBertTokenizer::GetInputTypeCount() const {
313 return 1;
314};
315
316ONNXTensorElementDataType CustomOpBertTokenizer::GetInputType(size_t /*index*/) const {
317 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
318};
319
320size_t CustomOpBertTokenizer::GetOutputTypeCount() const {
321 return 3;
322};
323
324ONNXTensorElementDataType CustomOpBertTokenizer::GetOutputType(size_t /*index*/) const {
325 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
326};
327