microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
ba200b4a0e391b45c0df4f9b1a506f0a9f574dd4

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

test/static_test/test_tokenizer.cc

119lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "gtest/gtest.h"
5#include "string_utils.h"
6#include "wordpiece_tokenizer.hpp"
7
8TEST(tokenizer, bert_word_split) {
9 ustring ind("##");
10 ustring text("A AAA B BB");
11 std::vector<std::u32string> words;
12 KernelWordpieceTokenizer_Split(ind, text, words);
13 std::vector<std::u32string> expected{ustring("A"), ustring("AAA"), ustring("B"), ustring("BB")};
14 EXPECT_EQ(expected, words);
15
16 text = ustring(" A AAA B BB ");
17 KernelWordpieceTokenizer_Split(ind, text, words);
18 EXPECT_EQ(words, expected);
19}
20
21std::unordered_map<std::u32string, int32_t> get_vocabulary_basic() {
22 std::vector<ustring> vocab_tokens = {
23 ustring("[UNK]"),
24 ustring("[CLS]"),
25 ustring("[SEP]"),
26 ustring("[PAD]"),
27 ustring("[MASK]"),
28 ustring("want"),
29 ustring("##want"),
30 ustring("##ed"),
31 ustring("wa"),
32 ustring("un"),
33 ustring("runn"),
34 ustring("##ing"),
35 ustring(","),
36 ustring("low"),
37 ustring("lowest"),
38 };
39 std::unordered_map<std::u32string, int32_t> vocab;
40 for (auto it = vocab_tokens.begin(); it != vocab_tokens.end(); ++it) {
41 vocab[*it] = vocab.size();
42 }
43 return vocab;
44}
45
46TEST(tokenizer, wordpiece_basic_tokenizer) {
47 auto vocab = get_vocabulary_basic();
48 std::vector<ustring> text = {ustring("UNwant\u00E9d,running")};
49 std::vector<ustring> tokens;
50 std::vector<int32_t> indices;
51 std::vector<int64_t> rows;
52 KernelWordpieceTokenizer_Tokenizer(vocab, ustring("##"), ustring("[unk]"), text, tokens, indices, rows);
53 //EXPECT_EQ(indices, std::vector<int32_t>({9, 6, 7, 12, 10, 11}));
54 //EXPECT_EQ(rows, std::vector<int64_t>({0, 6}));
55}
56
57std::unordered_map<std::u32string, int32_t> get_vocabulary_wordpiece() {
58 std::vector<ustring> vocab_tokens = {
59 ustring("[UNK]"), // 0
60 ustring("[CLS]"), // 1
61 ustring("[SEP]"), // 2
62 ustring("want"), // 3
63 ustring("##want"), // 4
64 ustring("##ed"), // 5
65 ustring("wa"), // 6
66 ustring("un"), // 7
67 ustring("runn"), // 8
68 ustring("##ing"), // 9
69 };
70 std::unordered_map<std::u32string, int32_t> vocab;
71 for (auto it = vocab_tokens.begin(); it != vocab_tokens.end(); ++it) {
72 vocab[*it] = vocab.size();
73 }
74 return vocab;
75}
76
77TEST(tokenizer, wordpiece_wordpiece_tokenizer) {
78 auto vocab = get_vocabulary_wordpiece();
79 std::vector<int32_t> indices;
80 std::vector<int64_t> rows;
81 std::vector<ustring> tokens;
82
83 std::vector<ustring> text = {ustring("unwanted running")}; // "un", "##want", "##ed", "runn", "##ing"
84 KernelWordpieceTokenizer_Tokenizer(vocab, ustring("##"), ustring("[UNK]"), text, tokens, indices, rows);
85 EXPECT_EQ(tokens, std::vector<ustring>({ustring("un"), ustring("##want"), ustring("##ed"),
86 ustring("runn"), ustring("##ing")}));
87 EXPECT_EQ(indices, std::vector<int32_t>({7, 4, 5, 8, 9}));
88 EXPECT_EQ(rows, std::vector<int64_t>({0, 5}));
89
90 text = std::vector<ustring>({ustring("unwantedX running")}); // "[UNK]", "runn", "##ing"
91 KernelWordpieceTokenizer_Tokenizer(vocab, ustring("##"), ustring("[UNK]"), text, tokens, indices, rows);
92 EXPECT_EQ(tokens, std::vector<ustring>({ustring("un"), ustring("##want"), ustring("##ed"),
93 ustring("[UNK]"), ustring("runn"), ustring("##ing")}));
94 EXPECT_EQ(indices, std::vector<int32_t>({7, 4, 5, -1, 8, 9}));
95 EXPECT_EQ(rows, std::vector<int64_t>({0, 6}));
96
97 text = std::vector<ustring>({ustring("")}); //
98 KernelWordpieceTokenizer_Tokenizer(vocab, ustring("##"), ustring("[unk]"), text, tokens, indices, rows);
99 EXPECT_EQ(tokens, std::vector<ustring>());
100 EXPECT_EQ(indices, std::vector<int32_t>());
101 EXPECT_EQ(rows, std::vector<int64_t>({0, 0}));
102}
103
104TEST(tokenizer, bert_wordpiece_tokenizer_rows) {
105 auto vocab = get_vocabulary_wordpiece();
106 std::vector<int32_t> indices;
107 std::vector<int64_t> rows;
108 std::vector<ustring> tokens;
109
110 std::vector<int64_t> existing_indices({0, 2, 3});
111 std::vector<ustring> text = {ustring("unwanted"), ustring("running"), ustring("running")};
112 KernelWordpieceTokenizer_Tokenizer(vocab, ustring("##"), ustring("[UNK]"), text, tokens, indices, rows,
113 existing_indices.data(), existing_indices.size());
114 EXPECT_EQ(tokens, std::vector<ustring>({ustring("un"), ustring("##want"), ustring("##ed"),
115 ustring("runn"), ustring("##ing"),
116 ustring("runn"), ustring("##ing")}));
117 EXPECT_EQ(indices, std::vector<int32_t>({7, 4, 5, 8, 9, 8, 9}));
118 EXPECT_EQ(rows, std::vector<int64_t>({0, 5, 7}));
119}
120