openai/tiktoken

Public

mirrored from https://github.com/openai/tiktokenAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
0.11.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

src/lib.rs

574lines · modecode

1use std::borrow::Borrow;
2use std::borrow::Cow;
3use std::collections::HashSet;
4use std::num::NonZeroU64;
5use std::thread;
6
7use fancy_regex::Regex;
8#[cfg(feature = "python")]
9use pyo3::prelude::*;
10use rustc_hash::FxHashMap as HashMap;
11
12#[cfg(feature = "python")]
13mod py;
14
15pub type Rank = u32;
16
17fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
18 // This is a vector of (start, rank).
19 // The rank is of the pair starting at position start.
20 let mut parts = Vec::with_capacity(piece.len() + 1);
21
22 // Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE
23 // the way we currently do, this is equivalent. An easy way to break this would be to decouple
24 // merge priority from token index or to prevent specific token merges.
25 let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
26 for i in 0..piece.len() - 1 {
27 let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
28 if rank < min_rank.0 {
29 min_rank = (rank, i);
30 }
31 parts.push((i, rank));
32 }
33 parts.push((piece.len() - 1, Rank::MAX));
34 parts.push((piece.len(), Rank::MAX));
35
36 let get_rank = {
37 #[inline(always)]
38 |parts: &Vec<(usize, Rank)>, i: usize| {
39 if (i + 3) < parts.len() {
40 // Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted
41 // parts[i + 1], see comment in the main loop.
42 *ranks
43 .get(&piece[parts[i].0..parts[i + 3].0])
44 .unwrap_or(&Rank::MAX)
45 } else {
46 Rank::MAX
47 }
48 }
49 };
50
51 // If you have n parts and m merges, this does O(mn) work.
52 // We could do something with a heap and do O(m log n) work.
53 // n is often very small so considerations like cache-locality outweigh the algorithmic
54 // complexity downsides of the `parts` vector.
55 while min_rank.0 != Rank::MAX {
56 let i = min_rank.1;
57 // Update parts[i] and parts[i - 1] before removing parts[i + 1], since
58 // `parts.remove(i + 1)` will thrash the cache.
59 if i > 0 {
60 parts[i - 1].1 = get_rank(&parts, i - 1);
61 }
62 parts[i].1 = get_rank(&parts, i);
63 parts.remove(i + 1);
64
65 min_rank = (Rank::MAX, usize::MAX);
66 for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
67 if rank < min_rank.0 {
68 min_rank = (rank, i);
69 }
70 }
71 }
72 parts
73}
74
75pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
76 if piece.len() == 1 {
77 return vec![ranks[piece]];
78 }
79 _byte_pair_merge(ranks, piece)
80 .windows(2)
81 .map(|part| ranks[&piece[part[0].0..part[1].0]])
82 .collect()
83}
84
85pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<&'a [u8]> {
86 assert!(piece.len() > 1);
87 _byte_pair_merge(ranks, piece)
88 .windows(2)
89 .map(|part| &piece[part[0].0..part[1].0])
90 .collect()
91}
92
93// Various performance notes:
94//
95// Regex
96// =====
97// Most of the time is spent in regex. The easiest way to speed this up is by using less fancy
98// regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than
99// the usual regex we use.
100//
101// However, given that we're using a regex parse-able by `regex`, there isn't much difference
102// between using the `regex` crate and using the `fancy_regex` crate.
103//
104// There is an important interaction between threading, `regex` and `fancy_regex`.
105// When using `fancy_regex`, we hit `regex.find_at`. It turns out that this causes contention on
106// some mutable scratch space inside of `regex`. This absolutely kills performance. When using plain
107// old `regex`, we don't hit this, because `find_iter` has a different code path.
108// Related: https://github.com/rust-lang/regex/blob/master/PERFORMANCE.md
109// Anyway, the way we get around this is with having a (mostly) thread local clone of the regex for
110// each thread.
111//
112// Threading
113// =========
114// I tried using `rayon`. It wasn't really faster than using Python threads and releasing the GIL.
115// So goodbye `rayon`! Let thread count etc be in control of our Python users.
116//
117// Caching
118// =======
119// The reference tokeniser has an lru cache over the equivalent of `byte_pair_encode`.
120// Originally, we had one too! Without it, we were only vaguely faster than Python.
121// I used an RWLock to protect the cache. This didn't seem to hurt single threaded performance
122// noticeably, but it did affect multi-threaded performance. Weirdly, it seemed to affect
123// multi-threaded performance even when I only had readers (maybed I messed something up?).
124// Anyway, I realised that we could get rid of the cache, if we treat the set of tokens as a cache!
125// These are exactly the set or merges that are likely to be hot. And now we don't have to think
126// about interior mutability, memory use, or cloning.
127//
128// Hashing
129// =======
130// We use FxHashMap instead of the standard HashMap. This is maybe like a 5-10% win?
131// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made
132// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster.
133
134struct FakeThreadId(NonZeroU64);
135
136fn hash_current_thread() -> usize {
137 // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter
138 // that works great for our use case of avoiding collisions in our array. Unfortunately,
139 // it's private. However, there are only so many ways you can layout a u64, so just transmute
140 // https://github.com/rust-lang/rust/issues/67939
141 const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()];
142 const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
143 let x = unsafe {
144 std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0
145 };
146 u64::from(x) as usize
147}
148
149#[derive(Debug, Clone)]
150pub struct DecodeKeyError {
151 pub token: Rank,
152}
153
154impl std::fmt::Display for DecodeKeyError {
155 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
156 write!(f, "Invalid token for decoding: {}", self.token)
157 }
158}
159
160impl std::error::Error for DecodeKeyError {}
161
162#[derive(Debug, Clone)]
163pub struct DecodeError {
164 pub message: String,
165}
166
167impl std::fmt::Display for DecodeError {
168 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
169 write!(f, "Could not decode tokens: {}", self.message)
170 }
171}
172
173impl std::error::Error for DecodeError {}
174
175#[derive(Debug, Clone)]
176pub struct EncodeError {
177 pub message: String,
178}
179
180impl std::fmt::Display for EncodeError {
181 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
182 write!(f, "Could not encode string: {}", self.message)
183 }
184}
185
186impl std::error::Error for EncodeError {}
187
188const MAX_NUM_THREADS: usize = 128;
189
190#[cfg_attr(feature = "python", pyclass)]
191#[derive(Clone)]
192pub struct CoreBPE {
193 encoder: HashMap<Vec<u8>, Rank>,
194 special_tokens_encoder: HashMap<String, Rank>,
195 decoder: HashMap<Rank, Vec<u8>>,
196 special_tokens_decoder: HashMap<Rank, Vec<u8>>,
197 regex_tls: Vec<Regex>,
198 special_regex_tls: Vec<Regex>,
199 sorted_token_bytes: Vec<Vec<u8>>,
200}
201
202impl CoreBPE {
203 fn _get_tl_regex(&self) -> &Regex {
204 // See performance notes above for what this is about
205 // It's also a little janky, please make a better version of it!
206 // However, it's nice that this doesn't leak memory to short-lived threads
207 &self.regex_tls[hash_current_thread() % MAX_NUM_THREADS]
208 }
209
210 fn _get_tl_special_regex(&self) -> &Regex {
211 &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS]
212 }
213
214 /// Decodes tokens into a list of bytes.
215 ///
216 /// The bytes are not gauranteed to be a valid utf-8 string.
217 fn decode_bytes(&self, tokens: &[Rank]) -> Result<Vec<u8>, DecodeKeyError> {
218 let mut ret = Vec::with_capacity(tokens.len() * 2);
219 for &token in tokens {
220 let token_bytes = match self.decoder.get(&token) {
221 Some(bytes) => bytes,
222 None => self
223 .special_tokens_decoder
224 .get(&token)
225 .ok_or(DecodeKeyError { token })?,
226 };
227 ret.extend(token_bytes);
228 }
229 Ok(ret)
230 }
231
232 pub fn encode_ordinary(&self, text: &str) -> Vec<Rank> {
233 // This is the core of the encoding logic; the other functions in here
234 // just make things complicated :-)
235 let regex = self._get_tl_regex();
236 let mut ret = vec![];
237 for mat in regex.find_iter(text) {
238 let piece = mat.unwrap().as_str().as_bytes();
239 match self.encoder.get(piece) {
240 Some(token) => ret.push(*token),
241 None => ret.extend(&byte_pair_encode(piece, &self.encoder)),
242 }
243 }
244 ret
245 }
246
247 pub fn encode(
248 &self,
249 text: &str,
250 allowed_special: &HashSet<&str>,
251 ) -> Result<(Vec<Rank>, usize), EncodeError> {
252 let special_regex = self._get_tl_special_regex();
253 let regex = self._get_tl_regex();
254 let mut ret = vec![];
255
256 let mut start = 0;
257 let mut last_piece_token_len = 0;
258 loop {
259 let mut next_special;
260 let mut start_find = start;
261 loop {
262 // Find the next allowed special token, if any
263 next_special = special_regex.find_from_pos(text, start_find).unwrap();
264 match next_special {
265 Some(m) => {
266 if allowed_special.contains(&text[m.start()..m.end()]) {
267 break;
268 }
269 start_find = m.start() + 1;
270 }
271 None => break,
272 }
273 }
274 let end = next_special.map_or(text.len(), |m| m.start());
275
276 // Okay, here we go, compare this logic to encode_ordinary
277 for mat_res in regex.find_iter(&text[start..end]) {
278 let mat = match mat_res {
279 Ok(m) => m,
280 Err(e) => {
281 return Err(EncodeError {
282 message: format!("Regex error while tokenizing: {e}"),
283 });
284 }
285 };
286
287 let piece = mat.as_str().as_bytes();
288 if let Some(token) = self.encoder.get(piece) {
289 last_piece_token_len = 1;
290 ret.push(*token);
291 continue;
292 }
293 let tokens = byte_pair_encode(piece, &self.encoder);
294 last_piece_token_len = tokens.len();
295 ret.extend(&tokens);
296 }
297
298 match next_special {
299 // And here we push the special token
300 Some(m) => {
301 let piece = m.as_str();
302 let token = self.special_tokens_encoder[piece];
303 ret.push(token);
304 start = m.end();
305 last_piece_token_len = 0;
306 }
307 None => break,
308 }
309 }
310
311 // last_piece_token_len is how many tokens came from the last regex split. This is used
312 // for determining unstable tokens, since you can't merge across (stable) regex splits
313 Ok((ret, last_piece_token_len))
314 }
315
316 fn _increase_last_piece_token_len(
317 &self,
318 tokens: Vec<Rank>,
319 mut last_piece_token_len: usize,
320 ) -> (Vec<Rank>, usize) {
321 // Unfortunately, the locations where our regex splits can be unstable.
322 // For the purposes of determining unstable tokens, unstable regex splitting
323 // is only a problem if a split that was present disappears, since this can
324 // lead to merging of tokens otherwise thought to be stable.
325 // cl100k_base makes our life hard by including the \s*[\r\n]+
326 // pattern. This can e.g. cause "\n" + " " to become "\n \n".
327 // Here is a quick and dirty fix:
328 {
329 let token_is_all_space = |token| {
330 self.decoder
331 .get(token)
332 .map(|token_bytes| {
333 token_bytes
334 .iter()
335 .rev()
336 .all(|&b| [b' ', b'\n', b'\t'].contains(&b))
337 })
338 .unwrap_or(false)
339 };
340 if last_piece_token_len > 0
341 && token_is_all_space(&tokens[tokens.len() - last_piece_token_len])
342 {
343 while (last_piece_token_len < tokens.len())
344 && token_is_all_space(&tokens[tokens.len() - last_piece_token_len - 1])
345 {
346 last_piece_token_len += 1;
347 }
348 }
349 }
350 debug_assert!(last_piece_token_len <= tokens.len());
351
352 (tokens, last_piece_token_len)
353 }
354
355 pub fn _encode_unstable_native(
356 &self,
357 text: &str,
358 allowed_special: &HashSet<&str>,
359 ) -> (Vec<Rank>, HashSet<Vec<Rank>>) {
360 let (tokens, last_piece_token_len) = self.encode(text, allowed_special).unwrap();
361 if last_piece_token_len == 0 {
362 // If last_piece_token_len is zero, the last token was a special token and we have
363 // no unstable bytes
364 return (tokens, HashSet::new());
365 }
366 let (mut tokens, last_piece_token_len) =
367 self._increase_last_piece_token_len(tokens, last_piece_token_len);
368
369 let unstable_bytes = self
370 .decode_bytes(&tokens[tokens.len() - last_piece_token_len..])
371 .unwrap();
372 tokens.truncate(tokens.len() - last_piece_token_len);
373
374 // TODO: we should try harder to find additional stable tokens
375 // This would reduce the amount of retokenising when determining completions
376 // Refer to the logic in an older version of this file
377
378 let mut completions = HashSet::new();
379 if unstable_bytes.is_empty() {
380 return (tokens, completions);
381 }
382
383 // This is the easy bit. Just find all single tokens that start with unstable_bytes
384 // (including tokens that exactly match unstable_bytes)
385 // Separating this from the loop below helps with performance in a common case.
386 let mut point = self
387 .sorted_token_bytes
388 .partition_point(|x| x.as_slice() < unstable_bytes.as_slice());
389 while point < self.sorted_token_bytes.len()
390 && self.sorted_token_bytes[point].starts_with(&unstable_bytes)
391 {
392 completions.insert(vec![
393 self.encoder[self.sorted_token_bytes[point].as_slice()],
394 ]);
395 point += 1;
396 }
397
398 // Now apply even more brute force. At every (other) possible position for the straddling
399 // token, concatenate additional bytes from that token (if any) to unstable_bytes,
400 // and retokenise the whole thing and see what we get.
401 for i in 1..unstable_bytes.len() {
402 let prefix = &unstable_bytes[..i];
403 let suffix = &unstable_bytes[i..];
404 let mut point = self
405 .sorted_token_bytes
406 .partition_point(|x| x.as_slice() < suffix);
407 // TODO: Perf optimisation if suffix starts with " "?
408 while point < self.sorted_token_bytes.len()
409 && self.sorted_token_bytes[point].starts_with(suffix)
410 {
411 let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat();
412 let encoded = match std::str::from_utf8(&possibility) {
413 // Morally, this is byte_pair_encode(&possibility, &self.encoder)
414 // But we might have introduced a regex split which would prevent merges.
415 // (particularly possible in the presence of unstable regex splits)
416 // So convert to UTF-8 and do regex splitting.
417 // E.g. with cl100k_base " !" gets split to " " + " !",
418 // but byte_pair_encode(" !") != byte_pair_encode(" ")
419 Ok(s) => self.encode_ordinary(s),
420
421 // Technically, whether or not this arm is correct depends on whether there
422 // would be a regex split before the UTF-8 truncation point.
423 // Probably niche enough that no one will ever notice (after all, people didn't
424 // notice all the big holes in the previous unstable token implementation)
425 Err(_) => byte_pair_encode(&possibility, &self.encoder),
426 // Something like the following is intriguing but incorrect:
427 // Err(e) => self.encode_ordinary(unsafe {
428 // std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()])
429 // }),
430 };
431 let mut seq = Vec::new();
432 let mut seq_len = 0;
433 for token in encoded {
434 seq.push(token);
435 seq_len += self.decoder[&token].len();
436 if seq_len >= unstable_bytes.len() {
437 break;
438 }
439 }
440 completions.insert(seq);
441 point += 1;
442 }
443 }
444
445 // This is also not straightforward. While we generally assume that regex splits are stable,
446 // unfortunately, they are not. That is, if adding bytes were to make a split appear in
447 // unstable_bytes, this could make tokens possible which our logic would otherwise think
448 // would be merged.
449 // For example, with gpt2, the use of \s+(?!\S) means that "\n\n" could
450 // develop a split, e.g. "\n\n0" splits into "\n"+"\n"+"0", making "\n" a possible token.
451 // Here is a quick and dirty fix:
452 // This isn't right if we ever remove \s+(?!\S)
453 if unstable_bytes.len() > 1 {
454 let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice());
455 if unstable_bytes.len() - last_decoded.1 > 0
456 && last_decoded.0.is_some_and(|c| c.is_whitespace())
457 {
458 let mut reencoded = byte_pair_encode(
459 &unstable_bytes[..unstable_bytes.len() - last_decoded.1],
460 &self.encoder,
461 );
462 reencoded.extend(byte_pair_encode(
463 &unstable_bytes[unstable_bytes.len() - last_decoded.1..],
464 &self.encoder,
465 ));
466 completions.insert(reencoded);
467 }
468 }
469
470 (tokens, completions)
471 }
472
473 pub fn new<E, SE, NSE>(
474 encoder: E,
475 special_tokens_encoder: SE,
476 pattern: &str,
477 ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>>
478 where
479 E: IntoIterator<Item = (Vec<u8>, Rank)>,
480 SE: IntoIterator<Item = (String, Rank)>,
481 NSE: IntoIterator<Item = (String, (Rank, Rank))>,
482 {
483 Self::new_internal(
484 HashMap::from_iter(encoder),
485 HashMap::from_iter(special_tokens_encoder),
486 pattern,
487 )
488 }
489
490 fn new_internal(
491 encoder: HashMap<Vec<u8>, Rank>,
492 special_tokens_encoder: HashMap<String, Rank>,
493 pattern: &str,
494 ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
495 let regex = Regex::new(pattern)?;
496
497 let special_regex = {
498 let parts = special_tokens_encoder
499 .keys()
500 .map(|s| fancy_regex::escape(s))
501 .collect::<Vec<_>>();
502 Regex::new(&parts.join("|"))?
503 };
504
505 let decoder: HashMap<Rank, Vec<u8>> =
506 encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
507
508 assert!(
509 encoder.len() == decoder.len(),
510 "Encoder and decoder must be of equal length. Encoder length: {}, decoder length: {}.\nMaybe you had duplicate token indices in your encoder?",
511 encoder.len(),
512 decoder.len()
513 );
514
515 let special_tokens_decoder: HashMap<Rank, Vec<u8>> = special_tokens_encoder
516 .iter()
517 .map(|(k, v)| (*v, k.as_bytes().to_vec()))
518 .collect();
519
520 // Clone because I don't know how to tell Rust I'm not going to change the map
521 let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
522 sorted_token_bytes.sort();
523
524 Ok(Self {
525 encoder,
526 special_tokens_encoder,
527 decoder,
528 special_tokens_decoder,
529 regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
530 special_regex_tls: (0..MAX_NUM_THREADS)
531 .map(|_| special_regex.clone())
532 .collect(),
533 sorted_token_bytes,
534 })
535 }
536
537 pub fn special_tokens(&self) -> HashSet<&str> {
538 self.special_tokens_encoder
539 .keys()
540 .map(|s| s.as_str())
541 .collect()
542 }
543
544 pub fn encode_with_special_tokens(&self, text: &str) -> Vec<Rank> {
545 let allowed_special = self.special_tokens();
546 self.encode(text, &allowed_special).unwrap().0
547 }
548}
549
550#[cfg(test)]
551mod tests {
552 use fancy_regex::Regex;
553 use rustc_hash::FxHashMap as HashMap;
554
555 use crate::{Rank, byte_pair_split};
556
557 fn setup_ranks() -> HashMap<Vec<u8>, Rank> {
558 HashMap::from_iter([(b"ab".to_vec(), 0), (b"cd".to_vec(), 1)])
559 }
560
561 #[test]
562 fn test_simple_characters() {
563 let ranks = setup_ranks();
564 let res = byte_pair_split(b"abcd", &ranks);
565 assert_eq!(res, vec![b"ab", b"cd"]);
566 }
567
568 #[test]
569 fn test_repeated_characters() {
570 let ranks = setup_ranks();
571 let res = byte_pair_split(b"abab", &ranks);
572 assert_eq!(res, vec![b"ab", b"ab"]);
573 }
574}
575