microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
13d9e27ccd8a0de9a1225756fbf6860a1931484f

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/tokenizer/clip_tokenizer.cc

226lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3// Partial code comes from other Microsoft employee.
4
5#include <string>
6#include <vector>
7#include <fstream>
8#include <sstream>
9#include <iostream>
10#include <algorithm>
11#include <list>
12#include <memory>
13#include <regex>
14#include <sstream>
15#include <stdexcept>
16#include <unordered_map>
17#include <functional>
18#include <codecvt>
19#include <mutex>
20
21#include "nlohmann/json.hpp"
22#include "bpetokenizer.hpp"
23#include "string_tensor.h"
24#include "unicode.h"
25
26// Note: the following logic comes from CPython: unicodetype_db.h (_PyUnicode_IsWhitespace)
27bool IsInUnicodeSpace(char32_t ch) {
28 switch (ch) {
29 case 0x0009:
30 case 0x000A:
31 case 0x000B:
32 case 0x000C:
33 case 0x000D:
34 case 0x001C:
35 case 0x001D:
36 case 0x001E:
37 case 0x001F:
38 case 0x0020:
39 case 0x0085:
40 case 0x00A0:
41 case 0x1680:
42 case 0x2000:
43 case 0x2001:
44 case 0x2002:
45 case 0x2003:
46 case 0x2004:
47 case 0x2005:
48 case 0x2006:
49 case 0x2007:
50 case 0x2008:
51 case 0x2009:
52 case 0x200A:
53 case 0x2028:
54 case 0x2029:
55 case 0x202F:
56 case 0x205F:
57 case 0x3000:
58 return true;
59 }
60 return false;
61}
62
63bool IsEmptyUstring(const ustring& str) {
64 return std::all_of(str.begin(), str.end(), [](char32_t ch) { return IsInUnicodeSpace(ch); });
65}
66
67KernelClipBpeTokenizer::KernelClipBpeTokenizer(const OrtApi& api, const OrtKernelInfo* info)
68 : BaseKernel(api, info) {
69 std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab");
70 if (vocab.empty()) {
71 ORTX_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
72 }
73
74 std::string merges = ort_.KernelInfoGetAttribute<std::string>(info, "merges");
75 if (merges.empty()) {
76 ORTX_CXX_API_THROW("merges shouldn't be empty.", ORT_INVALID_ARGUMENT);
77 }
78
79 if (!TryToGetAttribute<int64_t>("padding_length", padding_length_)) {
80 padding_length_ = -1;
81 }
82
83 if (padding_length_ != -1 && padding_length_ <= 0) {
84 ORTX_CXX_API_THROW("padding_length should be more than 0 or equal -1", ORT_INVALID_ARGUMENT);
85 }
86
87 std::stringstream vocabu_stream(vocab);
88 std::stringstream merges_stream(merges);
89 bbpe_tokenizer_ = std::make_shared<VocabData>();
90 bbpe_tokenizer_->Load(vocabu_stream, merges_stream, "<|endoftext|>", "<|endoftext|>");
91}
92
93std::vector<int64_t> KernelClipBpeTokenizer::Tokenize(ustring& input, int64_t max_length) {
94 std::vector<int64_t> res;
95
96 if (IsEmptyUstring(input)) {
97 return res;
98 }
99 // Add <|startoftext|> token to result
100 res.push_back(bbpe_tokenizer_->GetEncoding("<|startoftext|>"));
101
102 // Convert to lowercase
103 std::transform(input.begin(), input.end(), input.begin(), [](char32_t c) { return static_cast<char32_t>(ToLower(c)); });
104
105 // Parse input
106 auto special_token_split_res = bbpe_tokenizer_->SplitBySpecialTokens(input);
107 TokenWithRegularExp regcmp;
108
109 for (auto& seg_id : special_token_split_res) {
110 if (static_cast<int64_t>(res.size()) >= max_length) break;
111
112 if (seg_id.second != -1) {
113 res.push_back(seg_id.second);
114 continue;
115 }
116
117 auto cur_input = std::move(seg_id.first);
118 // Note: keep ptr to make sure the string_view is valid in the following process
119 const char32_t* ptr = cur_input.c_str();
120 regcmp.Set(ptr);
121
122 while (static_cast<int64_t>(res.size()) < max_length) {
123 auto [b, tok] = regcmp.GetNextToken();
124 if (!b) break;
125
126 std::string utf8_token = std::string(ustring(tok));
127
128 // Whitespace clean
129 utf8_token.erase(std::remove(utf8_token.begin(), utf8_token.end(), ' '), utf8_token.end());
130
131 // Get byte encodings prior to performing BPE
132 byte_list_.clear();
133 for (int i = 0; i < utf8_token.length(); i++) {
134 if (i == utf8_token.length() - 1) {
135 std::string boundary(1, utf8_token[i]);
136 byte_list_.push_back(bbpe_tokenizer_->GetEncoding(boundary + "</w>"));
137 } else {
138 byte_list_.push_back(bbpe_tokenizer_->ByteEncoder()[static_cast<unsigned char>(utf8_token[i])]);
139 }
140 }
141
142 // Perform BPE
143 bbpe_tokenizer_->bpe(byte_list_);
144
145 // Add output to result
146 for (auto p : byte_list_) {
147 if (static_cast<int64_t>(res.size()) >= max_length) {
148 break;
149 }
150
151 res.push_back(p);
152 }
153 }
154 }
155 // Add <|endoftext|> token to result
156 res.push_back(bbpe_tokenizer_->GetEncoding("<|endoftext|>"));
157 return res;
158}
159
160void KernelClipBpeTokenizer::Compute(OrtKernelContext* context) {
161 // Setup inputs
162 const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
163 std::vector<std::string> str_input;
164 GetTensorMutableDataString(api_, ort_, context, input, str_input);
165 OrtTensorDimensions input_dim(ort_, input);
166
167 std::vector<std::vector<int64_t>> tokenize_results;
168 for (auto& str : str_input) {
169 ustring ustr = ustring(str);
170 tokenize_results.emplace_back(Tokenize(ustr, padding_length_ < 0 ? INT64_MAX : padding_length_));
171 }
172
173 size_t max_length = 0;
174 if (padding_length_ == -1) {
175 for (auto& res : tokenize_results) {
176 max_length = std::max(max_length, res.size());
177 }
178 } else {
179 max_length = static_cast<size_t>(padding_length_);
180 }
181
182 OrtTensorDimensions output_dim = input_dim;
183 output_dim.push_back(max_length);
184 OrtValue* tokenize_output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size());
185 OrtValue* attention_mask = ort_.KernelContext_GetOutput(context, 1, output_dim.data(), output_dim.size());
186 auto* token = ort_.GetTensorMutableData<int64_t>(tokenize_output);
187 auto* mask = ort_.GetTensorMutableData<int64_t>(attention_mask);
188
189 int idx = 0;
190 for (auto& res : tokenize_results) {
191 for (int64_t id : res) {
192 token[idx] = id;
193 mask[idx] = 1;
194 idx++;
195 }
196
197 for (size_t i = res.size(); i < max_length; i++) {
198 token[idx] = 0;
199 mask[idx] = 0;
200 idx++;
201 }
202 }
203}
204
205void* CustomOpClipBpeTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
206 return CreateKernelImpl(api, info);
207}
208
209const char* CustomOpClipBpeTokenizer::GetName() const {
210 return "CLIPTokenizer";
211}
212
213size_t CustomOpClipBpeTokenizer::GetInputTypeCount() const {
214 return 1;
215}
216
217ONNXTensorElementDataType CustomOpClipBpeTokenizer::GetInputType(size_t /*index*/) const {
218 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
219}
220size_t CustomOpClipBpeTokenizer::GetOutputTypeCount() const {
221 return 2;
222}
223
224ONNXTensorElementDataType CustomOpClipBpeTokenizer::GetOutputType(size_t /*index*/) const {
225 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
226}
227