microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
rel-0.4

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/tokenizer/bert_tokenizer_decoder.cc

195lines · modecode

1#include "bert_tokenizer_decoder.hpp"
2
3BertTokenizerDecoder::BertTokenizerDecoder(std::string vocab, std::string unk_token, std::string sep_token, std::string pad_token,
4 std::string cls_token,std::string mask_token,std::string suffix_indicator) : raw_vocab_(vocab), unk_token_(unk_token), suffix_indicator_(suffix_indicator) {
5 auto tokens = SplitString(raw_vocab_, "\n", true);
6 vocab_.reserve(tokens.size());
7 for (int i = 0; i < tokens.size(); i++) {
8 auto& token = tokens[i];
9 if (token == unk_token) {
10 unk_token_id_ = i;
11 }
12 if (token == sep_token) {
13 sep_token_id_ = i;
14 }
15 if (token == pad_token) {
16 sep_token_id_ = i;
17 }
18 if (token == cls_token) {
19 cls_token_id_ = i;
20 }
21 if (token == mask_token) {
22 mask_token_id_ = i;
23 }
24
25 if (token.rfind(suffix_indicator_, 0) == 0) {
26 vocab_.emplace_back(token.substr(suffix_indicator.size(), token.size() - suffix_indicator.size()));
27 is_substr_.push_back(true);
28 } else {
29 vocab_.push_back(token);
30 is_substr_.push_back(false);
31 }
32 }
33}
34
35std::string BertTokenizerDecoder::Decode(const std::vector<int64_t>& ids, bool skip_special_tokens, bool clean_up_tokenization_spaces) {
36 std::string result;
37 int64_t pre_token = -1;
38
39 for (auto id : ids) {
40 if (skip_special_tokens && (id == unk_token_id_ || id == sep_token_id_ || id == pad_token_id_ || id == cls_token_id_ || id == mask_token_id_)) {
41 continue;
42 }
43
44 // deal with unk ids
45 if (id >= vocab_.size() || id < 0) {
46 if (result.empty()) {
47 result.push_back(' ');
48 }
49 result.append(unk_token_);
50 continue;
51 }
52
53 // skip first substr
54 if (result.empty() && is_substr_[id]) {
55 continue;
56 }
57
58 // At following situations, we needn't add space
59 // we needn't add a space at the beginning of the output
60 // we needn't add a space when the token is a substr (such as ##ing)
61 // we needn't add a space at the left or right of punctuation (such as client-side shouldn't be client - side), when clean_up_tokenization_spaces is true
62 if (!(result.empty() || is_substr_[id] || (clean_up_tokenization_spaces && RemoveTokenizeSpace(pre_token, id)))) {
63 result.push_back(' ');
64 }
65
66 result.append(vocab_[id]);
67 pre_token = id;
68 }
69
70 return result;
71}
72
73bool BertTokenizerDecoder::RemoveTokenizeSpace(int64_t pre_token_id, int64_t new_token_id) {
74
75 if (pre_token_id < 0) {
76 return true;
77 }
78
79 auto pre_char = ustring(vocab_[pre_token_id]).back();
80 auto cur_char = ustring(vocab_[new_token_id])[0];
81
82 // normal punctuation
83 if (cur_char == U'!' || cur_char == U'.' || cur_char == U'?' || cur_char == U',' || cur_char == '~' || cur_char == ':') {
84 return true;
85 }
86
87 // only remove left side space
88 if (cur_char == U'}' || cur_char == U']' || cur_char == U'>' || cur_char == ')') {
89 return true;
90 }
91
92 // only remove right side space
93 if (pre_char == U'{' || pre_char == U'[' || pre_char == U'<' || pre_char == '(' || pre_char == '$') {
94 return true;
95 }
96
97 // remove both side space
98 if (pre_char == U'-' || pre_char == U'\'' || pre_char == U'"' || pre_char == U'/' || pre_char == U'@' || pre_char == U'\\' ||
99 cur_char == U'-' || cur_char == U'\'' || cur_char == U'"' || cur_char == U'/' || cur_char == U'@' || cur_char == U'\\') {
100 return true;
101 }
102
103 // remove both space beside unicode punctuation
104 if (pre_char > 128 && iswpunct(pre_char)) {
105 return true;
106 }
107
108 if (cur_char > 128 && iswpunct(cur_char)) {
109 return true;
110 }
111
112 return false;
113}
114
115KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
116 std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
117 std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]"));
118 std::string sep_token = TryToGetAttributeWithDefault("sep_token", std::string("[SEP]"));
119 std::string pad_token = TryToGetAttributeWithDefault("pad_token", std::string("[PAD]"));
120 std::string cls_token = TryToGetAttributeWithDefault("cls_token", std::string("[CLS]"));
121 std::string mask_token = TryToGetAttributeWithDefault("mask_token", std::string("[MASK]"));
122 std::string suffix_indicator = TryToGetAttributeWithDefault("suffix_indicator", std::string("##"));
123
124 use_indices_ = TryToGetAttributeWithDefault("use_indices", false);
125 skip_special_tokens_ = TryToGetAttributeWithDefault("skip_special_tokens", false);
126 clean_up_tokenization_spaces_ = TryToGetAttributeWithDefault("clean_up_tokenization_spaces", true);
127
128 decoder_ = std::make_shared<BertTokenizerDecoder>(vocab, unk_token, sep_token, pad_token,
129 cls_token, mask_token, suffix_indicator);
130}
131
132void KernelBertTokenizerDecoder::Compute(OrtKernelContext* context) {
133 const OrtValue* ids = ort_.KernelContext_GetInput(context, 0);
134 const int64_t* p_ids = ort_.GetTensorData<int64_t>(ids);
135 OrtTensorDimensions ids_dim(ort_, ids);
136
137 if (!((ids_dim.size() == 1) || (ids_dim.size() == 2 && ids_dim[0] == 1))) {
138 ORT_CXX_API_THROW("[BertTokenizerDecoder]: Expect ids dimension [n] or [1,n].", ORT_INVALID_GRAPH);
139 }
140
141 // const int64_t* p_row_indices = ort_row_indices_dim.empty() ? nullptr : ort_.GetTensorData<int64_t>(ort_row_indices);
142 const OrtValue* positions = ort_.KernelContext_GetInput(context, 1);
143 OrtTensorDimensions positions_dim(ort_, positions);
144 if (use_indices_ &&
145 (!(positions_dim.empty() ||
146 (positions_dim.Size() == 0) ||
147 (positions_dim.size() == 2 && positions_dim[1] == 2)))) {
148 ORT_CXX_API_THROW("[BertTokenizerDecoder]: Expect positions empty or a [n, 2] matrix when use indices", ORT_INVALID_GRAPH);
149 }
150
151 const int64_t* p_positions =
152 positions_dim.empty() || positions_dim.Size() == 0? nullptr : ort_.GetTensorData<int64_t>(positions);
153
154 std::vector<std::string> result;
155 std::vector<int64_t> output_dim(1);
156 if (!use_indices_) {
157 result.push_back(decoder_->Decode(std::vector<int64_t>(p_ids, p_ids + ids_dim.Size()), skip_special_tokens_, clean_up_tokenization_spaces_));
158 output_dim[0] = 1;
159 } else {
160 if (p_positions != nullptr) {
161 for (int i = 0; i < positions_dim[0]; i++) {
162 int64_t start = p_positions[2 * i];
163 int64_t end = p_positions[2 * i + 1];
164
165 result.push_back(decoder_->Decode(std::vector<int64_t>(p_ids + start, p_ids + end), skip_special_tokens_, clean_up_tokenization_spaces_));
166 }
167 output_dim[0] = positions_dim[0];
168 }
169 }
170 OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size());
171
172 FillTensorDataString(api_, ort_, context, result, output);
173}
174
175void* CustomOpBertTokenizerDecoder::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
176 return new KernelBertTokenizerDecoder(api, info);
177};
178
179const char* CustomOpBertTokenizerDecoder::GetName() const { return "BertTokenizerDecoder"; };
180
181size_t CustomOpBertTokenizerDecoder::GetInputTypeCount() const {
182 return 2;
183};
184
185ONNXTensorElementDataType CustomOpBertTokenizerDecoder::GetInputType(size_t /*index*/) const {
186 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
187};
188
189size_t CustomOpBertTokenizerDecoder::GetOutputTypeCount() const {
190 return 1;
191};
192
193ONNXTensorElementDataType CustomOpBertTokenizerDecoder::GetOutputType(size_t /*index*/) const {
194 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
195};
196