microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
operators/tokenizer/wordpiece_tokenizer.cc
203lines · modecode
| 1 | // Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | // Licensed under the MIT License. |
| 3 | |
| 4 | #include "wordpiece_tokenizer.hpp" |
| 5 | #include "nlohmann/json.hpp" |
| 6 | |
| 7 | KernelWordpieceTokenizer::KernelWordpieceTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) { |
| 8 | // https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/WordpieceTokenizer.md |
| 9 | // https://github.com/tensorflow/text/blob/master/tensorflow_text/python/ops/bert_tokenizer.py |
| 10 | std::string vocab_as_string = ort_.KernelInfoGetAttribute<std::string>(info, "vocab"); |
| 11 | std::string suffix_indicator = ort_.KernelInfoGetAttribute<std::string>(info, "suffix_indicator"); |
| 12 | std::string unk = ort_.KernelInfoGetAttribute<std::string>(info, "unknown_token"); |
| 13 | max_input_chars_per_word_ = HasAttribute("max_input_chars_per_word") ? ort_.KernelInfoGetAttribute<int64_t>(info, "max_input_chars_per_word") : 200; |
| 14 | suffix_indicator_ = ustring(suffix_indicator); |
| 15 | unk_token_ = ustring(unk); |
| 16 | |
| 17 | std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> cvt; |
| 18 | std::unordered_map<std::string, int32_t> vocab_map; |
| 19 | auto parsed = nlohmann::json::parse(vocab_as_string); |
| 20 | parsed.get_to(vocab_map); |
| 21 | |
| 22 | for (auto it = vocab_map.begin(); it != vocab_map.end(); ++it) { |
| 23 | vocab_[ustring(it->first)] = it->second; |
| 24 | } |
| 25 | } |
| 26 | |
| 27 | void KernelWordpieceTokenizer_Split(const std::u32string& suffix_indicator, |
| 28 | const std::u32string& text, |
| 29 | std::vector<std::u32string>& words) { |
| 30 | ustring space(" "); |
| 31 | int pos = 0; |
| 32 | int last = 0; |
| 33 | words.clear(); |
| 34 | for (; pos < text.size(); ++pos) { |
| 35 | if (text[pos] == space[0]) { |
| 36 | if (last >= 0 && last < pos) { |
| 37 | words.push_back(text.substr(last, pos - last)); |
| 38 | } |
| 39 | last = pos + 1; |
| 40 | } |
| 41 | } |
| 42 | if (last >= 0 && last < text.size()) { |
| 43 | words.push_back(text.substr(last, pos - last)); |
| 44 | } |
| 45 | } |
| 46 | |
| 47 | void KernelWordpieceTokenizer_Tokenizer(const std::unordered_map<std::u32string, int32_t>& vocab, |
| 48 | const std::u32string& suffix_indicator, |
| 49 | const ustring& unk_token, |
| 50 | const std::vector<ustring>& texts, |
| 51 | std::vector<ustring>& tokens, |
| 52 | std::vector<int32_t>& indices, |
| 53 | std::vector<int64_t>& rows, |
| 54 | const int64_t* existing_rows, |
| 55 | int64_t n_existing_rows, |
| 56 | int64_t max_input_chars_per_word) { |
| 57 | std::vector<std::u32string> words; |
| 58 | bool is_bad; |
| 59 | bool no_existing_rows = n_existing_rows == 0; |
| 60 | int start, end; |
| 61 | std::u32string substr; |
| 62 | int64_t cur_substr; |
| 63 | tokens.clear(); |
| 64 | indices.clear(); |
| 65 | rows.clear(); |
| 66 | std::u32string token; |
| 67 | int64_t row_index = 0; |
| 68 | std::vector<ustring>::const_iterator it; |
| 69 | int64_t text_index; |
| 70 | for (it = texts.begin(), text_index = 0; it != texts.end(); ++it, ++text_index) { |
| 71 | if (no_existing_rows) { |
| 72 | rows.push_back(indices.size()); |
| 73 | } else if (text_index == existing_rows[row_index]) { |
| 74 | if (row_index >= n_existing_rows) |
| 75 | throw std::runtime_error(MakeString( |
| 76 | "row_index=", row_index, " is out of range=", n_existing_rows, ".")); |
| 77 | rows.push_back(indices.size()); |
| 78 | ++row_index; |
| 79 | } |
| 80 | |
| 81 | KernelWordpieceTokenizer_Split(suffix_indicator, *it, words); |
| 82 | |
| 83 | for (auto itk = words.begin(); itk != words.end(); ++itk) { |
| 84 | if (itk->size() > max_input_chars_per_word) { |
| 85 | indices.push_back(-1); |
| 86 | tokens.push_back(unk_token); |
| 87 | continue; |
| 88 | } |
| 89 | is_bad = false; |
| 90 | start = 0; |
| 91 | for (; start < itk->size();) { |
| 92 | end = itk->size(); |
| 93 | cur_substr = -1; |
| 94 | for (; start < end;) { |
| 95 | substr = itk->substr(start, end - start); |
| 96 | if (start > 0) |
| 97 | substr = suffix_indicator + substr; |
| 98 | auto itf = vocab.find(substr); |
| 99 | if (itf != vocab.end()) { |
| 100 | token = substr; |
| 101 | cur_substr = itf->second; |
| 102 | break; |
| 103 | } |
| 104 | end -= 1; |
| 105 | } |
| 106 | if (cur_substr == -1) { |
| 107 | is_bad = true; |
| 108 | break; |
| 109 | } |
| 110 | indices.push_back(cur_substr); |
| 111 | tokens.push_back(ustring(token)); |
| 112 | start = end; |
| 113 | } |
| 114 | if (is_bad) { |
| 115 | indices.push_back(-1); |
| 116 | tokens.push_back(unk_token); |
| 117 | } |
| 118 | } |
| 119 | } |
| 120 | rows.push_back(indices.size()); |
| 121 | } |
| 122 | |
| 123 | void KernelWordpieceTokenizer::Compute(OrtKernelContext* context) { |
| 124 | // Update with the new API |
| 125 | const OrtValue* ort_input = ort_.KernelContext_GetInput(context, 0); |
| 126 | std::vector<ustring> str_input; |
| 127 | GetTensorMutableDataString(api_, ort_, context, ort_input, str_input); |
| 128 | const OrtValue* ort_row_indices = ort_.KernelContext_GetInput(context, 1); |
| 129 | OrtTensorDimensions ort_row_indices_dim(ort_, ort_row_indices); |
| 130 | const int64_t* p_row_indices = ort_row_indices_dim.empty() ? nullptr : ort_.GetTensorData<int64_t>(ort_row_indices); |
| 131 | |
| 132 | std::vector<ustring> tokens; |
| 133 | std::vector<int32_t> indices; |
| 134 | std::vector<int64_t> row_begins; |
| 135 | |
| 136 | KernelWordpieceTokenizer_Tokenizer(vocab_, suffix_indicator_, unk_token_, str_input, |
| 137 | tokens, indices, row_begins, |
| 138 | p_row_indices, ort_row_indices_dim.Size(), |
| 139 | max_input_chars_per_word_); |
| 140 | |
| 141 | std::vector<int64_t> size_content{(int64_t)indices.size()}; |
| 142 | OrtValue* output = ort_.KernelContext_GetOutput(context, 0, size_content.data(), size_content.size()); |
| 143 | FillTensorDataString(api_, ort_, context, tokens, output); |
| 144 | |
| 145 | std::vector<int64_t> size_row_lengths{(int64_t)row_begins.size()}; |
| 146 | OrtValue* output_row_lengths = ort_.KernelContext_GetOutput(context, 1, size_row_lengths.data(), size_row_lengths.size()); |
| 147 | --size_row_lengths[0]; |
| 148 | OrtValue* output_row_begins = ort_.KernelContext_GetOutput(context, 2, size_row_lengths.data(), size_row_lengths.size()); |
| 149 | OrtValue* output_limit_values = ort_.KernelContext_GetOutput(context, 3, size_row_lengths.data(), size_row_lengths.size()); |
| 150 | int64_t* ptr_row_lengths = ort_.GetTensorMutableData<int64_t>(output_row_lengths); |
| 151 | int64_t* ptr_row_begins = ort_.GetTensorMutableData<int64_t>(output_row_begins); |
| 152 | int64_t* ptr_limit_values = ort_.GetTensorMutableData<int64_t>(output_limit_values); |
| 153 | |
| 154 | int64_t i; |
| 155 | for (i = 0; i < size_row_lengths[0]; ++i) { |
| 156 | ptr_row_lengths[i] = row_begins[i]; |
| 157 | ptr_row_begins[i] = row_begins[i]; |
| 158 | ptr_limit_values[i] = row_begins[i + 1]; |
| 159 | } |
| 160 | |
| 161 | i = size_row_lengths[0]; |
| 162 | ptr_row_lengths[i] = row_begins[i]; |
| 163 | } |
| 164 | |
| 165 | void* CustomOpWordpieceTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const { |
| 166 | return new KernelWordpieceTokenizer(api, info); |
| 167 | }; |
| 168 | |
| 169 | const char* CustomOpWordpieceTokenizer::GetName() const { |
| 170 | return "WordpieceTokenizer"; |
| 171 | }; |
| 172 | |
| 173 | size_t CustomOpWordpieceTokenizer::GetInputTypeCount() const { |
| 174 | return 2; |
| 175 | }; |
| 176 | |
| 177 | ONNXTensorElementDataType CustomOpWordpieceTokenizer::GetInputType(size_t index) const { |
| 178 | switch (index) { |
| 179 | case 0: |
| 180 | return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; |
| 181 | case 1: |
| 182 | return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; |
| 183 | default: |
| 184 | throw std::runtime_error(MakeString("Unexpected input index ", index)); |
| 185 | } |
| 186 | }; |
| 187 | |
| 188 | size_t CustomOpWordpieceTokenizer::GetOutputTypeCount() const { |
| 189 | return 4; |
| 190 | }; |
| 191 | |
| 192 | ONNXTensorElementDataType CustomOpWordpieceTokenizer::GetOutputType(size_t index) const { |
| 193 | switch (index) { |
| 194 | case 0: |
| 195 | return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; |
| 196 | case 1: |
| 197 | case 2: |
| 198 | case 3: |
| 199 | return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; |
| 200 | default: |
| 201 | throw std::runtime_error(MakeString("[WordpieceTokenizer] Unexpected output index ", index)); |
| 202 | } |
| 203 | }; |
| 204 | |