microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
bfbfa5a3044ec8d1312f3782c78ea3b9246bf667

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

test/static_test/test_tokenizer.cc

249lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "gtest/gtest.h"
5#include "string_utils.h"
6#include "wordpiece_tokenizer.hpp"
7#include "bert_tokenizer.hpp"
8
9#include <clocale>
10
11
12class LocaleBaseTest : public testing::Test{
13 public:
14 // Remember that SetUp() is run immediately before a test starts.
15 void SetUp() override {
16#if (defined(WIN32) || defined(_WIN32) || defined(__WIN32__) && !defined(__GNUC__))
17 default_locale_ = std::locale().name();
18 std::setlocale(LC_CTYPE, "C");
19#else
20 default_locale_ = std::locale("").name();
21 std::setlocale(LC_CTYPE, "en_US.UTF-8");
22#endif
23 }
24 // TearDown() is invoked immediately after a test finishes.
25 void TearDown() override {
26 if (!default_locale_.empty()) {
27 std::setlocale(LC_CTYPE, default_locale_.c_str());
28 }
29 }
30
31 private:
32 std::string default_locale_;
33};
34
35TEST(tokenizer, bert_word_split) {
36 ustring ind("##");
37 ustring text("A AAA B BB");
38 std::vector<std::u32string> words;
39 KernelWordpieceTokenizer_Split(ind, text, words);
40 std::vector<std::u32string> expected{ustring("A"), ustring("AAA"), ustring("B"), ustring("BB")};
41 EXPECT_EQ(expected, words);
42
43 text = ustring(" A AAA B BB ");
44 KernelWordpieceTokenizer_Split(ind, text, words);
45 EXPECT_EQ(words, expected);
46}
47
48std::unordered_map<std::u32string, int32_t> get_vocabulary_basic() {
49 std::vector<ustring> vocab_tokens = {
50 ustring("[UNK]"),
51 ustring("[CLS]"),
52 ustring("[SEP]"),
53 ustring("[PAD]"),
54 ustring("[MASK]"),
55 ustring("want"),
56 ustring("##want"),
57 ustring("##ed"),
58 ustring("wa"),
59 ustring("un"),
60 ustring("runn"),
61 ustring("##ing"),
62 ustring(","),
63 ustring("low"),
64 ustring("lowest"),
65 };
66 std::unordered_map<std::u32string, int32_t> vocab;
67 for (auto it = vocab_tokens.begin(); it != vocab_tokens.end(); ++it) {
68 vocab[*it] = vocab.size();
69 }
70 return vocab;
71}
72
73std::vector<ustring> ustring_vector_convertor(std::vector<std::string> input) {
74 std::vector<ustring> result;
75 for (const auto& str : input) {
76 result.emplace_back(str);
77 }
78 return result;
79}
80
81TEST(tokenizer, wordpiece_basic_tokenizer) {
82 auto vocab = get_vocabulary_basic();
83 std::vector<ustring> text = {ustring("UNwant\u00E9d,running")};
84 std::vector<ustring> tokens;
85 std::vector<int32_t> indices;
86 std::vector<int64_t> rows;
87 KernelWordpieceTokenizer_Tokenizer(vocab, ustring("##"), ustring("[unk]"), text, tokens, indices, rows);
88 // EXPECT_EQ(indices, std::vector<int32_t>({9, 6, 7, 12, 10, 11}));
89 // EXPECT_EQ(rows, std::vector<int64_t>({0, 6}));
90}
91
92std::unordered_map<std::u32string, int32_t> get_vocabulary_wordpiece() {
93 std::vector<ustring> vocab_tokens = {
94 ustring("[UNK]"), // 0
95 ustring("[CLS]"), // 1
96 ustring("[SEP]"), // 2
97 ustring("want"), // 3
98 ustring("##want"), // 4
99 ustring("##ed"), // 5
100 ustring("wa"), // 6
101 ustring("un"), // 7
102 ustring("runn"), // 8
103 ustring("##ing"), // 9
104 };
105 std::unordered_map<std::u32string, int32_t> vocab;
106 for (auto it = vocab_tokens.begin(); it != vocab_tokens.end(); ++it) {
107 vocab[*it] = vocab.size();
108 }
109 return vocab;
110}
111
112TEST(tokenizer, wordpiece_wordpiece_tokenizer) {
113 auto vocab = get_vocabulary_wordpiece();
114 std::vector<int32_t> indices;
115 std::vector<int64_t> rows;
116 std::vector<ustring> tokens;
117
118 std::vector<ustring> text = {ustring("unwanted running")}; // "un", "##want", "##ed", "runn", "##ing"
119 KernelWordpieceTokenizer_Tokenizer(vocab, ustring("##"), ustring("[UNK]"), text, tokens, indices, rows);
120 EXPECT_EQ(tokens, std::vector<ustring>({ustring("un"), ustring("##want"), ustring("##ed"),
121 ustring("runn"), ustring("##ing")}));
122 EXPECT_EQ(indices, std::vector<int32_t>({7, 4, 5, 8, 9}));
123 EXPECT_EQ(rows, std::vector<int64_t>({0, 5}));
124
125 text = std::vector<ustring>({ustring("unwantedX running")}); // "[UNK]", "runn", "##ing"
126 KernelWordpieceTokenizer_Tokenizer(vocab, ustring("##"), ustring("[UNK]"), text, tokens, indices, rows);
127 EXPECT_EQ(tokens, std::vector<ustring>({ustring("un"), ustring("##want"), ustring("##ed"),
128 ustring("[UNK]"), ustring("runn"), ustring("##ing")}));
129 EXPECT_EQ(indices, std::vector<int32_t>({7, 4, 5, -1, 8, 9}));
130 EXPECT_EQ(rows, std::vector<int64_t>({0, 6}));
131
132 text = std::vector<ustring>({ustring("")}); //
133 KernelWordpieceTokenizer_Tokenizer(vocab, ustring("##"), ustring("[unk]"), text, tokens, indices, rows);
134 EXPECT_EQ(tokens, std::vector<ustring>());
135 EXPECT_EQ(indices, std::vector<int32_t>());
136 EXPECT_EQ(rows, std::vector<int64_t>({0, 0}));
137}
138
139TEST(tokenizer, bert_wordpiece_tokenizer_rows) {
140 auto vocab = get_vocabulary_wordpiece();
141 std::vector<int32_t> indices;
142 std::vector<int64_t> rows;
143 std::vector<ustring> tokens;
144
145 std::vector<int64_t> existing_indices({0, 2, 3});
146 std::vector<ustring> text = {ustring("unwanted"), ustring("running"), ustring("running")};
147 KernelWordpieceTokenizer_Tokenizer(vocab, ustring("##"), ustring("[UNK]"), text, tokens, indices, rows,
148 existing_indices.data(), existing_indices.size());
149 EXPECT_EQ(tokens, std::vector<ustring>({ustring("un"), ustring("##want"), ustring("##ed"),
150 ustring("runn"), ustring("##ing"),
151 ustring("runn"), ustring("##ing")}));
152 EXPECT_EQ(indices, std::vector<int32_t>({7, 4, 5, 8, 9, 8, 9}));
153 EXPECT_EQ(rows, std::vector<int64_t>({0, 5, 7}));
154}
155
156TEST_F(LocaleBaseTest, basic_tokenizer_chinese) {
157 ustring test_case = ustring("ÀÁÂÃÄÅÇÈÉÊËÌÍÎÑÒÓÔÕÖÚÜ\t䗓𨖷虴𨀐辘𧄋脟𩑢𡗶镇伢𧎼䪱轚榶𢑌㺽𤨡!#$%&(Tom@microsoft.com)*+,-./:;<=>?@[\\]^_`{|}~");
158 std::vector<ustring> expect_result = ustring_vector_convertor({"aaaaaaceeeeiiinooooouu",
159 "䗓", "𨖷", "虴", "𨀐", "辘", "𧄋", "脟", "𩑢", "𡗶", "镇", "伢", "𧎼", "䪱", "轚", "榶", "𢑌", "㺽", "𤨡",
160 "!", "#", "$", "%", "&", "(", "tom", "@", "microsoft", ".", "com", ")", "*", "+", ",", "-", ".", "/", ":",
161 ";", "<", "=", ">", "?", "@", "[", "\\", "]", "^", "_", "`", "{", "|", "}", "~"});
162 BasicTokenizer tokenizer(true, true, true, true, true);
163 auto result = tokenizer.Tokenize(test_case);
164 EXPECT_EQ(result, expect_result);
165}
166
167TEST_F(LocaleBaseTest, basic_tokenizer_russia) {
168 ustring test_case = ustring("A $100,000 price-tag@big>small на русском языке");
169 std::vector<ustring> expect_result = ustring_vector_convertor({"a", "$", "100", ",", "000", "price", "-", "tag", "@", "big", ">", "small", "на", "русском", "языке"});
170 BasicTokenizer tokenizer(true, true, true, true, true);
171 auto result = tokenizer.Tokenize(test_case);
172 EXPECT_EQ(result, expect_result);
173}
174
175TEST_F(LocaleBaseTest, basic_tokenizer) {
176 ustring test_case = ustring("I mean, you’ll need something to talk about next Sunday, right?");
177 std::vector<ustring> expect_result = ustring_vector_convertor({"I", "mean", ",", "you", "’", "ll", "need", "something", "to", "talk", "about", "next", "Sunday", ",", "right", "?"});
178 BasicTokenizer tokenizer(false, true, true, true, true);
179 auto result = tokenizer.Tokenize(test_case);
180 EXPECT_EQ(result, expect_result);
181}
182
183TEST(tokenizer, truncation_one_input) {
184 TruncateStrategy truncate("longest_first");
185
186 std::vector<int64_t> init_vector1({1, 2, 3, 4, 5, 6, 7, 9});
187 std::vector<int64_t> init_vector2({1, 2, 3, 4, 5});
188
189 auto test_input = init_vector1;
190 truncate.Truncate(test_input, -1);
191 EXPECT_EQ(test_input, init_vector1);
192
193 test_input = init_vector1;
194 truncate.Truncate(test_input, 5);
195 EXPECT_EQ(test_input, std::vector<int64_t>({1, 2, 3, 4, 5}));
196
197 test_input = init_vector2;
198 truncate.Truncate(test_input, 6);
199 EXPECT_EQ(test_input, init_vector2);
200}
201
202TEST(tokenizer, truncation_longest_first) {
203 TruncateStrategy truncate("longest_first");
204
205 std::vector<int64_t> init_vector1({1, 2, 3, 4, 5, 6, 7, 9});
206 std::vector<int64_t> init_vector2({1, 2, 3, 4, 5});
207
208 auto test_input1 = init_vector1;
209 auto test_input2 = init_vector2;
210 truncate.Truncate(test_input1, test_input2, -1);
211 EXPECT_EQ(test_input1, init_vector1);
212 EXPECT_EQ(test_input2, init_vector2);
213
214 test_input1 = init_vector1;
215 test_input2 = init_vector2;
216 truncate.Truncate(test_input1, test_input2, 15);
217 EXPECT_EQ(test_input1, init_vector1);
218 EXPECT_EQ(test_input2, init_vector2);
219
220 test_input1 = init_vector1;
221 test_input2 = init_vector2;
222 truncate.Truncate(test_input1, test_input2, 14);
223 EXPECT_EQ(test_input1, init_vector1);
224 EXPECT_EQ(test_input2, init_vector2);
225
226 test_input1 = init_vector1;
227 test_input2 = init_vector2;
228 truncate.Truncate(test_input1, test_input2, 8);
229 EXPECT_EQ(test_input1, std::vector<int64_t>({1, 2, 3, 4}));
230 EXPECT_EQ(test_input2, std::vector<int64_t>({1, 2, 3, 4}));
231
232 test_input1 = init_vector1;
233 test_input2 = init_vector2;
234 truncate.Truncate(test_input1, test_input2, 9);
235 EXPECT_EQ(test_input1, std::vector<int64_t>({1, 2, 3, 4, 5}));
236 EXPECT_EQ(test_input2, std::vector<int64_t>({1, 2, 3, 4}));
237
238 test_input1 = init_vector1;
239 test_input2 = init_vector2;
240 truncate.Truncate(test_input1, test_input2, 12);
241 EXPECT_EQ(test_input1, std::vector<int64_t>({1, 2, 3, 4, 5, 6, 7}));
242 EXPECT_EQ(test_input2, std::vector<int64_t>({1, 2, 3, 4, 5}));
243
244 test_input1 = init_vector2;
245 test_input2 = init_vector1;
246 truncate.Truncate(test_input1, test_input2, 12);
247 EXPECT_EQ(test_input1, std::vector<int64_t>({1, 2, 3, 4, 5}));
248 EXPECT_EQ(test_input2, std::vector<int64_t>({1, 2, 3, 4, 5, 6 ,7}));
249}