#include "bert_tokenizer.hpp"
#include <utility>
BertTokenizerVocab::BertTokenizerVocab(std::string vocab) : raw_vocab_(vocab) {
auto tokens = SplitString(raw_vocab_, "\n", true);
for (int i = 0; i < tokens.size(); i++) {
(vocab_)[tokens[i]] = i;
}
}
bool BertTokenizerVocab::FindToken(const ustring& token) {
auto utf8_token = std::string(token);
return vocab_.find(utf8_token) != vocab_.end();
}
bool BertTokenizerVocab::FindTokenId(const ustring& token, int32_t& token_id) {
auto utf8_token = std::string(token);
auto it = vocab_.find(utf8_token);
if (it == vocab_.end()) {
return false;
}
token_id = it->second;
return true;
}
int32_t BertTokenizerVocab::FindTokenId(const ustring& token) {
auto utf8_token = std::string(token);
auto it = vocab_.find(utf8_token);
if (it == vocab_.end()) {
ORT_CXX_API_THROW("[BertTokenizerVocab]: can not find tokens: " + std::string(token), ORT_RUNTIME_EXCEPTION);
}
return it->second;
}
WordpieceTokenizer::WordpieceTokenizer(std::shared_ptr<BertTokenizerVocab> vocab, ustring unk_token,
ustring suffix_indicator, int max_input_chars_per_word): vocab_(std::move(vocab)), unk_token_(std::move(unk_token)),
suffix_indicator_(std::move(suffix_indicator)), max_input_chars_per_word_(max_input_chars_per_word) {
unk_token_id_ = vocab_->FindTokenId(unk_token_);
}
std::vector<ustring> WordpieceTokenizer::Tokenize(const ustring& text) {
std::vector<ustring> result;
ustring token;
for (auto c : text) {
if (c == U' ' && !token.empty()) {
GreedySearch(token, result);
token.clear();
continue;
}
token.push_back(c);
}
if (!token.empty()) {
GreedySearch(token, result);
}
return result;
}
std::vector<ustring> WordpieceTokenizer::Tokenize(const std::vector<ustring>& tokens) {
std::vector<ustring> result;
for (const auto& token : tokens) {
GreedySearch(token, result);
}
return result;
}
std::vector<int64_t> WordpieceTokenizer::Encode(const std::vector<ustring>& tokens) {
std::vector<int64_t> ids;
for (const auto& token : tokens) {
int32_t token_id = -1;
if (!vocab_->FindTokenId(token, token_id)) {
ids.push_back(unk_token_id_);
continue;
}
ids.push_back(token_id);
}
return ids;
}
void WordpieceTokenizer::GreedySearch(const ustring& token, std::vector<ustring>& tokenized_result) {
if (token.size() > max_input_chars_per_word_) {
tokenized_result.push_back(unk_token_);
return;
}
int start = 0;
int end = -1;
ustring substr;
for (; start < token.size();) {
end = token.size();
bool is_found = false;
// try to found the longest matched sub-token in vocab
for (; start < end;) {
substr = static_cast<const ustring>(token.substr(start, end - start));
if (start > 0) {
substr = static_cast<const ustring>(suffix_indicator_ + substr);
}
if (vocab_->FindToken(substr)) {
is_found = true;
break;
}
end -= 1;
}
// token not found in vocab
if (!is_found) {
tokenized_result.push_back(unk_token_);
break;
}
tokenized_result.push_back(substr);
start = end;
}
}
void TruncateStrategy::Truncate(std::vector<int64_t>& ids, int64_t max_len) {
if (max_len < 0 || max_len >= ids.size()) {
return;
}
ids.resize(max_len);
}
void TruncateStrategy::Truncate(std::vector<int64_t>& input1, std::vector<int64_t>& input2, int64_t max_len) {
if (max_len < 0 || (input1.size() + input2.size() <= max_len)) {
return;
}
auto input1_keep_len = input1.size();
auto input2_keep_len = input2.size();
auto half_max_len = max_len / 2;
switch (strategy_) {
case TruncateStrategyType::LONGEST_FIRST:
case TruncateStrategyType::LONGEST_FROM_BACK:
if ((input1_keep_len > half_max_len) && (input2_keep_len > half_max_len)) {
input1_keep_len = max_len - half_max_len;
input2_keep_len = half_max_len;
} else if (input2_keep_len > input1_keep_len) {
input2_keep_len = max_len - input1_keep_len;
} else {
input1_keep_len = max_len - input2_keep_len;
}
if (strategy_ == TruncateStrategyType::LONGEST_FIRST) {
input1.resize(input1_keep_len);
input2.resize(input2_keep_len);
} else {
input1.erase(input1.begin(), input1.end() - input1_keep_len);
input2.erase(input2.begin(), input2.end() - input2_keep_len);
}
return;
case TruncateStrategyType::ONLY_FIRST:
return;
case TruncateStrategyType::ONLY_SECOND:
return;
default:
return;
}
}
BertTokenizer::BertTokenizer(std::string vocab, bool do_lower_case, bool do_basic_tokenize, ustring unk_token, ustring sep_token,
ustring pad_token, ustring cls_token, ustring mask_token, bool tokenize_chinese_chars, bool strip_accents,
ustring suffix_indicator) : do_basic_tokenize_(do_basic_tokenize) {
vocab_ = std::make_shared<BertTokenizerVocab>(vocab);
if (do_basic_tokenize) {
basic_tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents, true, true);
}
wordpiece_tokenizer_ = std::make_shared<WordpieceTokenizer>(vocab_, unk_token, suffix_indicator);
unk_token_id_ = vocab_->FindTokenId(unk_token);
sep_token_id_ = vocab_->FindTokenId(sep_token);
pad_token_id_ = vocab_->FindTokenId(pad_token);
cls_token_id_ = vocab_->FindTokenId(cls_token);
mask_token_id_ = vocab_->FindTokenId(mask_token);
}
std::vector<ustring> BertTokenizer::Tokenize(const ustring& text) {
if (do_basic_tokenize_) {
return wordpiece_tokenizer_->Tokenize(basic_tokenizer_->Tokenize(text));
}
return wordpiece_tokenizer_->Tokenize(text);
}
std::vector<int64_t> BertTokenizer::Encode(const std::vector<ustring>& tokens) {
return wordpiece_tokenizer_->Encode(tokens);
}
std::vector<int64_t> BertTokenizer::AddSpecialToken(const std::vector<int64_t>& ids) {
std::vector<int64_t> result;
result.reserve(ids.size() + 2);
result.push_back(cls_token_id_);
result.insert(result.end(), ids.begin(), ids.end());
result.push_back(sep_token_id_);
return result;
}
std::vector<int64_t> BertTokenizer::AddSpecialToken(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2) {
std::vector<int64_t> result;
result.reserve(ids1.size() + ids2.size() + 3);
result.push_back(cls_token_id_);
result.insert(result.end(), ids1.begin(), ids1.end());
result.push_back(sep_token_id_);
result.insert(result.end(), ids2.begin(), ids2.end());
result.push_back(sep_token_id_);
return result;
}
std::vector<int64_t> BertTokenizer::GenerateTypeId(const std::vector<int64_t>& ids) {
return std::vector<int64_t>(ids.size() + 2, 0);
}
std::vector<int64_t> BertTokenizer::GenerateTypeId(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2) {
std::vector<int64_t> result;
result.reserve(ids1.size() + ids2.size() + 3);
result.insert(result.end(), ids1.size() + 2, 0);
result.insert(result.end(), ids2.size() + 1, 1);
return result;
}
TruncateStrategy::TruncateStrategy(std::string strategy_name) {
if (strategy_name == "longest_first") {
strategy_ = TruncateStrategyType::LONGEST_FIRST;
} else if (strategy_name == "only_first") {
strategy_ = TruncateStrategyType::ONLY_FIRST;
} else if (strategy_name == "only_second") {
strategy_ = TruncateStrategyType::ONLY_SECOND;
} else if (strategy_name == "longest_from_back") {
strategy_ = TruncateStrategyType::LONGEST_FROM_BACK;
}
}
KernelBertTokenizer::KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
bool do_basic_tokenize = TryToGetAttributeWithDefault("do_basic_tokenize", true);
std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]"));
std::string sep_token = TryToGetAttributeWithDefault("sep_token", std::string("[SEP]"));
std::string pad_token = TryToGetAttributeWithDefault("pad_token", std::string("[PAD]"));
std::string cls_token = TryToGetAttributeWithDefault("cls_token", std::string("[CLS]"));
std::string mask_token = TryToGetAttributeWithDefault("mask_token", std::string("[MASK]"));
bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
std::string suffix_indicator = TryToGetAttributeWithDefault("suffix_indicator", std::string("##"));
std::string truncation_strategy_name = TryToGetAttributeWithDefault("truncation_strategy_name", std::string("longest_first"));
max_length_ = TryToGetAttributeWithDefault("max_length", int64_t(-1));
tokenizer_ = std::make_shared<BertTokenizer>(vocab, do_lower_case, do_basic_tokenize, ustring(unk_token),
ustring(sep_token), ustring(pad_token), ustring(cls_token),
ustring(mask_token), tokenize_chinese_chars, strip_accents, ustring(suffix_indicator));
truncate_ = std::make_shared<TruncateStrategy>(truncation_strategy_name);
}
void KernelBertTokenizer::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
std::vector<std::string> input_data;
GetTensorMutableDataString(api_, ort_, context, input, input_data);
if (input_data.size() != 1 && input_data.size() != 2) {
ORT_CXX_API_THROW("[BertTokenizer]: only support one or two query.", ORT_INVALID_GRAPH);
}
std::vector<int64_t> input_ids;
std::vector<int64_t> token_type_ids;
if (input_data.size() == 1 || input_data[1].empty()) {
std::vector<int64_t> encode = tokenizer_->Encode(tokenizer_->Tokenize(ustring(input_data[0])));
truncate_->Truncate(encode, (max_length_ > 0 && max_length_ <= 2) ? 0 : max_length_ - 2);
input_ids = tokenizer_->AddSpecialToken(encode);
token_type_ids = tokenizer_->GenerateTypeId(encode);
} else if (input_data[0].empty()) {
std::vector<int64_t> encode = tokenizer_->Encode(tokenizer_->Tokenize(ustring(input_data[1])));
truncate_->Truncate(encode, (max_length_ > 0 && max_length_ <= 2) ? 0 : max_length_ - 2);
input_ids = tokenizer_->AddSpecialToken(encode);
token_type_ids = tokenizer_->GenerateTypeId(encode);
} else {
std::vector<int64_t> encode1 = tokenizer_->Encode(tokenizer_->Tokenize(ustring(input_data[0])));
std::vector<int64_t> encode2 = tokenizer_->Encode(tokenizer_->Tokenize(ustring(input_data[1])));
truncate_->Truncate(encode1, encode2, (max_length_ > 0 && max_length_ <= 3) ? 0 : max_length_ - 3);
input_ids = tokenizer_->AddSpecialToken(encode1, encode2);
token_type_ids = tokenizer_->GenerateTypeId(encode1, encode2);
}
std::vector<int64_t> attention_mask(input_ids.size(), 1);
std::vector<int64_t> output_dim({static_cast<int64_t>(input_ids.size())});
SetOutput(context, 0, output_dim, input_ids);
SetOutput(context, 1, output_dim, token_type_ids);
SetOutput(context, 2, output_dim, attention_mask);
}
void* CustomOpBertTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
return new KernelBertTokenizer(api, info);
};
const char* CustomOpBertTokenizer::GetName() const { return "BertTokenizer"; };
size_t CustomOpBertTokenizer::GetInputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpBertTokenizer::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};
size_t CustomOpBertTokenizer::GetOutputTypeCount() const {
return 3;
};
ONNXTensorElementDataType CustomOpBertTokenizer::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};microsoft/onnxruntime-extensions
Publicmirrored from https://github.com/microsoft/onnxruntime-extensionsAvailable
operators/tokenizer/bert_tokenizer.cc
326lines · modepreview