microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
f74770feed077546874ed7e66d1aba9e2509fea9

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/tokenizer/wordpiece_tokenizer.cc

203lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "wordpiece_tokenizer.hpp"
5#include "nlohmann/json.hpp"
6
7KernelWordpieceTokenizer::KernelWordpieceTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
8 // https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/WordpieceTokenizer.md
9 // https://github.com/tensorflow/text/blob/master/tensorflow_text/python/ops/bert_tokenizer.py
10 std::string vocab_as_string = ort_.KernelInfoGetAttribute<std::string>(info, "vocab");
11 std::string suffix_indicator = ort_.KernelInfoGetAttribute<std::string>(info, "suffix_indicator");
12 std::string unk = ort_.KernelInfoGetAttribute<std::string>(info, "unknown_token");
13 max_input_chars_per_word_ = HasAttribute("max_input_chars_per_word") ? ort_.KernelInfoGetAttribute<int64_t>(info, "max_input_chars_per_word") : 200;
14 suffix_indicator_ = ustring(suffix_indicator);
15 unk_token_ = ustring(unk);
16
17 std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> cvt;
18 std::unordered_map<std::string, int32_t> vocab_map;
19 auto parsed = nlohmann::json::parse(vocab_as_string);
20 parsed.get_to(vocab_map);
21
22 for (auto it = vocab_map.begin(); it != vocab_map.end(); ++it) {
23 vocab_[ustring(it->first)] = it->second;
24 }
25}
26
27void KernelWordpieceTokenizer_Split(const std::u32string& suffix_indicator,
28 const std::u32string& text,
29 std::vector<std::u32string>& words) {
30 ustring space(" ");
31 int pos = 0;
32 int last = 0;
33 words.clear();
34 for (; pos < text.size(); ++pos) {
35 if (text[pos] == space[0]) {
36 if (last >= 0 && last < pos) {
37 words.push_back(text.substr(last, pos - last));
38 }
39 last = pos + 1;
40 }
41 }
42 if (last >= 0 && last < text.size()) {
43 words.push_back(text.substr(last, pos - last));
44 }
45}
46
47void KernelWordpieceTokenizer_Tokenizer(const std::unordered_map<std::u32string, int32_t>& vocab,
48 const std::u32string& suffix_indicator,
49 const ustring& unk_token,
50 const std::vector<ustring>& texts,
51 std::vector<ustring>& tokens,
52 std::vector<int32_t>& indices,
53 std::vector<int64_t>& rows,
54 const int64_t* existing_rows,
55 int64_t n_existing_rows,
56 int64_t max_input_chars_per_word) {
57 std::vector<std::u32string> words;
58 bool is_bad;
59 bool no_existing_rows = n_existing_rows == 0;
60 int start, end;
61 std::u32string substr;
62 int64_t cur_substr;
63 tokens.clear();
64 indices.clear();
65 rows.clear();
66 std::u32string token;
67 int64_t row_index = 0;
68 std::vector<ustring>::const_iterator it;
69 int64_t text_index;
70 for (it = texts.begin(), text_index = 0; it != texts.end(); ++it, ++text_index) {
71 if (no_existing_rows) {
72 rows.push_back(indices.size());
73 } else if (text_index == existing_rows[row_index]) {
74 if (row_index >= n_existing_rows)
75 throw std::runtime_error(MakeString(
76 "row_index=", row_index, " is out of range=", n_existing_rows, "."));
77 rows.push_back(indices.size());
78 ++row_index;
79 }
80
81 KernelWordpieceTokenizer_Split(suffix_indicator, *it, words);
82
83 for (auto itk = words.begin(); itk != words.end(); ++itk) {
84 if (itk->size() > max_input_chars_per_word) {
85 indices.push_back(-1);
86 tokens.push_back(unk_token);
87 continue;
88 }
89 is_bad = false;
90 start = 0;
91 for (; start < itk->size();) {
92 end = itk->size();
93 cur_substr = -1;
94 for (; start < end;) {
95 substr = itk->substr(start, end - start);
96 if (start > 0)
97 substr = suffix_indicator + substr;
98 auto itf = vocab.find(substr);
99 if (itf != vocab.end()) {
100 token = substr;
101 cur_substr = itf->second;
102 break;
103 }
104 end -= 1;
105 }
106 if (cur_substr == -1) {
107 is_bad = true;
108 break;
109 }
110 indices.push_back(cur_substr);
111 tokens.push_back(ustring(token));
112 start = end;
113 }
114 if (is_bad) {
115 indices.push_back(-1);
116 tokens.push_back(unk_token);
117 }
118 }
119 }
120 rows.push_back(indices.size());
121}
122
123void KernelWordpieceTokenizer::Compute(OrtKernelContext* context) {
124 // Update with the new API
125 const OrtValue* ort_input = ort_.KernelContext_GetInput(context, 0);
126 std::vector<ustring> str_input;
127 GetTensorMutableDataString(api_, ort_, context, ort_input, str_input);
128 const OrtValue* ort_row_indices = ort_.KernelContext_GetInput(context, 1);
129 OrtTensorDimensions ort_row_indices_dim(ort_, ort_row_indices);
130 const int64_t* p_row_indices = ort_row_indices_dim.empty() ? nullptr : ort_.GetTensorData<int64_t>(ort_row_indices);
131
132 std::vector<ustring> tokens;
133 std::vector<int32_t> indices;
134 std::vector<int64_t> row_begins;
135
136 KernelWordpieceTokenizer_Tokenizer(vocab_, suffix_indicator_, unk_token_, str_input,
137 tokens, indices, row_begins,
138 p_row_indices, ort_row_indices_dim.Size(),
139 max_input_chars_per_word_);
140
141 std::vector<int64_t> size_content{(int64_t)indices.size()};
142 OrtValue* output = ort_.KernelContext_GetOutput(context, 0, size_content.data(), size_content.size());
143 FillTensorDataString(api_, ort_, context, tokens, output);
144
145 std::vector<int64_t> size_row_lengths{(int64_t)row_begins.size()};
146 OrtValue* output_row_lengths = ort_.KernelContext_GetOutput(context, 1, size_row_lengths.data(), size_row_lengths.size());
147 --size_row_lengths[0];
148 OrtValue* output_row_begins = ort_.KernelContext_GetOutput(context, 2, size_row_lengths.data(), size_row_lengths.size());
149 OrtValue* output_limit_values = ort_.KernelContext_GetOutput(context, 3, size_row_lengths.data(), size_row_lengths.size());
150 int64_t* ptr_row_lengths = ort_.GetTensorMutableData<int64_t>(output_row_lengths);
151 int64_t* ptr_row_begins = ort_.GetTensorMutableData<int64_t>(output_row_begins);
152 int64_t* ptr_limit_values = ort_.GetTensorMutableData<int64_t>(output_limit_values);
153
154 int64_t i;
155 for (i = 0; i < size_row_lengths[0]; ++i) {
156 ptr_row_lengths[i] = row_begins[i];
157 ptr_row_begins[i] = row_begins[i];
158 ptr_limit_values[i] = row_begins[i + 1];
159 }
160
161 i = size_row_lengths[0];
162 ptr_row_lengths[i] = row_begins[i];
163}
164
165void* CustomOpWordpieceTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
166 return new KernelWordpieceTokenizer(api, info);
167};
168
169const char* CustomOpWordpieceTokenizer::GetName() const {
170 return "WordpieceTokenizer";
171};
172
173size_t CustomOpWordpieceTokenizer::GetInputTypeCount() const {
174 return 2;
175};
176
177ONNXTensorElementDataType CustomOpWordpieceTokenizer::GetInputType(size_t index) const {
178 switch (index) {
179 case 0:
180 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
181 case 1:
182 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
183 default:
184 throw std::runtime_error(MakeString("Unexpected input index ", index));
185 }
186};
187
188size_t CustomOpWordpieceTokenizer::GetOutputTypeCount() const {
189 return 4;
190};
191
192ONNXTensorElementDataType CustomOpWordpieceTokenizer::GetOutputType(size_t index) const {
193 switch (index) {
194 case 0:
195 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
196 case 1:
197 case 2:
198 case 3:
199 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
200 default:
201 throw std::runtime_error(MakeString("[WordpieceTokenizer] Unexpected output index ", index));
202 }
203};
204