openai/tiktoken
Publicmirrored from https://github.com/openai/tiktokenAvailable
tiktoken/core.py
439lines · modecode
| 1 | from __future__ import annotations |
| 2 | |
| 3 | import functools |
| 4 | from concurrent.futures import ThreadPoolExecutor |
| 5 | from typing import TYPE_CHECKING, AbstractSet, Collection, Literal, NoReturn, Sequence |
| 6 | |
| 7 | import regex |
| 8 | |
| 9 | from tiktoken import _tiktoken |
| 10 | |
| 11 | if TYPE_CHECKING: |
| 12 | import numpy as np |
| 13 | import numpy.typing as npt |
| 14 | |
| 15 | |
| 16 | class Encoding: |
| 17 | def __init__( |
| 18 | self, |
| 19 | name: str, |
| 20 | *, |
| 21 | pat_str: str, |
| 22 | mergeable_ranks: dict[bytes, int], |
| 23 | special_tokens: dict[str, int], |
| 24 | explicit_n_vocab: int | None = None, |
| 25 | ): |
| 26 | """Creates an Encoding object. |
| 27 | |
| 28 | See openai_public.py for examples of how to construct an Encoding object. |
| 29 | |
| 30 | Args: |
| 31 | name: The name of the encoding. It should be clear from the name of the encoding |
| 32 | what behaviour to expect, in particular, encodings with different special tokens |
| 33 | should have different names. |
| 34 | pat_str: A regex pattern string that is used to split the input text. |
| 35 | mergeable_ranks: A dictionary mapping mergeable token bytes to their ranks. The ranks |
| 36 | must correspond to merge priority. |
| 37 | special_tokens: A dictionary mapping special token strings to their token values. |
| 38 | explicit_n_vocab: The number of tokens in the vocabulary. If provided, it is checked |
| 39 | that the number of mergeable tokens and special tokens is equal to this number. |
| 40 | """ |
| 41 | self.name = name |
| 42 | |
| 43 | self._pat_str = pat_str |
| 44 | self._mergeable_ranks = mergeable_ranks |
| 45 | self._special_tokens = special_tokens |
| 46 | |
| 47 | self.max_token_value = max( |
| 48 | max(mergeable_ranks.values()), max(special_tokens.values(), default=0) |
| 49 | ) |
| 50 | if explicit_n_vocab: |
| 51 | assert len(mergeable_ranks) + len(special_tokens) == explicit_n_vocab |
| 52 | assert self.max_token_value == explicit_n_vocab - 1 |
| 53 | |
| 54 | self._core_bpe = _tiktoken.CoreBPE(mergeable_ranks, special_tokens, pat_str) |
| 55 | |
| 56 | def __repr__(self) -> str: |
| 57 | return f"<Encoding {self.name!r}>" |
| 58 | |
| 59 | # ==================== |
| 60 | # Encoding |
| 61 | # ==================== |
| 62 | |
| 63 | def encode_ordinary(self, text: str) -> list[int]: |
| 64 | """Encodes a string into tokens, ignoring special tokens. |
| 65 | |
| 66 | This is equivalent to `encode(text, disallowed_special=())` (but slightly faster). |
| 67 | |
| 68 | ``` |
| 69 | >>> enc.encode_ordinary("hello world") |
| 70 | [31373, 995] |
| 71 | """ |
| 72 | try: |
| 73 | return self._core_bpe.encode_ordinary(text) |
| 74 | except UnicodeEncodeError: |
| 75 | # See comment in encode |
| 76 | text = text.encode("utf-16", "surrogatepass").decode("utf-16", "replace") |
| 77 | return self._core_bpe.encode_ordinary(text) |
| 78 | |
| 79 | def encode( |
| 80 | self, |
| 81 | text: str, |
| 82 | *, |
| 83 | allowed_special: Literal["all"] | AbstractSet[str] = set(), # noqa: B006 |
| 84 | disallowed_special: Literal["all"] | Collection[str] = "all", |
| 85 | ) -> list[int]: |
| 86 | """Encodes a string into tokens. |
| 87 | |
| 88 | Special tokens are artificial tokens used to unlock capabilities from a model, |
| 89 | such as fill-in-the-middle. So we want to be careful about accidentally encoding special |
| 90 | tokens, since they can be used to trick a model into doing something we don't want it to do. |
| 91 | |
| 92 | Hence, by default, encode will raise an error if it encounters text that corresponds |
| 93 | to a special token. This can be controlled on a per-token level using the `allowed_special` |
| 94 | and `disallowed_special` parameters. In particular: |
| 95 | - Setting `disallowed_special` to () will prevent this function from raising errors and |
| 96 | cause all text corresponding to special tokens to be encoded as natural text. |
| 97 | - Setting `allowed_special` to "all" will cause this function to treat all text |
| 98 | corresponding to special tokens to be encoded as special tokens. |
| 99 | |
| 100 | ``` |
| 101 | >>> enc.encode("hello world") |
| 102 | [31373, 995] |
| 103 | >>> enc.encode("<|endoftext|>", allowed_special={"<|endoftext|>"}) |
| 104 | [50256] |
| 105 | >>> enc.encode("<|endoftext|>", allowed_special="all") |
| 106 | [50256] |
| 107 | >>> enc.encode("<|endoftext|>") |
| 108 | # Raises ValueError |
| 109 | >>> enc.encode("<|endoftext|>", disallowed_special=()) |
| 110 | [27, 91, 437, 1659, 5239, 91, 29] |
| 111 | ``` |
| 112 | """ |
| 113 | if allowed_special == "all": |
| 114 | allowed_special = self.special_tokens_set |
| 115 | if disallowed_special == "all": |
| 116 | disallowed_special = self.special_tokens_set - allowed_special |
| 117 | if disallowed_special: |
| 118 | if not isinstance(disallowed_special, frozenset): |
| 119 | disallowed_special = frozenset(disallowed_special) |
| 120 | if match := _special_token_regex(disallowed_special).search(text): |
| 121 | raise_disallowed_special_token(match.group()) |
| 122 | |
| 123 | try: |
| 124 | return self._core_bpe.encode(text, allowed_special) |
| 125 | except UnicodeEncodeError: |
| 126 | # BPE operates on bytes, but the regex operates on unicode. If we pass a str that is |
| 127 | # invalid UTF-8 to Rust, it will rightfully complain. Here we do a quick and dirty |
| 128 | # fixup for any surrogate pairs that may have sneaked their way into the text. |
| 129 | # Technically, this introduces a place where encode + decode doesn't roundtrip a Python |
| 130 | # string, but given that this is input we want to support, maybe that's okay. |
| 131 | # Also we use errors="replace" to handle weird things like lone surrogates. |
| 132 | text = text.encode("utf-16", "surrogatepass").decode("utf-16", "replace") |
| 133 | return self._core_bpe.encode(text, allowed_special) |
| 134 | |
| 135 | def encode_to_numpy( |
| 136 | self, |
| 137 | text: str, |
| 138 | *, |
| 139 | allowed_special: Literal["all"] | AbstractSet[str] = set(), # noqa: B006 |
| 140 | disallowed_special: Literal["all"] | Collection[str] = "all", |
| 141 | ) -> npt.NDArray[np.uint32]: |
| 142 | """Encodes a string into tokens, returning a numpy array. |
| 143 | |
| 144 | Avoids the overhead of copying the token buffer into a Python list. |
| 145 | """ |
| 146 | if allowed_special == "all": |
| 147 | allowed_special = self.special_tokens_set |
| 148 | if disallowed_special == "all": |
| 149 | disallowed_special = self.special_tokens_set - allowed_special |
| 150 | if disallowed_special: |
| 151 | if not isinstance(disallowed_special, frozenset): |
| 152 | disallowed_special = frozenset(disallowed_special) |
| 153 | if match := _special_token_regex(disallowed_special).search(text): |
| 154 | raise_disallowed_special_token(match.group()) |
| 155 | |
| 156 | import numpy as np |
| 157 | |
| 158 | buffer = self._core_bpe.encode_to_tiktoken_buffer(text, self.special_tokens_set) |
| 159 | return np.frombuffer(buffer, dtype=np.uint32) |
| 160 | |
| 161 | def encode_ordinary_batch(self, text: list[str], *, num_threads: int = 8) -> list[list[int]]: |
| 162 | """Encodes a list of strings into tokens, in parallel, ignoring special tokens. |
| 163 | |
| 164 | This is equivalent to `encode_batch(text, disallowed_special=())` (but slightly faster). |
| 165 | |
| 166 | ``` |
| 167 | >>> enc.encode_ordinary_batch(["hello world", "goodbye world"]) |
| 168 | [[31373, 995], [11274, 16390, 995]] |
| 169 | ``` |
| 170 | """ |
| 171 | encoder = functools.partial(self.encode_ordinary) |
| 172 | with ThreadPoolExecutor(num_threads) as e: |
| 173 | return list(e.map(encoder, text)) |
| 174 | |
| 175 | def encode_batch( |
| 176 | self, |
| 177 | text: list[str], |
| 178 | *, |
| 179 | num_threads: int = 8, |
| 180 | allowed_special: Literal["all"] | AbstractSet[str] = set(), # noqa: B006 |
| 181 | disallowed_special: Literal["all"] | Collection[str] = "all", |
| 182 | ) -> list[list[int]]: |
| 183 | """Encodes a list of strings into tokens, in parallel. |
| 184 | |
| 185 | See `encode` for more details on `allowed_special` and `disallowed_special`. |
| 186 | |
| 187 | ``` |
| 188 | >>> enc.encode_batch(["hello world", "goodbye world"]) |
| 189 | [[31373, 995], [11274, 16390, 995]] |
| 190 | ``` |
| 191 | """ |
| 192 | if allowed_special == "all": |
| 193 | allowed_special = self.special_tokens_set |
| 194 | if disallowed_special == "all": |
| 195 | disallowed_special = self.special_tokens_set - allowed_special |
| 196 | if not isinstance(disallowed_special, frozenset): |
| 197 | disallowed_special = frozenset(disallowed_special) |
| 198 | |
| 199 | encoder = functools.partial( |
| 200 | self.encode, allowed_special=allowed_special, disallowed_special=disallowed_special |
| 201 | ) |
| 202 | with ThreadPoolExecutor(num_threads) as e: |
| 203 | return list(e.map(encoder, text)) |
| 204 | |
| 205 | def encode_with_unstable( |
| 206 | self, |
| 207 | text: str, |
| 208 | *, |
| 209 | allowed_special: Literal["all"] | AbstractSet[str] = set(), # noqa: B006 |
| 210 | disallowed_special: Literal["all"] | Collection[str] = "all", |
| 211 | ) -> tuple[list[int], list[list[int]]]: |
| 212 | """Encodes a string into stable tokens and possible completion sequences. |
| 213 | |
| 214 | Note that the stable tokens will only represent a substring of `text`. |
| 215 | |
| 216 | See `encode` for more details on `allowed_special` and `disallowed_special`. |
| 217 | |
| 218 | This API should itself be considered unstable. |
| 219 | |
| 220 | ``` |
| 221 | >>> enc.encode_with_unstable("hello fanta") |
| 222 | ([31373], [(277, 4910), (5113, 265), ..., (8842,)]) |
| 223 | |
| 224 | >>> text = "..." |
| 225 | >>> stable_tokens, completions = enc.encode_with_unstable(text) |
| 226 | >>> assert text.encode().startswith(enc.decode_bytes(stable_tokens)) |
| 227 | >>> assert all(enc.decode_bytes(stable_tokens + seq).startswith(text.encode()) for seq in completions) |
| 228 | ``` |
| 229 | """ |
| 230 | if allowed_special == "all": |
| 231 | allowed_special = self.special_tokens_set |
| 232 | if disallowed_special == "all": |
| 233 | disallowed_special = self.special_tokens_set - allowed_special |
| 234 | if disallowed_special: |
| 235 | if not isinstance(disallowed_special, frozenset): |
| 236 | disallowed_special = frozenset(disallowed_special) |
| 237 | if match := _special_token_regex(disallowed_special).search(text): |
| 238 | raise_disallowed_special_token(match.group()) |
| 239 | |
| 240 | return self._core_bpe.encode_with_unstable(text, allowed_special) |
| 241 | |
| 242 | def encode_single_token(self, text_or_bytes: str | bytes) -> int: |
| 243 | """Encodes text corresponding to a single token to its token value. |
| 244 | |
| 245 | NOTE: this will encode all special tokens. |
| 246 | |
| 247 | Raises `KeyError` if the token is not in the vocabulary. |
| 248 | |
| 249 | ``` |
| 250 | >>> enc.encode_single_token("hello") |
| 251 | 31373 |
| 252 | ``` |
| 253 | """ |
| 254 | if isinstance(text_or_bytes, str): |
| 255 | text_or_bytes = text_or_bytes.encode("utf-8") |
| 256 | return self._core_bpe.encode_single_token(text_or_bytes) |
| 257 | |
| 258 | # ==================== |
| 259 | # Decoding |
| 260 | # ==================== |
| 261 | |
| 262 | def decode_bytes(self, tokens: Sequence[int]) -> bytes: |
| 263 | """Decodes a list of tokens into bytes. |
| 264 | |
| 265 | ``` |
| 266 | >>> enc.decode_bytes([31373, 995]) |
| 267 | b'hello world' |
| 268 | ``` |
| 269 | """ |
| 270 | return self._core_bpe.decode_bytes(tokens) |
| 271 | |
| 272 | def decode(self, tokens: Sequence[int], errors: str = "replace") -> str: |
| 273 | """Decodes a list of tokens into a string. |
| 274 | |
| 275 | WARNING: the default behaviour of this function is lossy, since decoded bytes are not |
| 276 | guaranteed to be valid UTF-8. You can control this behaviour using the `errors` parameter, |
| 277 | for instance, setting `errors=strict`. |
| 278 | |
| 279 | ``` |
| 280 | >>> enc.decode([31373, 995]) |
| 281 | 'hello world' |
| 282 | ``` |
| 283 | """ |
| 284 | return self._core_bpe.decode_bytes(tokens).decode("utf-8", errors=errors) |
| 285 | |
| 286 | def decode_single_token_bytes(self, token: int) -> bytes: |
| 287 | """Decodes a token into bytes. |
| 288 | |
| 289 | NOTE: this will decode all special tokens. |
| 290 | |
| 291 | Raises `KeyError` if the token is not in the vocabulary. |
| 292 | |
| 293 | ``` |
| 294 | >>> enc.decode_single_token_bytes(31373) |
| 295 | b'hello' |
| 296 | ``` |
| 297 | """ |
| 298 | return self._core_bpe.decode_single_token_bytes(token) |
| 299 | |
| 300 | def decode_tokens_bytes(self, tokens: Sequence[int]) -> list[bytes]: |
| 301 | """Decodes a list of tokens into a list of bytes. |
| 302 | |
| 303 | Useful for visualising tokenisation. |
| 304 | >>> enc.decode_tokens_bytes([31373, 995]) |
| 305 | [b'hello', b' world'] |
| 306 | """ |
| 307 | return [self.decode_single_token_bytes(token) for token in tokens] |
| 308 | |
| 309 | def decode_with_offsets(self, tokens: Sequence[int]) -> tuple[str, list[int]]: |
| 310 | """Decodes a list of tokens into a string and a list of offsets. |
| 311 | |
| 312 | Each offset is the index into text corresponding to the start of each token. |
| 313 | If UTF-8 character boundaries do not line up with token boundaries, the offset is the index |
| 314 | of the first character that contains bytes from the token. |
| 315 | |
| 316 | This will currently raise if given tokens that decode to invalid UTF-8; this behaviour may |
| 317 | change in the future to be more permissive. |
| 318 | |
| 319 | >>> enc.decode_with_offsets([31373, 995]) |
| 320 | ('hello world', [0, 5]) |
| 321 | """ |
| 322 | token_bytes = self.decode_tokens_bytes(tokens) |
| 323 | |
| 324 | text_len = 0 |
| 325 | offsets = [] |
| 326 | for token in token_bytes: |
| 327 | offsets.append(max(0, text_len - (0x80 <= token[0] < 0xC0))) |
| 328 | text_len += sum(1 for c in token if not 0x80 <= c < 0xC0) |
| 329 | |
| 330 | # TODO: assess correctness for errors="ignore" and errors="replace" |
| 331 | text = b"".join(token_bytes).decode("utf-8", errors="strict") |
| 332 | return text, offsets |
| 333 | |
| 334 | def decode_batch( |
| 335 | self, batch: Sequence[Sequence[int]], *, errors: str = "replace", num_threads: int = 8 |
| 336 | ) -> list[str]: |
| 337 | """Decodes a batch (list of lists of tokens) into a list of strings.""" |
| 338 | decoder = functools.partial(self.decode, errors=errors) |
| 339 | with ThreadPoolExecutor(num_threads) as e: |
| 340 | return list(e.map(decoder, batch)) |
| 341 | |
| 342 | def decode_bytes_batch( |
| 343 | self, batch: Sequence[Sequence[int]], *, num_threads: int = 8 |
| 344 | ) -> list[bytes]: |
| 345 | """Decodes a batch (list of lists of tokens) into a list of bytes.""" |
| 346 | with ThreadPoolExecutor(num_threads) as e: |
| 347 | return list(e.map(self.decode_bytes, batch)) |
| 348 | |
| 349 | # ==================== |
| 350 | # Miscellaneous |
| 351 | # ==================== |
| 352 | |
| 353 | def token_byte_values(self) -> list[bytes]: |
| 354 | """Returns the list of all token byte values.""" |
| 355 | return self._core_bpe.token_byte_values() |
| 356 | |
| 357 | @property |
| 358 | def eot_token(self) -> int: |
| 359 | return self._special_tokens["<|endoftext|>"] |
| 360 | |
| 361 | @functools.cached_property |
| 362 | def special_tokens_set(self) -> set[str]: |
| 363 | return set(self._special_tokens.keys()) |
| 364 | |
| 365 | def is_special_token(self, token: int) -> bool: |
| 366 | assert isinstance(token, int) |
| 367 | return token in self._special_token_values |
| 368 | |
| 369 | @property |
| 370 | def n_vocab(self) -> int: |
| 371 | """For backwards compatibility. Prefer to use `enc.max_token_value + 1`.""" |
| 372 | return self.max_token_value + 1 |
| 373 | |
| 374 | # ==================== |
| 375 | # Private |
| 376 | # ==================== |
| 377 | |
| 378 | def _encode_single_piece(self, text_or_bytes: str | bytes) -> list[int]: |
| 379 | """Encodes text corresponding to bytes without a regex split. |
| 380 | |
| 381 | NOTE: this will not encode any special tokens. |
| 382 | |
| 383 | ``` |
| 384 | >>> enc.encode_single_piece("helloqqqq") |
| 385 | [31373, 38227, 38227] |
| 386 | ``` |
| 387 | """ |
| 388 | if isinstance(text_or_bytes, str): |
| 389 | text_or_bytes = text_or_bytes.encode("utf-8") |
| 390 | return self._core_bpe.encode_single_piece(text_or_bytes) |
| 391 | |
| 392 | def _encode_only_native_bpe(self, text: str) -> list[int]: |
| 393 | """Encodes a string into tokens, but do regex splitting in Python.""" |
| 394 | _unused_pat = regex.compile(self._pat_str) |
| 395 | ret = [] |
| 396 | for piece in regex.findall(_unused_pat, text): |
| 397 | ret.extend(self._core_bpe.encode_single_piece(piece)) |
| 398 | return ret |
| 399 | |
| 400 | def _encode_bytes(self, text: bytes) -> list[int]: |
| 401 | return self._core_bpe._encode_bytes(text) |
| 402 | |
| 403 | def __getstate__(self) -> object: |
| 404 | import tiktoken.registry |
| 405 | |
| 406 | # As an optimisation, pickle registered encodings by reference |
| 407 | if self is tiktoken.registry.ENCODINGS.get(self.name): |
| 408 | return self.name |
| 409 | return { |
| 410 | "name": self.name, |
| 411 | "pat_str": self._pat_str, |
| 412 | "mergeable_ranks": self._mergeable_ranks, |
| 413 | "special_tokens": self._special_tokens, |
| 414 | } |
| 415 | |
| 416 | def __setstate__(self, value: object) -> None: |
| 417 | import tiktoken.registry |
| 418 | |
| 419 | if isinstance(value, str): |
| 420 | self.__dict__ = tiktoken.registry.get_encoding(value).__dict__ |
| 421 | return |
| 422 | self.__init__(**value) |
| 423 | |
| 424 | |
| 425 | @functools.lru_cache(maxsize=128) |
| 426 | def _special_token_regex(tokens: frozenset[str]) -> "regex.Pattern[str]": |
| 427 | inner = "|".join(regex.escape(token) for token in tokens) |
| 428 | return regex.compile(f"({inner})") |
| 429 | |
| 430 | |
| 431 | def raise_disallowed_special_token(token: str) -> NoReturn: |
| 432 | raise ValueError( |
| 433 | f"Encountered text corresponding to disallowed special token {token!r}.\n" |
| 434 | "If you want this text to be encoded as a special token, " |
| 435 | f"pass it to `allowed_special`, e.g. `allowed_special={{{token!r}, ...}}`.\n" |
| 436 | f"If you want this text to be encoded as normal text, disable the check for this token " |
| 437 | f"by passing `disallowed_special=(enc.special_tokens_set - {{{token!r}}})`.\n" |
| 438 | "To disable this check for all special tokens, pass `disallowed_special=()`.\n" |
| 439 | ) |