microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
e0d48e255f28e5465f63e7fc141df1e1d533cc40

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/tokenizer/bert_tokenizer_decoder.cc

204lines · modecode

1#include "bert_tokenizer_decoder.hpp"
2
3BertTokenizerDecoder::BertTokenizerDecoder(
4 std::string vocab,
5 std::string unk_token,
6 std::string sep_token,
7 std::string pad_token,
8 std::string cls_token,
9 std::string mask_token,
10 std::string suffix_indicator
11) :
12 unk_token_(unk_token),
13 suffix_indicator_(suffix_indicator),
14 raw_vocab_(vocab)
15{
16 auto tokens = SplitString(raw_vocab_, "\n", true);
17 vocab_.reserve(tokens.size());
18 for (size_t i = 0; i < tokens.size(); i++) {
19 auto& token = tokens[i];
20 if (token == unk_token) {
21 unk_token_id_ = static_cast<int32_t>(i);
22 }
23 if (token == sep_token) {
24 sep_token_id_ = static_cast<int32_t>(i);
25 }
26 if (token == pad_token) {
27 sep_token_id_ = static_cast<int32_t>(i);
28 }
29 if (token == cls_token) {
30 cls_token_id_ = static_cast<int32_t>(i);
31 }
32 if (token == mask_token) {
33 mask_token_id_ = static_cast<int32_t>(i);
34 }
35
36 if (token.rfind(suffix_indicator_, 0) == 0) {
37 vocab_.emplace_back(token.substr(suffix_indicator.size(), token.size() - suffix_indicator.size()));
38 is_substr_.push_back(true);
39 } else {
40 vocab_.push_back(token);
41 is_substr_.push_back(false);
42 }
43 }
44}
45
46std::string BertTokenizerDecoder::Decode(const std::vector<int64_t>& ids, bool skip_special_tokens, bool clean_up_tokenization_spaces) {
47 std::string result;
48 int64_t pre_token = -1;
49
50 for (auto id : ids) {
51 if (skip_special_tokens && (id == sep_token_id_ || id == pad_token_id_ || id == cls_token_id_ || id == mask_token_id_)) {
52 continue;
53 }
54
55 // deal with unk ids
56 if (id < 0 || static_cast<size_t>(id) >= vocab_.size()) {
57 if (!result.empty()) {
58 result.push_back(' ');
59 }
60 result.append(unk_token_);
61 continue;
62 }
63
64 // skip first substr
65 if (result.empty() && is_substr_[static_cast<size_t>(id)]) {
66 continue;
67 }
68
69 // At following situations, we needn't add space
70 // we needn't add a space at the beginning of the output
71 // we needn't add a space when the token is a substr (such as ##ing)
72 // we needn't add a space at the left or right of punctuation (such as client-side shouldn't be client - side), when clean_up_tokenization_spaces is true
73 if (!(result.empty() || is_substr_[static_cast<size_t>(id)] || (clean_up_tokenization_spaces && RemoveTokenizeSpace(pre_token, id)))) {
74 result.push_back(' ');
75 }
76
77 result.append(vocab_[static_cast<size_t>(id)]);
78 pre_token = id;
79 }
80
81 return result;
82}
83
84bool BertTokenizerDecoder::RemoveTokenizeSpace(int64_t pre_token_id, int64_t new_token_id) {
85
86 if (pre_token_id < 0) {
87 return true;
88 }
89
90 auto pre_char = ustring(vocab_[static_cast<size_t>(pre_token_id)]).back();
91 auto cur_char = ustring(vocab_[static_cast<size_t>(new_token_id)])[0];
92
93 // normal punctuation
94 if (cur_char == U'!' || cur_char == U'.' || cur_char == U'?' || cur_char == U',' || cur_char == '~' || cur_char == ':') {
95 return true;
96 }
97
98 // only remove left side space
99 if (cur_char == U'}' || cur_char == U']' || cur_char == U'>' || cur_char == ')') {
100 return true;
101 }
102
103 // only remove right side space
104 if (pre_char == U'{' || pre_char == U'[' || pre_char == U'<' || pre_char == '(' || pre_char == '$') {
105 return true;
106 }
107
108 // remove both side space
109 if (pre_char == U'-' || pre_char == U'\'' || pre_char == U'"' || pre_char == U'/' || pre_char == U'@' || pre_char == U'\\' ||
110 cur_char == U'-' || cur_char == U'\'' || cur_char == U'"' || cur_char == U'/' || cur_char == U'@' || cur_char == U'\\') {
111 return true;
112 }
113
114 // remove both space beside unicode punctuation
115 if (pre_char > 128 && IsPunct(pre_char)) {
116 return true;
117 }
118
119 if (cur_char > 128 && IsPunct(cur_char)) {
120 return true;
121 }
122
123 return false;
124}
125
126KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
127 std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
128 std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]"));
129 std::string sep_token = TryToGetAttributeWithDefault("sep_token", std::string("[SEP]"));
130 std::string pad_token = TryToGetAttributeWithDefault("pad_token", std::string("[PAD]"));
131 std::string cls_token = TryToGetAttributeWithDefault("cls_token", std::string("[CLS]"));
132 std::string mask_token = TryToGetAttributeWithDefault("mask_token", std::string("[MASK]"));
133 std::string suffix_indicator = TryToGetAttributeWithDefault("suffix_indicator", std::string("##"));
134
135 use_indices_ = TryToGetAttributeWithDefault("use_indices", false);
136 skip_special_tokens_ = TryToGetAttributeWithDefault("skip_special_tokens", false);
137 clean_up_tokenization_spaces_ = TryToGetAttributeWithDefault("clean_up_tokenization_spaces", true);
138
139 decoder_ = std::make_shared<BertTokenizerDecoder>(vocab, unk_token, sep_token, pad_token,
140 cls_token, mask_token, suffix_indicator);
141}
142
143void KernelBertTokenizerDecoder::Compute(OrtKernelContext* context) {
144 const OrtValue* ids = ort_.KernelContext_GetInput(context, 0);
145 const int64_t* p_ids = ort_.GetTensorData<int64_t>(ids);
146 OrtTensorDimensions ids_dim(ort_, ids);
147
148 if (!((ids_dim.size() == 1) || (ids_dim.size() == 2 && ids_dim[0] == 1))) {
149 ORTX_CXX_API_THROW("[BertTokenizerDecoder]: Expect ids dimension [n] or [1,n].", ORT_INVALID_GRAPH);
150 }
151
152 // const int64_t* p_row_indices = ort_row_indices_dim.empty() ? nullptr : ort_.GetTensorData<int64_t>(ort_row_indices);
153 const OrtValue* positions = ort_.KernelContext_GetInput(context, 1);
154 OrtTensorDimensions positions_dim(ort_, positions);
155 if (use_indices_ &&
156 (!((positions_dim.Size() == 0) ||
157 (positions_dim.size() == 2 && positions_dim[1] == 2)))) {
158 ORTX_CXX_API_THROW("[BertTokenizerDecoder]: Expect positions empty or a [n, 2] matrix when use indices", ORT_INVALID_GRAPH);
159 }
160
161 const int64_t* p_positions = positions_dim.Size() == 0 ? nullptr : ort_.GetTensorData<int64_t>(positions);
162
163 std::vector<std::string> result;
164 std::vector<int64_t> output_dim(1);
165 if (!use_indices_) {
166 result.push_back(decoder_->Decode(std::vector<int64_t>(p_ids, p_ids + ids_dim.Size()), skip_special_tokens_, clean_up_tokenization_spaces_));
167 output_dim[0] = 1;
168 } else {
169 if (p_positions != nullptr) {
170 for (int i = 0; i < positions_dim[0]; i++) {
171 int64_t start = p_positions[2 * i];
172 int64_t end = p_positions[2 * i + 1];
173
174 result.push_back(decoder_->Decode(std::vector<int64_t>(p_ids + start, p_ids + end), skip_special_tokens_, clean_up_tokenization_spaces_));
175 }
176 output_dim[0] = positions_dim[0];
177 }
178 }
179 OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size());
180
181 FillTensorDataString(api_, ort_, context, result, output);
182}
183
184void* CustomOpBertTokenizerDecoder::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
185 return CreateKernelImpl(api, info);
186};
187
188const char* CustomOpBertTokenizerDecoder::GetName() const { return "BertTokenizerDecoder"; };
189
190size_t CustomOpBertTokenizerDecoder::GetInputTypeCount() const {
191 return 2;
192};
193
194ONNXTensorElementDataType CustomOpBertTokenizerDecoder::GetInputType(size_t /*index*/) const {
195 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
196};
197
198size_t CustomOpBertTokenizerDecoder::GetOutputTypeCount() const {
199 return 1;
200};
201
202ONNXTensorElementDataType CustomOpBertTokenizerDecoder::GetOutputType(size_t /*index*/) const {
203 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
204};
205