openai/tiktoken

Public

mirrored from https://github.com/openai/tiktokenAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
subquad2

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tiktoken/core.py

439lines · modecode

1from __future__ import annotations
2
3import functools
4from concurrent.futures import ThreadPoolExecutor
5from typing import TYPE_CHECKING, AbstractSet, Collection, Literal, NoReturn, Sequence
6
7import regex
8
9from tiktoken import _tiktoken
10
11if TYPE_CHECKING:
12 import numpy as np
13 import numpy.typing as npt
14
15
16class Encoding:
17 def __init__(
18 self,
19 name: str,
20 *,
21 pat_str: str,
22 mergeable_ranks: dict[bytes, int],
23 special_tokens: dict[str, int],
24 explicit_n_vocab: int | None = None,
25 ):
26 """Creates an Encoding object.
27
28 See openai_public.py for examples of how to construct an Encoding object.
29
30 Args:
31 name: The name of the encoding. It should be clear from the name of the encoding
32 what behaviour to expect, in particular, encodings with different special tokens
33 should have different names.
34 pat_str: A regex pattern string that is used to split the input text.
35 mergeable_ranks: A dictionary mapping mergeable token bytes to their ranks. The ranks
36 must correspond to merge priority.
37 special_tokens: A dictionary mapping special token strings to their token values.
38 explicit_n_vocab: The number of tokens in the vocabulary. If provided, it is checked
39 that the number of mergeable tokens and special tokens is equal to this number.
40 """
41 self.name = name
42
43 self._pat_str = pat_str
44 self._mergeable_ranks = mergeable_ranks
45 self._special_tokens = special_tokens
46
47 self.max_token_value = max(
48 max(mergeable_ranks.values()), max(special_tokens.values(), default=0)
49 )
50 if explicit_n_vocab:
51 assert len(mergeable_ranks) + len(special_tokens) == explicit_n_vocab
52 assert self.max_token_value == explicit_n_vocab - 1
53
54 self._core_bpe = _tiktoken.CoreBPE(mergeable_ranks, special_tokens, pat_str)
55
56 def __repr__(self) -> str:
57 return f"<Encoding {self.name!r}>"
58
59 # ====================
60 # Encoding
61 # ====================
62
63 def encode_ordinary(self, text: str) -> list[int]:
64 """Encodes a string into tokens, ignoring special tokens.
65
66 This is equivalent to `encode(text, disallowed_special=())` (but slightly faster).
67
68 ```
69 >>> enc.encode_ordinary("hello world")
70 [31373, 995]
71 """
72 try:
73 return self._core_bpe.encode_ordinary(text)
74 except UnicodeEncodeError:
75 # See comment in encode
76 text = text.encode("utf-16", "surrogatepass").decode("utf-16", "replace")
77 return self._core_bpe.encode_ordinary(text)
78
79 def encode(
80 self,
81 text: str,
82 *,
83 allowed_special: Literal["all"] | AbstractSet[str] = set(), # noqa: B006
84 disallowed_special: Literal["all"] | Collection[str] = "all",
85 ) -> list[int]:
86 """Encodes a string into tokens.
87
88 Special tokens are artificial tokens used to unlock capabilities from a model,
89 such as fill-in-the-middle. So we want to be careful about accidentally encoding special
90 tokens, since they can be used to trick a model into doing something we don't want it to do.
91
92 Hence, by default, encode will raise an error if it encounters text that corresponds
93 to a special token. This can be controlled on a per-token level using the `allowed_special`
94 and `disallowed_special` parameters. In particular:
95 - Setting `disallowed_special` to () will prevent this function from raising errors and
96 cause all text corresponding to special tokens to be encoded as natural text.
97 - Setting `allowed_special` to "all" will cause this function to treat all text
98 corresponding to special tokens to be encoded as special tokens.
99
100 ```
101 >>> enc.encode("hello world")
102 [31373, 995]
103 >>> enc.encode("<|endoftext|>", allowed_special={"<|endoftext|>"})
104 [50256]
105 >>> enc.encode("<|endoftext|>", allowed_special="all")
106 [50256]
107 >>> enc.encode("<|endoftext|>")
108 # Raises ValueError
109 >>> enc.encode("<|endoftext|>", disallowed_special=())
110 [27, 91, 437, 1659, 5239, 91, 29]
111 ```
112 """
113 if allowed_special == "all":
114 allowed_special = self.special_tokens_set
115 if disallowed_special == "all":
116 disallowed_special = self.special_tokens_set - allowed_special
117 if disallowed_special:
118 if not isinstance(disallowed_special, frozenset):
119 disallowed_special = frozenset(disallowed_special)
120 if match := _special_token_regex(disallowed_special).search(text):
121 raise_disallowed_special_token(match.group())
122
123 try:
124 return self._core_bpe.encode(text, allowed_special)
125 except UnicodeEncodeError:
126 # BPE operates on bytes, but the regex operates on unicode. If we pass a str that is
127 # invalid UTF-8 to Rust, it will rightfully complain. Here we do a quick and dirty
128 # fixup for any surrogate pairs that may have sneaked their way into the text.
129 # Technically, this introduces a place where encode + decode doesn't roundtrip a Python
130 # string, but given that this is input we want to support, maybe that's okay.
131 # Also we use errors="replace" to handle weird things like lone surrogates.
132 text = text.encode("utf-16", "surrogatepass").decode("utf-16", "replace")
133 return self._core_bpe.encode(text, allowed_special)
134
135 def encode_to_numpy(
136 self,
137 text: str,
138 *,
139 allowed_special: Literal["all"] | AbstractSet[str] = set(), # noqa: B006
140 disallowed_special: Literal["all"] | Collection[str] = "all",
141 ) -> npt.NDArray[np.uint32]:
142 """Encodes a string into tokens, returning a numpy array.
143
144 Avoids the overhead of copying the token buffer into a Python list.
145 """
146 if allowed_special == "all":
147 allowed_special = self.special_tokens_set
148 if disallowed_special == "all":
149 disallowed_special = self.special_tokens_set - allowed_special
150 if disallowed_special:
151 if not isinstance(disallowed_special, frozenset):
152 disallowed_special = frozenset(disallowed_special)
153 if match := _special_token_regex(disallowed_special).search(text):
154 raise_disallowed_special_token(match.group())
155
156 import numpy as np
157
158 buffer = self._core_bpe.encode_to_tiktoken_buffer(text, self.special_tokens_set)
159 return np.frombuffer(buffer, dtype=np.uint32)
160
161 def encode_ordinary_batch(self, text: list[str], *, num_threads: int = 8) -> list[list[int]]:
162 """Encodes a list of strings into tokens, in parallel, ignoring special tokens.
163
164 This is equivalent to `encode_batch(text, disallowed_special=())` (but slightly faster).
165
166 ```
167 >>> enc.encode_ordinary_batch(["hello world", "goodbye world"])
168 [[31373, 995], [11274, 16390, 995]]
169 ```
170 """
171 encoder = functools.partial(self.encode_ordinary)
172 with ThreadPoolExecutor(num_threads) as e:
173 return list(e.map(encoder, text))
174
175 def encode_batch(
176 self,
177 text: list[str],
178 *,
179 num_threads: int = 8,
180 allowed_special: Literal["all"] | AbstractSet[str] = set(), # noqa: B006
181 disallowed_special: Literal["all"] | Collection[str] = "all",
182 ) -> list[list[int]]:
183 """Encodes a list of strings into tokens, in parallel.
184
185 See `encode` for more details on `allowed_special` and `disallowed_special`.
186
187 ```
188 >>> enc.encode_batch(["hello world", "goodbye world"])
189 [[31373, 995], [11274, 16390, 995]]
190 ```
191 """
192 if allowed_special == "all":
193 allowed_special = self.special_tokens_set
194 if disallowed_special == "all":
195 disallowed_special = self.special_tokens_set - allowed_special
196 if not isinstance(disallowed_special, frozenset):
197 disallowed_special = frozenset(disallowed_special)
198
199 encoder = functools.partial(
200 self.encode, allowed_special=allowed_special, disallowed_special=disallowed_special
201 )
202 with ThreadPoolExecutor(num_threads) as e:
203 return list(e.map(encoder, text))
204
205 def encode_with_unstable(
206 self,
207 text: str,
208 *,
209 allowed_special: Literal["all"] | AbstractSet[str] = set(), # noqa: B006
210 disallowed_special: Literal["all"] | Collection[str] = "all",
211 ) -> tuple[list[int], list[list[int]]]:
212 """Encodes a string into stable tokens and possible completion sequences.
213
214 Note that the stable tokens will only represent a substring of `text`.
215
216 See `encode` for more details on `allowed_special` and `disallowed_special`.
217
218 This API should itself be considered unstable.
219
220 ```
221 >>> enc.encode_with_unstable("hello fanta")
222 ([31373], [(277, 4910), (5113, 265), ..., (8842,)])
223
224 >>> text = "..."
225 >>> stable_tokens, completions = enc.encode_with_unstable(text)
226 >>> assert text.encode().startswith(enc.decode_bytes(stable_tokens))
227 >>> assert all(enc.decode_bytes(stable_tokens + seq).startswith(text.encode()) for seq in completions)
228 ```
229 """
230 if allowed_special == "all":
231 allowed_special = self.special_tokens_set
232 if disallowed_special == "all":
233 disallowed_special = self.special_tokens_set - allowed_special
234 if disallowed_special:
235 if not isinstance(disallowed_special, frozenset):
236 disallowed_special = frozenset(disallowed_special)
237 if match := _special_token_regex(disallowed_special).search(text):
238 raise_disallowed_special_token(match.group())
239
240 return self._core_bpe.encode_with_unstable(text, allowed_special)
241
242 def encode_single_token(self, text_or_bytes: str | bytes) -> int:
243 """Encodes text corresponding to a single token to its token value.
244
245 NOTE: this will encode all special tokens.
246
247 Raises `KeyError` if the token is not in the vocabulary.
248
249 ```
250 >>> enc.encode_single_token("hello")
251 31373
252 ```
253 """
254 if isinstance(text_or_bytes, str):
255 text_or_bytes = text_or_bytes.encode("utf-8")
256 return self._core_bpe.encode_single_token(text_or_bytes)
257
258 # ====================
259 # Decoding
260 # ====================
261
262 def decode_bytes(self, tokens: Sequence[int]) -> bytes:
263 """Decodes a list of tokens into bytes.
264
265 ```
266 >>> enc.decode_bytes([31373, 995])
267 b'hello world'
268 ```
269 """
270 return self._core_bpe.decode_bytes(tokens)
271
272 def decode(self, tokens: Sequence[int], errors: str = "replace") -> str:
273 """Decodes a list of tokens into a string.
274
275 WARNING: the default behaviour of this function is lossy, since decoded bytes are not
276 guaranteed to be valid UTF-8. You can control this behaviour using the `errors` parameter,
277 for instance, setting `errors=strict`.
278
279 ```
280 >>> enc.decode([31373, 995])
281 'hello world'
282 ```
283 """
284 return self._core_bpe.decode_bytes(tokens).decode("utf-8", errors=errors)
285
286 def decode_single_token_bytes(self, token: int) -> bytes:
287 """Decodes a token into bytes.
288
289 NOTE: this will decode all special tokens.
290
291 Raises `KeyError` if the token is not in the vocabulary.
292
293 ```
294 >>> enc.decode_single_token_bytes(31373)
295 b'hello'
296 ```
297 """
298 return self._core_bpe.decode_single_token_bytes(token)
299
300 def decode_tokens_bytes(self, tokens: Sequence[int]) -> list[bytes]:
301 """Decodes a list of tokens into a list of bytes.
302
303 Useful for visualising tokenisation.
304 >>> enc.decode_tokens_bytes([31373, 995])
305 [b'hello', b' world']
306 """
307 return [self.decode_single_token_bytes(token) for token in tokens]
308
309 def decode_with_offsets(self, tokens: Sequence[int]) -> tuple[str, list[int]]:
310 """Decodes a list of tokens into a string and a list of offsets.
311
312 Each offset is the index into text corresponding to the start of each token.
313 If UTF-8 character boundaries do not line up with token boundaries, the offset is the index
314 of the first character that contains bytes from the token.
315
316 This will currently raise if given tokens that decode to invalid UTF-8; this behaviour may
317 change in the future to be more permissive.
318
319 >>> enc.decode_with_offsets([31373, 995])
320 ('hello world', [0, 5])
321 """
322 token_bytes = self.decode_tokens_bytes(tokens)
323
324 text_len = 0
325 offsets = []
326 for token in token_bytes:
327 offsets.append(max(0, text_len - (0x80 <= token[0] < 0xC0)))
328 text_len += sum(1 for c in token if not 0x80 <= c < 0xC0)
329
330 # TODO: assess correctness for errors="ignore" and errors="replace"
331 text = b"".join(token_bytes).decode("utf-8", errors="strict")
332 return text, offsets
333
334 def decode_batch(
335 self, batch: Sequence[Sequence[int]], *, errors: str = "replace", num_threads: int = 8
336 ) -> list[str]:
337 """Decodes a batch (list of lists of tokens) into a list of strings."""
338 decoder = functools.partial(self.decode, errors=errors)
339 with ThreadPoolExecutor(num_threads) as e:
340 return list(e.map(decoder, batch))
341
342 def decode_bytes_batch(
343 self, batch: Sequence[Sequence[int]], *, num_threads: int = 8
344 ) -> list[bytes]:
345 """Decodes a batch (list of lists of tokens) into a list of bytes."""
346 with ThreadPoolExecutor(num_threads) as e:
347 return list(e.map(self.decode_bytes, batch))
348
349 # ====================
350 # Miscellaneous
351 # ====================
352
353 def token_byte_values(self) -> list[bytes]:
354 """Returns the list of all token byte values."""
355 return self._core_bpe.token_byte_values()
356
357 @property
358 def eot_token(self) -> int:
359 return self._special_tokens["<|endoftext|>"]
360
361 @functools.cached_property
362 def special_tokens_set(self) -> set[str]:
363 return set(self._special_tokens.keys())
364
365 def is_special_token(self, token: int) -> bool:
366 assert isinstance(token, int)
367 return token in self._special_token_values
368
369 @property
370 def n_vocab(self) -> int:
371 """For backwards compatibility. Prefer to use `enc.max_token_value + 1`."""
372 return self.max_token_value + 1
373
374 # ====================
375 # Private
376 # ====================
377
378 def _encode_single_piece(self, text_or_bytes: str | bytes) -> list[int]:
379 """Encodes text corresponding to bytes without a regex split.
380
381 NOTE: this will not encode any special tokens.
382
383 ```
384 >>> enc.encode_single_piece("helloqqqq")
385 [31373, 38227, 38227]
386 ```
387 """
388 if isinstance(text_or_bytes, str):
389 text_or_bytes = text_or_bytes.encode("utf-8")
390 return self._core_bpe.encode_single_piece(text_or_bytes)
391
392 def _encode_only_native_bpe(self, text: str) -> list[int]:
393 """Encodes a string into tokens, but do regex splitting in Python."""
394 _unused_pat = regex.compile(self._pat_str)
395 ret = []
396 for piece in regex.findall(_unused_pat, text):
397 ret.extend(self._core_bpe.encode_single_piece(piece))
398 return ret
399
400 def _encode_bytes(self, text: bytes) -> list[int]:
401 return self._core_bpe._encode_bytes(text)
402
403 def __getstate__(self) -> object:
404 import tiktoken.registry
405
406 # As an optimisation, pickle registered encodings by reference
407 if self is tiktoken.registry.ENCODINGS.get(self.name):
408 return self.name
409 return {
410 "name": self.name,
411 "pat_str": self._pat_str,
412 "mergeable_ranks": self._mergeable_ranks,
413 "special_tokens": self._special_tokens,
414 }
415
416 def __setstate__(self, value: object) -> None:
417 import tiktoken.registry
418
419 if isinstance(value, str):
420 self.__dict__ = tiktoken.registry.get_encoding(value).__dict__
421 return
422 self.__init__(**value)
423
424
425@functools.lru_cache(maxsize=128)
426def _special_token_regex(tokens: frozenset[str]) -> "regex.Pattern[str]":
427 inner = "|".join(regex.escape(token) for token in tokens)
428 return regex.compile(f"({inner})")
429
430
431def raise_disallowed_special_token(token: str) -> NoReturn:
432 raise ValueError(
433 f"Encountered text corresponding to disallowed special token {token!r}.\n"
434 "If you want this text to be encoded as a special token, "
435 f"pass it to `allowed_special`, e.g. `allowed_special={{{token!r}, ...}}`.\n"
436 f"If you want this text to be encoded as normal text, disable the check for this token "
437 f"by passing `disallowed_special=(enc.special_tokens_set - {{{token!r}}})`.\n"
438 "To disable this check for all special tokens, pass `disallowed_special=()`.\n"
439 )