openai/tiktoken

Public

mirrored from https://github.com/openai/tiktokenAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
next

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_encoding.py

245lines · modeblame

c373d9b9Shantanu Jain3 years ago1# Note that there are more actual tests, they're just not currently public :-)
2
3from typing import Callable
4
5import hypothesis
6import hypothesis.strategies as st
7import pytest
8
9import tiktoken
10
11from .test_helpers import ENCODING_FACTORIES, MAX_EXAMPLES
12
13
14def test_simple():
15enc = tiktoken.get_encoding("gpt2")
16assert enc.encode("hello world") == [31373, 995]
17assert enc.decode([31373, 995]) == "hello world"
18assert enc.encode("hello <|endoftext|>", allowed_special="all") == [31373, 220, 50256]
19
20enc = tiktoken.get_encoding("cl100k_base")
21assert enc.encode("hello world") == [15339, 1917]
22assert enc.decode([15339, 1917]) == "hello world"
23assert enc.encode("hello <|endoftext|>", allowed_special="all") == [15339, 220, 100257]
24
25for enc_name in tiktoken.list_encoding_names():
26enc = tiktoken.get_encoding(enc_name)
e80597faShantanu1 years ago27for token in range(min(10_000, enc.max_token_value - 1)):
c373d9b9Shantanu Jain3 years ago28assert enc.encode_single_token(enc.decode_single_token_bytes(token)) == token
29
30
31def test_simple_repeated():
32enc = tiktoken.get_encoding("gpt2")
33assert enc.encode("0") == [15]
34assert enc.encode("00") == [405]
35assert enc.encode("000") == [830]
36assert enc.encode("0000") == [2388]
37assert enc.encode("00000") == [20483]
38assert enc.encode("000000") == [10535]
39assert enc.encode("0000000") == [24598]
40assert enc.encode("00000000") == [8269]
41assert enc.encode("000000000") == [10535, 830]
42assert enc.encode("0000000000") == [8269, 405]
43assert enc.encode("00000000000") == [8269, 830]
44assert enc.encode("000000000000") == [8269, 2388]
45assert enc.encode("0000000000000") == [8269, 20483]
46assert enc.encode("00000000000000") == [8269, 10535]
47assert enc.encode("000000000000000") == [8269, 24598]
48assert enc.encode("0000000000000000") == [25645]
49assert enc.encode("00000000000000000") == [8269, 10535, 830]
50
51
52def test_simple_regex():
53enc = tiktoken.get_encoding("cl100k_base")
54assert enc.encode("rer") == [38149]
55assert enc.encode("'rer") == [2351, 81]
56assert enc.encode("today\n ") == [31213, 198, 220]
57assert enc.encode("today\n \n") == [31213, 27907]
58assert enc.encode("today\n \n") == [31213, 14211]
59
60
61def test_basic_encode():
62enc = tiktoken.get_encoding("r50k_base")
63assert enc.encode("hello world") == [31373, 995]
64
65enc = tiktoken.get_encoding("p50k_base")
66assert enc.encode("hello world") == [31373, 995]
67
68enc = tiktoken.get_encoding("cl100k_base")
69assert enc.encode("hello world") == [15339, 1917]
70assert enc.encode(" \x850") == [220, 126, 227, 15]
71
72
73def test_encode_empty():
74enc = tiktoken.get_encoding("r50k_base")
75assert enc.encode("") == []
76
77
78def test_encode_bytes():
79enc = tiktoken.get_encoding("cl100k_base")
80assert enc._encode_bytes(b" \xec\x8b\xa4\xed") == [62085]
81
82
83def test_encode_surrogate_pairs():
84enc = tiktoken.get_encoding("cl100k_base")
85
86assert enc.encode("👍") == [9468, 239, 235]
87# surrogate pair gets converted to codepoint
88assert enc.encode("\ud83d\udc4d") == [9468, 239, 235]
89
90# lone surrogate just gets replaced
91assert enc.encode("\ud83d") == enc.encode("�")
92
93
e80597faShantanu1 years ago94@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
95def test_catastrophically_repetitive(make_enc: Callable[[], tiktoken.Encoding]):
96enc = make_enc()
97for c in ["^", "0", "a", "'s", " ", "\n"]:
98big_value = c * 10_000
99assert big_value == enc.decode(enc.encode(big_value))
100
101big_value = " " + big_value
102assert big_value == enc.decode(enc.encode(big_value))
103
104big_value = big_value + "\n"
105assert big_value == enc.decode(enc.encode(big_value))
106
107
c373d9b9Shantanu Jain3 years ago108# ====================
109# Roundtrip
110# ====================
111
112
113@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
114def test_basic_roundtrip(make_enc):
115enc = make_enc()
116for value in (
117"hello",
118"hello ",
119"hello ",
120" hello",
121" hello ",
122" hello ",
123"hello world",
124"请考试我的软件!12345",
125):
126assert value == enc.decode(enc.encode(value))
127assert value == enc.decode(enc.encode_ordinary(value))
128
129
130@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
131@hypothesis.given(text=st.text())
132@hypothesis.settings(deadline=None)
133def test_hyp_roundtrip(make_enc: Callable[[], tiktoken.Encoding], text):
134enc = make_enc()
135
136assert text == enc.decode(enc.encode(text))
137
138
139@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
140def test_single_token_roundtrip(make_enc: Callable[[], tiktoken.Encoding]):
141enc = make_enc()
142
143for token in range(enc.n_vocab):
144try:
145token_bytes = enc.decode_single_token_bytes(token)
146except KeyError:
147continue
148assert enc.encode_single_token(token_bytes) == token
149
150
151# ====================
152# Special tokens
153# ====================
154
155
156def test_special_token():
157enc = tiktoken.get_encoding("cl100k_base")
158
159eot = enc.encode_single_token("<|endoftext|>")
160assert eot == enc.eot_token
161fip = enc.encode_single_token("<|fim_prefix|>")
162fim = enc.encode_single_token("<|fim_middle|>")
163
164text = "<|endoftext|> hello <|fim_prefix|>"
165assert eot not in enc.encode(text, disallowed_special=())
166with pytest.raises(ValueError):
167enc.encode(text)
168with pytest.raises(ValueError):
169enc.encode(text, disallowed_special="all")
170with pytest.raises(ValueError):
171enc.encode(text, disallowed_special={"<|endoftext|>"})
172with pytest.raises(ValueError):
173enc.encode(text, disallowed_special={"<|fim_prefix|>"})
174
175text = "<|endoftext|> hello <|fim_prefix|> there <|fim_middle|>"
176tokens = enc.encode(text, disallowed_special=())
177assert eot not in tokens
178assert fip not in tokens
179assert fim not in tokens
180
181tokens = enc.encode(text, allowed_special="all", disallowed_special=())
182assert eot in tokens
183assert fip in tokens
184assert fim in tokens
185
186tokens = enc.encode(text, allowed_special="all", disallowed_special="all")
187assert eot in tokens
188assert fip in tokens
189assert fim in tokens
190
191tokens = enc.encode(text, allowed_special={"<|fim_prefix|>"}, disallowed_special=())
192assert eot not in tokens
193assert fip in tokens
194assert fim not in tokens
195
196tokens = enc.encode(text, allowed_special={"<|endoftext|>"}, disallowed_special=())
197assert eot in tokens
198assert fip not in tokens
199assert fim not in tokens
200
201tokens = enc.encode(text, allowed_special={"<|fim_middle|>"}, disallowed_special=())
202assert eot not in tokens
203assert fip not in tokens
204assert fim in tokens
205
206
207@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
208@hypothesis.given(text=st.text())
209@hypothesis.settings(deadline=None, max_examples=MAX_EXAMPLES)
210def test_hyp_special_ordinary(make_enc, text: str):
211enc = make_enc()
212assert enc.encode_ordinary(text) == enc.encode(text, disallowed_special=())
213
214
215# ====================
216# Batch encoding
217# ====================
218
219
220@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
221def test_batch_encode(make_enc: Callable[[], tiktoken.Encoding]):
222enc = make_enc()
223text1 = "hello world"
224text2 = "goodbye world"
225
226assert enc.encode_batch([text1]) == [enc.encode(text1)]
227assert enc.encode_batch([text1, text2]) == [enc.encode(text1), enc.encode(text2)]
228
229assert enc.encode_ordinary_batch([text1]) == [enc.encode_ordinary(text1)]
230assert enc.encode_ordinary_batch([text1, text2]) == [
231enc.encode_ordinary(text1),
232enc.encode_ordinary(text2),
233]
234
235
236@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
237@hypothesis.given(batch=st.lists(st.text()))
238@hypothesis.settings(deadline=None)
239def test_hyp_batch_roundtrip(make_enc: Callable[[], tiktoken.Encoding], batch):
240enc = make_enc()
241
242encoded = enc.encode_batch(batch)
243assert encoded == [enc.encode(t) for t in batch]
244decoded = enc.decode_batch(encoded)
245assert decoded == batch