microsoft/onnxruntime-extensions

Public

mirrored from https://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v0.7.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/tokenizer/clip_tokenizer.cc

162lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3// Partial code comes from other Microsoft employee.
4#include "clip_tokenizer.hpp"
5#include "string_utils.h"
6
7KernelClipBpeTokenizer::KernelClipBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info)
8 : BaseKernel(api, info) {
9 std::string vocab = ort_.KernelInfoGetAttribute<std::string>(&info, "vocab");
10 if (vocab.empty()) {
11 ORTX_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
12 }
13
14 std::string merges = ort_.KernelInfoGetAttribute<std::string>(&info, "merges");
15 if (merges.empty()) {
16 ORTX_CXX_API_THROW("merges shouldn't be empty.", ORT_INVALID_ARGUMENT);
17 }
18
19 if (!TryToGetAttribute<int64_t>("padding_length", padding_length_)) {
20 padding_length_ = -1;
21 }
22
23 if (padding_length_ != -1 && padding_length_ <= 0) {
24 ORTX_CXX_API_THROW("padding_length should be more than 0 or equal -1", ORT_INVALID_ARGUMENT);
25 }
26
27 std::stringstream vocabu_stream(vocab);
28 std::stringstream merges_stream(merges);
29 bbpe_tokenizer_ = std::make_shared<VocabData>();
30 bbpe_tokenizer_->Load(vocabu_stream, merges_stream, "<|endoftext|>", "<|endoftext|>");
31}
32
33std::vector<int64_t> KernelClipBpeTokenizer::Tokenize(ustring& input, int64_t max_length) {
34 std::vector<int64_t> res;
35
36 if (IsEmptyUString(input)) {
37 return res;
38 }
39 // Add <|startoftext|> token to result
40 res.push_back(bbpe_tokenizer_->GetEncoding("<|startoftext|>"));
41
42 // Convert to lowercase
43 std::transform(input.begin(), input.end(), input.begin(), [](char32_t c) { return static_cast<char32_t>(ToLower(c)); });
44
45 // Parse input
46 auto special_token_split_res = bbpe_tokenizer_->SplitBySpecialTokens(input);
47 TokenWithRegularExp regcmp;
48
49 for (auto& seg_id : special_token_split_res) {
50 if (static_cast<int64_t>(res.size()) >= max_length) break;
51
52 if (seg_id.second != -1) {
53 res.push_back(seg_id.second);
54 continue;
55 }
56
57 auto cur_input = std::move(seg_id.first);
58 // Note: keep ptr to make sure the string_view is valid in the following process
59 const char32_t* ptr = cur_input.c_str();
60 regcmp.Set(ptr);
61
62 while (static_cast<int64_t>(res.size()) < max_length) {
63 auto [b, tok] = regcmp.GetNextToken();
64 if (!b) break;
65
66 std::string utf8_token = std::string(ustring(tok));
67
68 // Whitespace clean
69 utf8_token.erase(std::remove(utf8_token.begin(), utf8_token.end(), ' '), utf8_token.end());
70
71 // Get byte encodings prior to performing BPE
72 byte_list_.clear();
73 for (int i = 0; i < utf8_token.length(); i++) {
74 if (i == utf8_token.length() - 1) {
75 std::string boundary(1, utf8_token[i]);
76 byte_list_.push_back(bbpe_tokenizer_->GetEncoding(boundary + "</w>"));
77 } else {
78 byte_list_.push_back(bbpe_tokenizer_->ByteEncoder()[static_cast<unsigned char>(utf8_token[i])]);
79 }
80 }
81
82 // Perform BPE
83 bbpe_tokenizer_->bpe(byte_list_);
84
85 // Add output to result
86 for (auto p : byte_list_) {
87 if (static_cast<int64_t>(res.size()) >= max_length) {
88 break;
89 }
90
91 res.push_back(p);
92 }
93 }
94 }
95 // Add <|endoftext|> token to result
96 res.push_back(bbpe_tokenizer_->GetEncoding("<|endoftext|>"));
97 return res;
98}
99
100void KernelClipBpeTokenizer::Compute(OrtKernelContext* context) {
101 // Setup inputs
102 const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
103 std::vector<std::string> str_input;
104 GetTensorMutableDataString(api_, ort_, context, input, str_input);
105 OrtTensorDimensions input_dim(ort_, input);
106
107 std::vector<std::vector<int64_t>> tokenize_results;
108 for (auto& str : str_input) {
109 ustring ustr = ustring(str);
110 tokenize_results.emplace_back(Tokenize(ustr, padding_length_ < 0 ? INT64_MAX : padding_length_));
111 }
112
113 size_t max_length = 0;
114 if (padding_length_ == -1) {
115 for (auto& res : tokenize_results) {
116 max_length = std::max(max_length, res.size());
117 }
118 } else {
119 max_length = static_cast<size_t>(padding_length_);
120 }
121
122 OrtTensorDimensions output_dim = input_dim;
123 output_dim.push_back(max_length);
124 OrtValue* tokenize_output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size());
125 OrtValue* attention_mask = ort_.KernelContext_GetOutput(context, 1, output_dim.data(), output_dim.size());
126 auto* token = ort_.GetTensorMutableData<int64_t>(tokenize_output);
127 auto* mask = ort_.GetTensorMutableData<int64_t>(attention_mask);
128
129 int idx = 0;
130 for (auto& res : tokenize_results) {
131 for (int64_t id : res) {
132 token[idx] = id;
133 mask[idx] = 1;
134 idx++;
135 }
136
137 for (size_t i = res.size(); i < max_length; i++) {
138 token[idx] = 0;
139 mask[idx] = 0;
140 idx++;
141 }
142 }
143}
144
145const char* CustomOpClipBpeTokenizer::GetName() const {
146 return "CLIPTokenizer";
147}
148
149size_t CustomOpClipBpeTokenizer::GetInputTypeCount() const {
150 return 1;
151}
152
153ONNXTensorElementDataType CustomOpClipBpeTokenizer::GetInputType(size_t /*index*/) const {
154 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
155}
156size_t CustomOpClipBpeTokenizer::GetOutputTypeCount() const {
157 return 2;
158}
159
160ONNXTensorElementDataType CustomOpClipBpeTokenizer::GetOutputType(size_t /*index*/) const {
161 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
162}
163