microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
4a0f8929494fa301baa6c59f617cce7872a7c4c8

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tokenizer/gpt2tok.cc

671lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3// Partial code comes from other Microsoft employee.
4
5#include <string>
6#include <vector>
7#include <fstream>
8#include <iostream>
9#include <list>
10#include <memory>
11#include <regex>
12#include <sstream>
13#include <stdexcept>
14#include <unordered_map>
15#include <functional>
16#include <codecvt>
17#include <mutex>
18
19#include "nlohmann/json.hpp"
20#include "kernels/kernels.h"
21#include "unicode.h"
22
23namespace {
24class SpecialTokenMap {
25 public:
26 void Add(std::u32string p_str, int p_id) {
27 auto it = token_map_.find(p_str);
28 if (it != token_map_.end()) {
29 if (it->second != p_id) {
30 throw std::runtime_error("Duplicate special tokens");
31 }
32 }
33 else {
34 token_map_[p_str] = p_id;
35 token_list_.push_back(SpecialTokenInfo(std::move(p_str), p_id));
36 }
37 }
38
39 std::list<std::pair<std::u32string, int>> SplitBySpeicalTokens(std::u32string input) const {
40 std::list<std::pair<std::u32string, int>> res;
41 res.push_back({std::move(input), -1});
42 for (const auto& st : token_list_) {
43 std::list<std::pair<std::u32string, int>> new_split_res;
44 for (auto& str : res) {
45 if (str.second != -1) {
46 new_split_res.push_back(std::move(str));
47 continue;
48 }
49 auto it = str.first.begin();
50 size_t search_pos = 0;
51 while (it != str.first.end()) {
52 #if defined(__APPLE__)
53 auto search_it = std::search(it, str.first.end(), st.str.begin(), st.str.end());
54 #else
55 auto search_it = std::search(it, str.first.end(),
56 std::boyer_moore_searcher(st.str.begin(), st.str.end()));
57 #endif
58 if (search_it == str.first.end()) {
59 new_split_res.push_back({str.first.substr(search_pos), -1});
60 break;
61 }
62 auto prefixLen = search_it - it;
63 if (prefixLen != 0) {
64 new_split_res.push_back({str.first.substr(search_pos, prefixLen), -1});
65 search_pos += prefixLen;
66 }
67 new_split_res.push_back({str.first.substr(search_pos, st.str.size()), st.id});
68 it = search_it + st.str.size();
69 search_pos += st.str.size();
70 }
71 }
72 std::swap(new_split_res, res);
73 }
74 return res;
75 }
76
77 private:
78 struct SpecialTokenInfo {
79 std::u32string str;
80 int id;
81
82 SpecialTokenInfo(std::u32string p_str, int p_id)
83 : str(std::move(p_str))
84 , id(p_id) {
85 if (str.size() == 0) {
86 throw std::runtime_error("Empty special token.");
87 }
88 }
89 };
90
91 std::list<SpecialTokenInfo> token_list_;
92 std::unordered_map<std::u32string, int> token_map_;
93};
94
95using json = nlohmann::json;
96class VocabData {
97 public:
98 VocabData()
99 : unk_id_(-1) {
100 }
101
102 struct BpeNode {
103 int id;
104 int value;
105 };
106
107 void Load(const char* p_vocab_file, const char* p_bpe_file, const char* unk_token, const char* special_tokens) {
108 std::ifstream json_stream(p_vocab_file);
109 if (json_stream.fail()) {
110 throw std::runtime_error(std::string("Fail to open vocab file: ") + p_vocab_file);
111 }
112
113 json tok_json;
114 json_stream >> tok_json;
115 vocab_map_ = std::move(tok_json.get<std::unordered_map<std::string, int>>());
116
117 auto it = vocab_map_.find(unk_token);
118 if (it != vocab_map_.end()) {
119 unk_id_ = it->second;
120 } else {
121 int id = (int)vocab_map_.size();
122 vocab_map_[unk_token] = id;
123 std::cerr << "Special token (" << unk_token << ") have been added in the vocabulary." << std::endl;
124 }
125
126 std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> str_convert;
127 for (auto i = 33; i <= 126; ++i)
128 byte_encoder_[i] = GetVocabIndex(str_convert.to_bytes((char32_t)i));
129 for (auto i = 161; i <= 172; ++i)
130 byte_encoder_[i] = GetVocabIndex(str_convert.to_bytes((char32_t)i));
131 for (auto i = 174; i <= 255; ++i)
132 byte_encoder_[i] = GetVocabIndex(str_convert.to_bytes((char32_t)i));
133
134 int index = 256;
135 for (auto i = 0; i < 33; ++i)
136 byte_encoder_[i] = GetVocabIndex(str_convert.to_bytes((char32_t)(index++)));
137 for (auto i = 127; i < 161; ++i)
138 byte_encoder_[i] = GetVocabIndex(str_convert.to_bytes((char32_t)(index++)));
139 byte_encoder_[173] = GetVocabIndex(str_convert.to_bytes((char32_t)(index++)));
140
141 std::ifstream bpe_file(p_bpe_file);
142 if (bpe_file.fail()) {
143 throw std::runtime_error(std::string("Fail to open vocab file: ") + p_bpe_file);
144 }
145
146 index = 0;
147 std::string line;
148 while (std::getline(bpe_file, line)) {
149 line = std::regex_replace(line, std::regex("\r"), "");
150 if (line.empty()) continue;
151 if ((line[0] == '#') && (index == 0)) continue;
152 auto pos = line.find(' ');
153 if (pos == std::string::npos) {
154 throw std::runtime_error("Cannot know how to parse line: " + line);
155 }
156 std::string w1 = line.substr(0, pos);
157 std::string w2 = line.substr(pos + 1);
158 int iw1 = GetVocabIndex(w1);
159 int iw2 = GetVocabIndex(w2);
160 int iww = GetVocabIndex(w1 + w2);
161 std::pair<int, int> key{iw1, iw2};
162 BpeNode value{iww, index++};
163 bpe_map_[key] = value;
164 }
165
166 if (special_tokens) {
167 std::istringstream istrea(special_tokens);
168 std::string line;
169 while (istrea >> line) {
170 if (line.empty()) continue;
171 line = std::regex_replace(line, std::regex("\r"), "");
172 std::u32string line_32 = str_convert.from_bytes(line);
173 int id = (int)vocab_map_.size();
174 if (auto it = vocab_map_.find(line); it != vocab_map_.end())
175 id = it->second;
176 else
177 vocab_map_[line] = id;
178 special_tokens_.Add(std::move(line_32), id);
179 }
180 }
181
182 id2token_map_.resize(vocab_map_.size());
183 for (const auto& [t, i] : vocab_map_) {
184 id2token_map_[i] = t;
185 }
186 }
187
188 public:
189 void bpe(std::list<int>& vals) const {
190 while (vals.size() >= 2) {
191 auto pos_it = vals.end();
192 int minval = std::numeric_limits<int>::max();
193 int ori_id1 = 0, ori_id2 = 0;
194 int aim_id = 0;
195 for (auto it = vals.begin(); it != vals.end(); ++it) {
196 auto it2 = it;
197 ++it2;
198 if (it2 == vals.end()) break;
199 auto map_it = bpe_map_.find({*it, *it2});
200 if (map_it == bpe_map_.end()) continue;
201 if (minval > map_it->second.value) {
202 ori_id1 = *it;
203 ori_id2 = *it2;
204 minval = map_it->second.value;
205 pos_it = it;
206 aim_id = map_it->second.id;
207 }
208 }
209 if (pos_it == vals.end()) break;
210
211 pos_it = vals.erase(pos_it);
212 *pos_it = aim_id;
213 for (++pos_it; pos_it != vals.end(); ++pos_it) {
214 if (*pos_it != ori_id1) continue;
215 auto it2 = pos_it;
216 ++it2;
217 if (it2 == vals.end()) break;
218 if (*it2 != ori_id2) continue;
219 pos_it = vals.erase(pos_it);
220 *pos_it = aim_id;
221 }
222 }
223 }
224
225 const auto& ByteEncoder() const {
226 return byte_encoder_;
227 }
228
229 auto SplitBySpeicalTokens(const std::u32string& input) const {
230 return special_tokens_.SplitBySpeicalTokens(input);
231 }
232
233 size_t VocabSize() const { return vocab_map_.size(); }
234
235 int TokenToID(const std::string& input) const {
236 auto it = vocab_map_.find(input);
237 if (it == vocab_map_.end()) {
238 throw std::runtime_error("Token not found: " + input);
239 }
240 return it->second;
241 }
242
243 const std::string& IdToToken(int id) const {
244 if ((id < 0) || (id >= id2token_map_.size())) {
245 throw std::runtime_error("Invalid ID: " + std::to_string(id));
246 }
247 return id2token_map_[id];
248 }
249
250 private:
251 int GetVocabIndex(const std::string& str) {
252 auto it = vocab_map_.find(str);
253 if (it == vocab_map_.end()) {
254 throw std::runtime_error("Cannot find word in vocabulary: " + str);
255 }
256 return it->second;
257 }
258
259 private:
260 struct hash_pair {
261 template <class T1, class T2>
262 size_t operator()(const std::pair<T1, T2>& p) const {
263 auto hash1 = std::hash<T1>{}(p.first);
264 auto hash2 = std::hash<T2>{}(p.second);
265 return hash1 ^ (hash2 << 16);
266 }
267 };
268 std::unordered_map<std::pair<int, int>, BpeNode, hash_pair> bpe_map_;
269
270 int byte_encoder_[256] = {};
271 std::unordered_map<std::string, int> vocab_map_;
272 std::vector<std::string> id2token_map_;
273
274 int unk_id_;
275 SpecialTokenMap special_tokens_;
276};
277
278class TokenWithRegularExp {
279 public:
280 void Set(std::u32string_view val) {
281 m_text = val;
282 }
283
284 std::pair<bool, std::u32string_view> GetNextToken() {
285 while (!m_text.empty()) {
286 auto res = TryMatch();
287 if (res.empty()) {
288 m_text = m_text.substr(1);
289 continue;
290 }
291 return {true, res};
292 }
293 return {false, {}};
294 }
295
296 private:
297 std::u32string_view TryMatch() {
298 // python pattern:
299 // 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
300
301 // 's|'t|'re|'ve|'m|'ll|'d|
302 // Note: the sequencial of the following if should not be switched, which follows the python regex's syntax
303 if ((m_text[0] == U'\'') && (m_text.size() > 1)) {
304 if ((m_text[1] == U's') || (m_text[1] == U't') ||
305 (m_text[1] == U'm') || (m_text[1] == U'd')) {
306 std::u32string_view res = m_text.substr(0, 2);
307 m_text = m_text.substr(2);
308 return res;
309 } else if (m_text.size() > 2) {
310 if (((m_text[1] == U'r') && (m_text[2] == U'e')) ||
311 ((m_text[1] == U'v') && (m_text[2] == U'e')) ||
312 ((m_text[1] == U'l') && (m_text[2] == U'l'))) {
313 std::u32string_view res = m_text.substr(0, 3);
314 m_text = m_text.substr(3);
315 return res;
316 }
317 }
318 }
319
320 // ?\p{L}+
321 if ((m_text[0] == U' ') && (m_text.size() > 1) && (ufal::unilib::unicode::category(m_text[1]) & ufal::unilib::unicode::L)) {
322 size_t i = 2;
323 for (; i < m_text.size(); ++i) {
324 if ((ufal::unilib::unicode::category(m_text[i]) & ufal::unilib::unicode::L) == 0)
325 break;
326 }
327 std::u32string_view res = m_text.substr(0, i);
328 m_text = m_text.substr(i);
329 return res;
330 }
331 if (ufal::unilib::unicode::category(m_text[0]) & ufal::unilib::unicode::L) {
332 size_t i = 1;
333 for (; i < m_text.size(); ++i) {
334 if ((ufal::unilib::unicode::category(m_text[i]) & ufal::unilib::unicode::L) == 0)
335 break;
336 }
337 std::u32string_view res = m_text.substr(0, i);
338 m_text = m_text.substr(i);
339 return res;
340 }
341
342 // ?\p{N}+
343 if ((m_text[0] == U' ') && (m_text.size() > 1) && (ufal::unilib::unicode::category(m_text[1]) & ufal::unilib::unicode::N)) {
344 size_t i = 2;
345 for (; i < m_text.size(); ++i) {
346 if ((ufal::unilib::unicode::category(m_text[i]) & ufal::unilib::unicode::N) == 0)
347 break;
348 }
349 std::u32string_view res = m_text.substr(0, i);
350 m_text = m_text.substr(i);
351 return res;
352 }
353 if (ufal::unilib::unicode::category(m_text[0]) & ufal::unilib::unicode::N) {
354 size_t i = 1;
355 for (; i < m_text.size(); ++i) {
356 if ((ufal::unilib::unicode::category(m_text[i]) & ufal::unilib::unicode::N) == 0)
357 break;
358 }
359 std::u32string_view res = m_text.substr(0, i);
360 m_text = m_text.substr(i);
361 return res;
362 }
363
364 // ?[^\s\p{L}\p{N}]+
365 if ((m_text[0] == U' ') && (m_text.size() > 1) && (NotLNZ(m_text[1]))) {
366 size_t i = 2;
367 for (; i < m_text.size(); ++i) {
368 if (!NotLNZ(m_text[i]))
369 break;
370 }
371 std::u32string_view res = m_text.substr(0, i);
372 m_text = m_text.substr(i);
373 return res;
374 }
375 if (NotLNZ(m_text[0])) {
376 size_t i = 1;
377 for (; i < m_text.size(); ++i) {
378 if (!NotLNZ(m_text[i]))
379 break;
380 }
381 std::u32string_view res = m_text.substr(0, i);
382 m_text = m_text.substr(i);
383 return res;
384 }
385
386 // \s+(?!\S)|\s+
387 if ((m_text.size() >= 1) && (IsZ(m_text[0]))) {
388 size_t i = 1;
389 for (; i < m_text.size(); ++i) {
390 if (!IsZ(m_text[i])) break;
391 }
392 if ((i > 1) && (i != m_text.size())) //\s+(?!\S)
393 {
394 i--;
395 std::u32string_view res = m_text.substr(0, i);
396 m_text = m_text.substr(i);
397 return res;
398 } else // \s+
399 {
400 std::u32string_view res = m_text.substr(0, i);
401 m_text = m_text.substr(i);
402 return res;
403 }
404 }
405
406 return std::u32string_view{};
407 }
408
409 static bool IsZ(char32_t ch) {
410 auto category = ufal::unilib::unicode::category(ch);
411 return (category & ufal::unilib::unicode::Z) != 0;
412 }
413
414 static bool NotLNZ(char32_t ch) {
415 auto category = ufal::unilib::unicode::category(ch);
416 if (category & ufal::unilib::unicode::L) return false;
417 if (category & ufal::unilib::unicode::N) return false;
418 if (category & ufal::unilib::unicode::Z) return false;
419 return true;
420 }
421
422 private:
423 std::u32string_view m_text;
424};
425
426template <typename TKey, typename TVal>
427class LruCache {
428 public:
429 LruCache(size_t capacity)
430 : m_capacity(capacity) {}
431
432 const TVal* Get(const TKey& key) {
433 if (auto it = m_map.find(key); it != m_map.end()) {
434 TVal* res = &(it->second->second);
435
436 auto lst_it = it->second;
437 m_list.splice(m_list.begin(), m_list, lst_it);
438 if (lst_it != m_list.begin()) {
439 throw std::runtime_error("list splice error in LruCache");
440 }
441 return res;
442 }
443 return nullptr;
444 }
445
446 void Set(const TKey& key, TVal value) {
447 if (auto try_res = Get(key); try_res) {
448 m_list.front().second = std::move(value);
449 return;
450 }
451 if (m_list.size() >= m_capacity) {
452 auto p = m_list.back().first;
453 m_map.erase(p);
454 m_list.pop_back();
455 }
456 m_list.push_front(std::pair<TKey, TVal>{key, std::move(value)});
457 m_map.insert({key, m_list.begin()});
458 }
459
460 private:
461 using ListElem = std::list<std::pair<TKey, TVal>>;
462 ListElem m_list;
463 std::unordered_map<TKey, typename ListElem::iterator> m_map;
464 const size_t m_capacity;
465};
466
467//Note: the following logic comes from CPython: unicodetype_db.h (_PyUnicode_IsWhitespace)
468bool IsUnicodeSpace(char32_t ch) {
469 switch (ch) {
470 case 0x0009:
471 case 0x000A:
472 case 0x000B:
473 case 0x000C:
474 case 0x000D:
475 case 0x001C:
476 case 0x001D:
477 case 0x001E:
478 case 0x001F:
479 case 0x0020:
480 case 0x0085:
481 case 0x00A0:
482 case 0x1680:
483 case 0x2000:
484 case 0x2001:
485 case 0x2002:
486 case 0x2003:
487 case 0x2004:
488 case 0x2005:
489 case 0x2006:
490 case 0x2007:
491 case 0x2008:
492 case 0x2009:
493 case 0x200A:
494 case 0x2028:
495 case 0x2029:
496 case 0x202F:
497 case 0x205F:
498 case 0x3000:
499 return true;
500 }
501 return false;
502}
503} // namespace
504
505struct KernelBpeTokenizer : BaseKernel {
506 KernelBpeTokenizer(OrtApi api, const VocabData* global_data)
507 : BaseKernel(api)
508 , vocab_data_(global_data)
509 , token2id_cache_(30 * 1024){
510 }
511
512 static size_t const p_max_len = 1024;
513 using StringConverter = std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t>;
514
515 private:
516 LruCache<std::string, std::list<int>> token2id_cache_;
517 std::list<int> byte_list_;
518 const VocabData* vocab_data_;
519
520public:
521 size_t Tokenize(const std::u32string& input_32, int* p_index_array, size_t p_max_len) {
522 bool all_space_chars = true;
523 for (auto ch : input_32) {
524 if (IsUnicodeSpace(ch)) {
525 all_space_chars = false;
526 break;
527 }
528 }
529 if (all_space_chars) return 0;
530
531 auto special_token_split_res = vocab_data_->SplitBySpeicalTokens(input_32);
532 size_t cur_id = 0;
533 TokenWithRegularExp regcmp;
534 StringConverter str_convert;
535 for (auto& seg_id : special_token_split_res) {
536 if (cur_id >= p_max_len) break;
537 if (seg_id.second != -1) {
538 p_index_array[cur_id] = seg_id.second;
539 ++cur_id;
540 continue;
541 }
542
543 auto cur_input = std::move(seg_id.first);
544 // Note: keep ptr to make sure the string_view is valid in the following process
545 const char32_t* ptr = cur_input.c_str();
546 regcmp.Set(ptr);
547
548 while (cur_id < p_max_len) {
549 auto [b, tok] = regcmp.GetNextToken();
550 if (!b) break;
551
552 std::string utf8_token = str_convert.to_bytes(tok.data(), tok.data() + tok.size());
553 auto cache_res = token2id_cache_.Get(utf8_token);
554 if (cache_res) {
555 UpdateOutputBuffer(p_index_array, p_max_len, cur_id, *cache_res);
556 } else {
557 byte_list_.clear();
558 for (char& cp : utf8_token)
559 byte_list_.push_back(vocab_data_->ByteEncoder()[(unsigned char)cp]);
560 vocab_data_->bpe(byte_list_);
561 token2id_cache_.Set(utf8_token, byte_list_);
562 UpdateOutputBuffer(p_index_array, p_max_len, cur_id, byte_list_);
563 }
564 }
565 }
566 return cur_id;
567 }
568
569 void UpdateOutputBuffer(int* p_index_array, size_t p_max_len, size_t& cur_id, const std::list<int>& byte_lst) {
570 size_t aim_len = byte_lst.size();
571 if (aim_len + cur_id > p_max_len) aim_len = p_max_len - cur_id;
572
573 for (auto p : byte_lst) {
574 p_index_array[cur_id] = p;
575 ++cur_id;
576 --aim_len;
577 if (aim_len == 0) break;
578 }
579 }
580
581 size_t Tokenize(const std::string& p_input, int* p_index_array, size_t p_max_len) {
582 std::u32string input_32 = StringConverter().from_bytes(p_input);
583 return Tokenize(input_32, p_index_array, p_max_len);
584 }
585
586 void Compute(OrtKernelContext* context) {
587 // Setup inputs
588 const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
589 const std::string* str_input = ort_.GetTensorData<std::string>(input);
590
591 OrtTensorDimensions dimensions(ort_, input);
592 int tok_res[p_max_len];
593 auto indexed_len = Tokenize(str_input[0], tok_res, p_max_len);
594 if (dimensions.size() != 1 || dimensions[0] != 1) {
595 throw std::runtime_error("only support 1-d string input");
596 }
597
598 // Setup output
599 int64_t output_shape[2] = {1, static_cast<int64_t>(indexed_len)};
600 OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_shape, 2);
601 int64_t* out = ort_.GetTensorMutableData<int64_t>(output);
602
603 OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
604 int64_t size = ort_.GetTensorShapeElementCount(output_info);
605 ort_.ReleaseTensorTypeAndShapeInfo(output_info);
606
607 for (size_t j = 0; j < indexed_len; j++) {
608 out[j] = tok_res[j];
609 }
610 }
611};
612
613struct CustomOpBpeTokenizer : Ort::CustomOpBase<CustomOpBpeTokenizer, KernelBpeTokenizer> {
614
615 VocabData bbpe_tokenizer_;
616
617 void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
618 return new KernelBpeTokenizer(api, &bbpe_tokenizer_);
619 }
620
621 const char* GetName() const {
622 return "GPT2Tokenizer";
623 }
624
625 size_t GetInputTypeCount() const {
626 return 1;
627 }
628
629 ONNXTensorElementDataType GetInputType(size_t index) const {
630 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
631 }
632 size_t GetOutputTypeCount() const {
633 return 1;
634 }
635
636 ONNXTensorElementDataType GetOutputType(size_t index) const {
637 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
638 }
639
640 CustomOpBpeTokenizer() {
641 if (const char* filepath = std::getenv("GPT2TOKFILE")) {
642 std::string merges(filepath);
643 std::string vocab_ext(".vocab");
644 if (size_t pos = merges.find(vocab_ext); pos != std::string::npos) {
645 merges.replace(pos, vocab_ext.length(), ".merges.txt");
646 bbpe_tokenizer_.Load(filepath, merges.c_str(), "<|endoftext|>", "<|endoftext|>");
647 }
648 else{
649 throw std::runtime_error(std::string("cannot find the vocab file: ") + filepath);
650 }
651 }
652 }
653
654 ~CustomOpBpeTokenizer() {
655 // STL container objects in global data will be cleaned on their own.
656 }
657};
658
659const OrtCustomOp** LoadTokenizerSchemaList() {
660 // create the global objects here to let the ORT catch the expection if any
661 static std::unique_ptr<CustomOpBpeTokenizer> p_CoBpeTokenizer;
662 static const OrtCustomOp* c_DomainList[2] = {nullptr}; // {&c_CoBpeTokenizer, nullptr};
663 static std::mutex mtx_loaded;
664 std::lock_guard<std::mutex> lck(mtx_loaded);
665 if (p_CoBpeTokenizer.get() == nullptr) {
666 p_CoBpeTokenizer = std::make_unique<CustomOpBpeTokenizer>();
667 c_DomainList[0] = p_CoBpeTokenizer.get();
668 }
669
670 return c_DomainList;
671}
672