openai/tiktoken
Publicmirrored from https://github.com/openai/tiktokenAvailable
tests/test_encoding.py
245lines · modeblame
c373d9b9Shantanu Jain3 years ago | 1 | # Note that there are more actual tests, they're just not currently public :-) |
| 2 | | |
| 3 | from typing import Callable | |
| 4 | | |
| 5 | import hypothesis | |
| 6 | import hypothesis.strategies as st | |
| 7 | import pytest | |
| 8 | | |
| 9 | import tiktoken | |
| 10 | | |
| 11 | from .test_helpers import ENCODING_FACTORIES, MAX_EXAMPLES | |
| 12 | | |
| 13 | | |
| 14 | def test_simple(): | |
| 15 | enc = tiktoken.get_encoding("gpt2") | |
| 16 | assert enc.encode("hello world") == [31373, 995] | |
| 17 | assert enc.decode([31373, 995]) == "hello world" | |
| 18 | assert enc.encode("hello <|endoftext|>", allowed_special="all") == [31373, 220, 50256] | |
| 19 | | |
| 20 | enc = tiktoken.get_encoding("cl100k_base") | |
| 21 | assert enc.encode("hello world") == [15339, 1917] | |
| 22 | assert enc.decode([15339, 1917]) == "hello world" | |
| 23 | assert enc.encode("hello <|endoftext|>", allowed_special="all") == [15339, 220, 100257] | |
| 24 | | |
| 25 | for enc_name in tiktoken.list_encoding_names(): | |
| 26 | enc = tiktoken.get_encoding(enc_name) | |
e80597faShantanu1 years ago | 27 | for token in range(min(10_000, enc.max_token_value - 1)): |
c373d9b9Shantanu Jain3 years ago | 28 | assert enc.encode_single_token(enc.decode_single_token_bytes(token)) == token |
| 29 | | |
| 30 | | |
| 31 | def test_simple_repeated(): | |
| 32 | enc = tiktoken.get_encoding("gpt2") | |
| 33 | assert enc.encode("0") == [15] | |
| 34 | assert enc.encode("00") == [405] | |
| 35 | assert enc.encode("000") == [830] | |
| 36 | assert enc.encode("0000") == [2388] | |
| 37 | assert enc.encode("00000") == [20483] | |
| 38 | assert enc.encode("000000") == [10535] | |
| 39 | assert enc.encode("0000000") == [24598] | |
| 40 | assert enc.encode("00000000") == [8269] | |
| 41 | assert enc.encode("000000000") == [10535, 830] | |
| 42 | assert enc.encode("0000000000") == [8269, 405] | |
| 43 | assert enc.encode("00000000000") == [8269, 830] | |
| 44 | assert enc.encode("000000000000") == [8269, 2388] | |
| 45 | assert enc.encode("0000000000000") == [8269, 20483] | |
| 46 | assert enc.encode("00000000000000") == [8269, 10535] | |
| 47 | assert enc.encode("000000000000000") == [8269, 24598] | |
| 48 | assert enc.encode("0000000000000000") == [25645] | |
| 49 | assert enc.encode("00000000000000000") == [8269, 10535, 830] | |
| 50 | | |
| 51 | | |
| 52 | def test_simple_regex(): | |
| 53 | enc = tiktoken.get_encoding("cl100k_base") | |
| 54 | assert enc.encode("rer") == [38149] | |
| 55 | assert enc.encode("'rer") == [2351, 81] | |
| 56 | assert enc.encode("today\n ") == [31213, 198, 220] | |
| 57 | assert enc.encode("today\n \n") == [31213, 27907] | |
| 58 | assert enc.encode("today\n \n") == [31213, 14211] | |
| 59 | | |
| 60 | | |
| 61 | def test_basic_encode(): | |
| 62 | enc = tiktoken.get_encoding("r50k_base") | |
| 63 | assert enc.encode("hello world") == [31373, 995] | |
| 64 | | |
| 65 | enc = tiktoken.get_encoding("p50k_base") | |
| 66 | assert enc.encode("hello world") == [31373, 995] | |
| 67 | | |
| 68 | enc = tiktoken.get_encoding("cl100k_base") | |
| 69 | assert enc.encode("hello world") == [15339, 1917] | |
| 70 | assert enc.encode(" \x850") == [220, 126, 227, 15] | |
| 71 | | |
| 72 | | |
| 73 | def test_encode_empty(): | |
| 74 | enc = tiktoken.get_encoding("r50k_base") | |
| 75 | assert enc.encode("") == [] | |
| 76 | | |
| 77 | | |
| 78 | def test_encode_bytes(): | |
| 79 | enc = tiktoken.get_encoding("cl100k_base") | |
| 80 | assert enc._encode_bytes(b" \xec\x8b\xa4\xed") == [62085] | |
| 81 | | |
| 82 | | |
| 83 | def test_encode_surrogate_pairs(): | |
| 84 | enc = tiktoken.get_encoding("cl100k_base") | |
| 85 | | |
| 86 | assert enc.encode("👍") == [9468, 239, 235] | |
| 87 | # surrogate pair gets converted to codepoint | |
| 88 | assert enc.encode("\ud83d\udc4d") == [9468, 239, 235] | |
| 89 | | |
| 90 | # lone surrogate just gets replaced | |
| 91 | assert enc.encode("\ud83d") == enc.encode("�") | |
| 92 | | |
| 93 | | |
e80597faShantanu1 years ago | 94 | @pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) |
| 95 | def test_catastrophically_repetitive(make_enc: Callable[[], tiktoken.Encoding]): | |
| 96 | enc = make_enc() | |
| 97 | for c in ["^", "0", "a", "'s", " ", "\n"]: | |
| 98 | big_value = c * 10_000 | |
| 99 | assert big_value == enc.decode(enc.encode(big_value)) | |
| 100 | | |
| 101 | big_value = " " + big_value | |
| 102 | assert big_value == enc.decode(enc.encode(big_value)) | |
| 103 | | |
| 104 | big_value = big_value + "\n" | |
| 105 | assert big_value == enc.decode(enc.encode(big_value)) | |
| 106 | | |
| 107 | | |
c373d9b9Shantanu Jain3 years ago | 108 | # ==================== |
| 109 | # Roundtrip | |
| 110 | # ==================== | |
| 111 | | |
| 112 | | |
| 113 | @pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) | |
| 114 | def test_basic_roundtrip(make_enc): | |
| 115 | enc = make_enc() | |
| 116 | for value in ( | |
| 117 | "hello", | |
| 118 | "hello ", | |
| 119 | "hello ", | |
| 120 | " hello", | |
| 121 | " hello ", | |
| 122 | " hello ", | |
| 123 | "hello world", | |
| 124 | "请考试我的软件!12345", | |
| 125 | ): | |
| 126 | assert value == enc.decode(enc.encode(value)) | |
| 127 | assert value == enc.decode(enc.encode_ordinary(value)) | |
| 128 | | |
| 129 | | |
| 130 | @pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) | |
| 131 | @hypothesis.given(text=st.text()) | |
| 132 | @hypothesis.settings(deadline=None) | |
| 133 | def test_hyp_roundtrip(make_enc: Callable[[], tiktoken.Encoding], text): | |
| 134 | enc = make_enc() | |
| 135 | | |
| 136 | assert text == enc.decode(enc.encode(text)) | |
| 137 | | |
| 138 | | |
| 139 | @pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) | |
| 140 | def test_single_token_roundtrip(make_enc: Callable[[], tiktoken.Encoding]): | |
| 141 | enc = make_enc() | |
| 142 | | |
| 143 | for token in range(enc.n_vocab): | |
| 144 | try: | |
| 145 | token_bytes = enc.decode_single_token_bytes(token) | |
| 146 | except KeyError: | |
| 147 | continue | |
| 148 | assert enc.encode_single_token(token_bytes) == token | |
| 149 | | |
| 150 | | |
| 151 | # ==================== | |
| 152 | # Special tokens | |
| 153 | # ==================== | |
| 154 | | |
| 155 | | |
| 156 | def test_special_token(): | |
| 157 | enc = tiktoken.get_encoding("cl100k_base") | |
| 158 | | |
| 159 | eot = enc.encode_single_token("<|endoftext|>") | |
| 160 | assert eot == enc.eot_token | |
| 161 | fip = enc.encode_single_token("<|fim_prefix|>") | |
| 162 | fim = enc.encode_single_token("<|fim_middle|>") | |
| 163 | | |
| 164 | text = "<|endoftext|> hello <|fim_prefix|>" | |
| 165 | assert eot not in enc.encode(text, disallowed_special=()) | |
| 166 | with pytest.raises(ValueError): | |
| 167 | enc.encode(text) | |
| 168 | with pytest.raises(ValueError): | |
| 169 | enc.encode(text, disallowed_special="all") | |
| 170 | with pytest.raises(ValueError): | |
| 171 | enc.encode(text, disallowed_special={"<|endoftext|>"}) | |
| 172 | with pytest.raises(ValueError): | |
| 173 | enc.encode(text, disallowed_special={"<|fim_prefix|>"}) | |
| 174 | | |
| 175 | text = "<|endoftext|> hello <|fim_prefix|> there <|fim_middle|>" | |
| 176 | tokens = enc.encode(text, disallowed_special=()) | |
| 177 | assert eot not in tokens | |
| 178 | assert fip not in tokens | |
| 179 | assert fim not in tokens | |
| 180 | | |
| 181 | tokens = enc.encode(text, allowed_special="all", disallowed_special=()) | |
| 182 | assert eot in tokens | |
| 183 | assert fip in tokens | |
| 184 | assert fim in tokens | |
| 185 | | |
| 186 | tokens = enc.encode(text, allowed_special="all", disallowed_special="all") | |
| 187 | assert eot in tokens | |
| 188 | assert fip in tokens | |
| 189 | assert fim in tokens | |
| 190 | | |
| 191 | tokens = enc.encode(text, allowed_special={"<|fim_prefix|>"}, disallowed_special=()) | |
| 192 | assert eot not in tokens | |
| 193 | assert fip in tokens | |
| 194 | assert fim not in tokens | |
| 195 | | |
| 196 | tokens = enc.encode(text, allowed_special={"<|endoftext|>"}, disallowed_special=()) | |
| 197 | assert eot in tokens | |
| 198 | assert fip not in tokens | |
| 199 | assert fim not in tokens | |
| 200 | | |
| 201 | tokens = enc.encode(text, allowed_special={"<|fim_middle|>"}, disallowed_special=()) | |
| 202 | assert eot not in tokens | |
| 203 | assert fip not in tokens | |
| 204 | assert fim in tokens | |
| 205 | | |
| 206 | | |
| 207 | @pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) | |
| 208 | @hypothesis.given(text=st.text()) | |
| 209 | @hypothesis.settings(deadline=None, max_examples=MAX_EXAMPLES) | |
| 210 | def test_hyp_special_ordinary(make_enc, text: str): | |
| 211 | enc = make_enc() | |
| 212 | assert enc.encode_ordinary(text) == enc.encode(text, disallowed_special=()) | |
| 213 | | |
| 214 | | |
| 215 | # ==================== | |
| 216 | # Batch encoding | |
| 217 | # ==================== | |
| 218 | | |
| 219 | | |
| 220 | @pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) | |
| 221 | def test_batch_encode(make_enc: Callable[[], tiktoken.Encoding]): | |
| 222 | enc = make_enc() | |
| 223 | text1 = "hello world" | |
| 224 | text2 = "goodbye world" | |
| 225 | | |
| 226 | assert enc.encode_batch([text1]) == [enc.encode(text1)] | |
| 227 | assert enc.encode_batch([text1, text2]) == [enc.encode(text1), enc.encode(text2)] | |
| 228 | | |
| 229 | assert enc.encode_ordinary_batch([text1]) == [enc.encode_ordinary(text1)] | |
| 230 | assert enc.encode_ordinary_batch([text1, text2]) == [ | |
| 231 | enc.encode_ordinary(text1), | |
| 232 | enc.encode_ordinary(text2), | |
| 233 | ] | |
| 234 | | |
| 235 | | |
| 236 | @pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) | |
| 237 | @hypothesis.given(batch=st.lists(st.text())) | |
| 238 | @hypothesis.settings(deadline=None) | |
| 239 | def test_hyp_batch_roundtrip(make_enc: Callable[[], tiktoken.Encoding], batch): | |
| 240 | enc = make_enc() | |
| 241 | | |
| 242 | encoded = enc.encode_batch(batch) | |
| 243 | assert encoded == [enc.encode(t) for t in batch] | |
| 244 | decoded = enc.decode_batch(encoded) | |
| 245 | assert decoded == batch |