openai/tiktoken

Public

mirrored fromhttps://github.com/openai/tiktokenAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
0.7.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_encoding.py

231lines · modecode

1# Note that there are more actual tests, they're just not currently public :-)
2
3from typing import Callable
4
5import hypothesis
6import hypothesis.strategies as st
7import pytest
8
9import tiktoken
10
11from .test_helpers import ENCODING_FACTORIES, MAX_EXAMPLES
12
13
14def test_simple():
15 enc = tiktoken.get_encoding("gpt2")
16 assert enc.encode("hello world") == [31373, 995]
17 assert enc.decode([31373, 995]) == "hello world"
18 assert enc.encode("hello <|endoftext|>", allowed_special="all") == [31373, 220, 50256]
19
20 enc = tiktoken.get_encoding("cl100k_base")
21 assert enc.encode("hello world") == [15339, 1917]
22 assert enc.decode([15339, 1917]) == "hello world"
23 assert enc.encode("hello <|endoftext|>", allowed_special="all") == [15339, 220, 100257]
24
25 for enc_name in tiktoken.list_encoding_names():
26 enc = tiktoken.get_encoding(enc_name)
27 for token in range(10_000):
28 assert enc.encode_single_token(enc.decode_single_token_bytes(token)) == token
29
30
31def test_simple_repeated():
32 enc = tiktoken.get_encoding("gpt2")
33 assert enc.encode("0") == [15]
34 assert enc.encode("00") == [405]
35 assert enc.encode("000") == [830]
36 assert enc.encode("0000") == [2388]
37 assert enc.encode("00000") == [20483]
38 assert enc.encode("000000") == [10535]
39 assert enc.encode("0000000") == [24598]
40 assert enc.encode("00000000") == [8269]
41 assert enc.encode("000000000") == [10535, 830]
42 assert enc.encode("0000000000") == [8269, 405]
43 assert enc.encode("00000000000") == [8269, 830]
44 assert enc.encode("000000000000") == [8269, 2388]
45 assert enc.encode("0000000000000") == [8269, 20483]
46 assert enc.encode("00000000000000") == [8269, 10535]
47 assert enc.encode("000000000000000") == [8269, 24598]
48 assert enc.encode("0000000000000000") == [25645]
49 assert enc.encode("00000000000000000") == [8269, 10535, 830]
50
51
52def test_simple_regex():
53 enc = tiktoken.get_encoding("cl100k_base")
54 assert enc.encode("rer") == [38149]
55 assert enc.encode("'rer") == [2351, 81]
56 assert enc.encode("today\n ") == [31213, 198, 220]
57 assert enc.encode("today\n \n") == [31213, 27907]
58 assert enc.encode("today\n \n") == [31213, 14211]
59
60
61def test_basic_encode():
62 enc = tiktoken.get_encoding("r50k_base")
63 assert enc.encode("hello world") == [31373, 995]
64
65 enc = tiktoken.get_encoding("p50k_base")
66 assert enc.encode("hello world") == [31373, 995]
67
68 enc = tiktoken.get_encoding("cl100k_base")
69 assert enc.encode("hello world") == [15339, 1917]
70 assert enc.encode(" \x850") == [220, 126, 227, 15]
71
72
73def test_encode_empty():
74 enc = tiktoken.get_encoding("r50k_base")
75 assert enc.encode("") == []
76
77
78def test_encode_bytes():
79 enc = tiktoken.get_encoding("cl100k_base")
80 assert enc._encode_bytes(b" \xec\x8b\xa4\xed") == [62085]
81
82
83def test_encode_surrogate_pairs():
84 enc = tiktoken.get_encoding("cl100k_base")
85
86 assert enc.encode("👍") == [9468, 239, 235]
87 # surrogate pair gets converted to codepoint
88 assert enc.encode("\ud83d\udc4d") == [9468, 239, 235]
89
90 # lone surrogate just gets replaced
91 assert enc.encode("\ud83d") == enc.encode("�")
92
93
94# ====================
95# Roundtrip
96# ====================
97
98
99@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
100def test_basic_roundtrip(make_enc):
101 enc = make_enc()
102 for value in (
103 "hello",
104 "hello ",
105 "hello ",
106 " hello",
107 " hello ",
108 " hello ",
109 "hello world",
110 "请考试我的软件!12345",
111 ):
112 assert value == enc.decode(enc.encode(value))
113 assert value == enc.decode(enc.encode_ordinary(value))
114
115
116@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
117@hypothesis.given(text=st.text())
118@hypothesis.settings(deadline=None)
119def test_hyp_roundtrip(make_enc: Callable[[], tiktoken.Encoding], text):
120 enc = make_enc()
121
122 assert text == enc.decode(enc.encode(text))
123
124
125@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
126def test_single_token_roundtrip(make_enc: Callable[[], tiktoken.Encoding]):
127 enc = make_enc()
128
129 for token in range(enc.n_vocab):
130 try:
131 token_bytes = enc.decode_single_token_bytes(token)
132 except KeyError:
133 continue
134 assert enc.encode_single_token(token_bytes) == token
135
136
137# ====================
138# Special tokens
139# ====================
140
141
142def test_special_token():
143 enc = tiktoken.get_encoding("cl100k_base")
144
145 eot = enc.encode_single_token("<|endoftext|>")
146 assert eot == enc.eot_token
147 fip = enc.encode_single_token("<|fim_prefix|>")
148 fim = enc.encode_single_token("<|fim_middle|>")
149
150 text = "<|endoftext|> hello <|fim_prefix|>"
151 assert eot not in enc.encode(text, disallowed_special=())
152 with pytest.raises(ValueError):
153 enc.encode(text)
154 with pytest.raises(ValueError):
155 enc.encode(text, disallowed_special="all")
156 with pytest.raises(ValueError):
157 enc.encode(text, disallowed_special={"<|endoftext|>"})
158 with pytest.raises(ValueError):
159 enc.encode(text, disallowed_special={"<|fim_prefix|>"})
160
161 text = "<|endoftext|> hello <|fim_prefix|> there <|fim_middle|>"
162 tokens = enc.encode(text, disallowed_special=())
163 assert eot not in tokens
164 assert fip not in tokens
165 assert fim not in tokens
166
167 tokens = enc.encode(text, allowed_special="all", disallowed_special=())
168 assert eot in tokens
169 assert fip in tokens
170 assert fim in tokens
171
172 tokens = enc.encode(text, allowed_special="all", disallowed_special="all")
173 assert eot in tokens
174 assert fip in tokens
175 assert fim in tokens
176
177 tokens = enc.encode(text, allowed_special={"<|fim_prefix|>"}, disallowed_special=())
178 assert eot not in tokens
179 assert fip in tokens
180 assert fim not in tokens
181
182 tokens = enc.encode(text, allowed_special={"<|endoftext|>"}, disallowed_special=())
183 assert eot in tokens
184 assert fip not in tokens
185 assert fim not in tokens
186
187 tokens = enc.encode(text, allowed_special={"<|fim_middle|>"}, disallowed_special=())
188 assert eot not in tokens
189 assert fip not in tokens
190 assert fim in tokens
191
192
193@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
194@hypothesis.given(text=st.text())
195@hypothesis.settings(deadline=None, max_examples=MAX_EXAMPLES)
196def test_hyp_special_ordinary(make_enc, text: str):
197 enc = make_enc()
198 assert enc.encode_ordinary(text) == enc.encode(text, disallowed_special=())
199
200
201# ====================
202# Batch encoding
203# ====================
204
205
206@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
207def test_batch_encode(make_enc: Callable[[], tiktoken.Encoding]):
208 enc = make_enc()
209 text1 = "hello world"
210 text2 = "goodbye world"
211
212 assert enc.encode_batch([text1]) == [enc.encode(text1)]
213 assert enc.encode_batch([text1, text2]) == [enc.encode(text1), enc.encode(text2)]
214
215 assert enc.encode_ordinary_batch([text1]) == [enc.encode_ordinary(text1)]
216 assert enc.encode_ordinary_batch([text1, text2]) == [
217 enc.encode_ordinary(text1),
218 enc.encode_ordinary(text2),
219 ]
220
221
222@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
223@hypothesis.given(batch=st.lists(st.text()))
224@hypothesis.settings(deadline=None)
225def test_hyp_batch_roundtrip(make_enc: Callable[[], tiktoken.Encoding], batch):
226 enc = make_enc()
227
228 encoded = enc.encode_batch(batch)
229 assert encoded == [enc.encode(t) for t in batch]
230 decoded = enc.decode_batch(encoded)
231 assert decoded == batch
232