openai/tiktoken

Public

mirrored fromhttps://github.com/openai/tiktokenAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
0.5.2

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

src/lib.rs

608lines · modecode

1// This check is new and seems buggy (possibly with PyO3 interaction)
2#![allow(clippy::borrow_deref_ref)]
3
4use std::collections::HashSet;
5use std::thread;
6
7use fancy_regex::Regex;
8use pyo3::exceptions;
9use pyo3::prelude::*;
10use pyo3::types::{PyBytes, PyList, PyTuple};
11use pyo3::PyResult;
12use rustc_hash::FxHashMap as HashMap;
13
14fn _byte_pair_merge<T>(
15 piece: &[u8],
16 ranks: &HashMap<Vec<u8>, usize>,
17 f: impl Fn(std::ops::Range<usize>) -> T,
18) -> Vec<T> {
19 // This is a vector of (start, rank).
20 // The rank is of the byte pair starting at position start.
21 // The rank of the last item in the vector is not a valid value.
22 let mut parts: Vec<(usize, usize)> = (0..piece.len() + 1).map(|i| (i, usize::MAX)).collect();
23
24 let get_rank = {
25 #[inline(always)]
26 |parts: &Vec<(usize, usize)>, start_idx: usize, skip: usize| {
27 if (start_idx + skip + 2) < parts.len() {
28 ranks
29 .get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0])
30 .copied()
31 } else {
32 None
33 }
34 }
35 };
36
37 // We look up the ranks once in the beginning and iteratively update
38 // them during each merge, which reduces the number of rank lookups.
39 for i in 0..parts.len() - 2 {
40 match get_rank(&parts, i, 0) {
41 Some(rank) => {
42 // usize::MAX is a sentinel value and cannot be a valid rank
43 debug_assert!(rank != usize::MAX);
44 parts[i].1 = rank;
45 }
46 None => {
47 continue;
48 }
49 };
50 }
51
52 // If you have n parts and m merges, this does O(mn) work.
53 // We could do something with a heap and do O(m log n) work.
54 // It is important to consider that n is often small (<100), and as such
55 // the cache-locality benefits outweigh the algorithmic complexity downsides
56 // of the `parts` vector data structure above.
57
58 // Note that we hash bytes, not token pairs. As long as we train BPE the way we
59 // currently do, this is equivalent. An easy way to break this would be to decouple
60 // merge priority from token index or to prevent specific token merges.
61 loop {
62 if parts.len() == 1 {
63 break;
64 }
65
66 // usize::MAX is a sentinel rank value allowing us to
67 // take the min more quickly
68 let mut min_rank: (usize, usize) = (usize::MAX, 0);
69 for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
70 if rank < min_rank.0 {
71 min_rank = (rank, i);
72 }
73 }
74
75 if min_rank.0 != usize::MAX {
76 let i = min_rank.1;
77
78 // NOTE: We are about to remove parts[i + 1]. We do not do it
79 // yet because there are cache-locality benefits to updating
80 // parts[i] and parts[i-1] before removing, which could thrash
81 // the cache. Thus, we update the rank calculation by skipping over
82 // parts[i + 1], by invoking `get_rank!` with `skip = 1`.
83 parts[i].1 = get_rank(&parts, i, 1).unwrap_or(usize::MAX);
84 if i > 0 {
85 parts[i - 1].1 = get_rank(&parts, i - 1, 1).unwrap_or(usize::MAX);
86 }
87
88 parts.remove(i + 1);
89 } else {
90 break;
91 }
92 }
93 let mut out: Vec<T> = Vec::with_capacity(parts.len() - 1);
94 for i in 0..parts.len() - 1 {
95 out.push(f(parts[i].0..parts[i + 1].0));
96 }
97 out
98}
99
100pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<usize> {
101 if piece.len() == 1 {
102 return vec![ranks[piece]];
103 }
104 _byte_pair_merge(piece, ranks, |p| ranks[&piece[p.start..p.end]])
105}
106
107pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<&'a [u8]> {
108 if piece.len() == 1 {
109 return vec![piece];
110 }
111 _byte_pair_merge(piece, ranks, |p| &piece[p.start..p.end])
112}
113
114// Various performance notes:
115//
116// Regex
117// =====
118// Most of the time is spent in regex. The easiest way to speed this up is by using less fancy
119// regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than
120// the usual regex we use.
121//
122// However, given that we're using a regex parse-able by `regex`, there isn't much difference
123// between using the `regex` crate and using the `fancy_regex` crate.
124//
125// There is an important interaction between threading, `regex` and `fancy_regex`.
126// When using `fancy_regex`, we hit `regex.find_at`. It turns out that this causes contention on
127// some mutable scratch space inside of `regex`. This absolutely kills performance. When using plain
128// old `regex`, we don't hit this, because `find_iter` has a different code path.
129// Related: https://github.com/rust-lang/regex/blob/master/PERFORMANCE.md
130// Anyway, the way we get around this is with having a (mostly) thread local clone of the regex for
131// each thread.
132//
133// Threading
134// =========
135// I tried using `rayon`. It wasn't really faster than using Python threads and releasing the GIL.
136// So goodbye `rayon`! Let thread count etc be in control of our Python users.
137//
138// Caching
139// =======
140// The reference tokeniser has an lru cache over the equivalent of `byte_pair_encode`.
141// Originally, we had one too! Without it, we were only vaguely faster than Python.
142// I used an RWLock to protect the cache. This didn't seem to hurt single threaded performance
143// noticeably, but it did affect multi-threaded performance. Weirdly, it seemed to affect
144// multi-threaded performance even when I only had readers (maybed I messed something up?).
145// Anyway, I realised that we could get rid of the cache, if we treat the set of tokens as a cache!
146// These are exactly the set or merges that are likely to be hot. And now we don't have to think
147// about interior mutability, memory use, or cloning.
148//
149// Hashing
150// =======
151// We use FxHashMap instead of the standard HashMap. This is maybe like a 5-10% win?
152// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made
153// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster.
154
155use std::num::NonZeroU64;
156pub struct FakeThreadId(NonZeroU64);
157
158fn hash_current_thread() -> usize {
159 // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter
160 // that works great for our use case of avoiding collisions in our array. Unfortunately,
161 // it's private. However, there are only so many ways you can layout a u64, so just transmute
162 // https://github.com/rust-lang/rust/issues/67939
163 const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()];
164 const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
165 let x = unsafe {
166 std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0
167 };
168 u64::from(x) as usize
169}
170
171const MAX_NUM_THREADS: usize = 128;
172#[pyclass]
173struct CoreBPE {
174 encoder: HashMap<Vec<u8>, usize>,
175 special_tokens_encoder: HashMap<String, usize>,
176 decoder: HashMap<usize, Vec<u8>>,
177 special_tokens_decoder: HashMap<usize, Vec<u8>>,
178 regex_tls: Vec<Regex>,
179 special_regex_tls: Vec<Regex>,
180 sorted_token_bytes: Vec<Vec<u8>>,
181}
182
183impl CoreBPE {
184 fn _get_tl_regex(&self) -> &Regex {
185 // See performance notes above for what this is about
186 // It's also a little janky, please make a better version of it!
187 // However, it's nice that this doesn't leak memory to short-lived threads
188 &self.regex_tls[hash_current_thread() % MAX_NUM_THREADS]
189 }
190
191 fn _get_tl_special_regex(&self) -> &Regex {
192 &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS]
193 }
194
195 fn _decode_native(&self, tokens: &[usize]) -> Vec<u8> {
196 let mut ret = Vec::with_capacity(tokens.len() * 2);
197 for token in tokens {
198 let token_bytes = self
199 .decoder
200 .get(token)
201 .unwrap_or_else(|| &self.special_tokens_decoder[token]);
202 ret.extend(token_bytes);
203 }
204 ret
205 }
206
207 fn _encode_ordinary_native(&self, text: &str) -> Vec<usize> {
208 // This is the core of the encoding logic; the other functions in here
209 // just make things complicated :-)
210 let regex = self._get_tl_regex();
211 let mut ret = vec![];
212 for mat in regex.find_iter(text) {
213 let piece = mat.unwrap().as_str().as_bytes();
214 if let Some(token) = self.encoder.get(piece) {
215 ret.push(*token);
216 continue;
217 }
218 ret.extend(&byte_pair_encode(piece, &self.encoder));
219 }
220 ret
221 }
222
223 fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<usize>, usize) {
224 let special_regex = self._get_tl_special_regex();
225 let regex = self._get_tl_regex();
226 let mut ret = vec![];
227
228 let mut start = 0;
229 let mut last_piece_token_len = 0;
230 loop {
231 let mut next_special;
232 let mut start_find = start;
233 loop {
234 // Find the next allowed special token, if any
235 next_special = special_regex.find_from_pos(text, start_find).unwrap();
236 match next_special {
237 Some(m) => {
238 if allowed_special.contains(&text[m.start()..m.end()]) {
239 break;
240 }
241 start_find = m.start() + 1;
242 }
243 None => break,
244 }
245 }
246 let end = next_special.map_or(text.len(), |m| m.start());
247
248 // Okay, here we go, compare this logic to _encode_ordinary_native
249 for mat in regex.find_iter(&text[start..end]) {
250 let piece = mat.unwrap().as_str().as_bytes();
251 if let Some(token) = self.encoder.get(piece) {
252 last_piece_token_len = 1;
253 ret.push(*token);
254 continue;
255 }
256 let tokens = byte_pair_encode(piece, &self.encoder);
257 last_piece_token_len = tokens.len();
258 ret.extend(&tokens);
259 }
260
261 match next_special {
262 // And here we push the special token
263 Some(m) => {
264 let piece = m.as_str();
265 let token = self.special_tokens_encoder[piece];
266 ret.push(token);
267 start = m.end();
268 last_piece_token_len = 0;
269 }
270 None => break,
271 }
272 }
273
274 // last_piece_token_len is how many tokens came from the last regex split. This is used
275 // for determining unstable tokens, since you can't merge across (stable) regex splits
276 (ret, last_piece_token_len)
277 }
278
279 fn _increase_last_piece_token_len(
280 &self,
281 tokens: Vec<usize>,
282 mut last_piece_token_len: usize,
283 ) -> (Vec<usize>, usize) {
284 // Unfortunately, the locations where our regex splits can be unstable.
285 // For the purposes of determining unstable tokens, unstable regex splitting
286 // is only a problem if a split that was present disappears, since this can
287 // lead to merging of tokens otherwise thought to be stable.
288 // cl100k_base makes our life hard by including the \s*[\r\n]+
289 // pattern. This can e.g. cause "\n" + " " to become "\n \n".
290 // Here is a quick and dirty fix:
291 {
292 let token_is_all_space = |token| {
293 self.decoder
294 .get(token)
295 .map(|token_bytes| {
296 token_bytes
297 .iter()
298 .rev()
299 .all(|&b| [b' ', b'\n', b'\t'].contains(&b))
300 })
301 .unwrap_or(false)
302 };
303 if last_piece_token_len > 0
304 && token_is_all_space(&tokens[tokens.len() - last_piece_token_len])
305 {
306 while (last_piece_token_len < tokens.len())
307 && token_is_all_space(&tokens[tokens.len() - last_piece_token_len - 1])
308 {
309 last_piece_token_len += 1;
310 }
311 }
312 }
313 debug_assert!(last_piece_token_len <= tokens.len());
314
315 (tokens, last_piece_token_len)
316 }
317
318 fn _encode_unstable_native(
319 &self,
320 text: &str,
321 allowed_special: &HashSet<&str>,
322 ) -> (Vec<usize>, HashSet<Vec<usize>>) {
323 let (tokens, last_piece_token_len) = self._encode_native(text, allowed_special);
324 if last_piece_token_len == 0 {
325 // If last_piece_token_len is zero, the last token was a special token and we have
326 // no unstable bytes
327 return (tokens, HashSet::new());
328 }
329 let (mut tokens, last_piece_token_len) =
330 self._increase_last_piece_token_len(tokens, last_piece_token_len);
331
332 let unstable_bytes = self._decode_native(&tokens[tokens.len() - last_piece_token_len..]);
333 tokens.truncate(tokens.len() - last_piece_token_len);
334
335 // TODO: we should try harder to find additional stable tokens
336 // This would reduce the amount of retokenising when determining completions
337 // Refer to the logic in an older version of this file
338
339 let mut completions = HashSet::new();
340 if unstable_bytes.is_empty() {
341 return (tokens, completions);
342 }
343
344 // This is the easy bit. Just find all single tokens that start with unstable_bytes
345 // (including tokens that exactly match unstable_bytes)
346 // Separating this from the loop below helps with performance in a common case.
347 let mut point = self
348 .sorted_token_bytes
349 .partition_point(|x| x.as_slice() < unstable_bytes.as_slice());
350 while point < self.sorted_token_bytes.len()
351 && self.sorted_token_bytes[point].starts_with(&unstable_bytes)
352 {
353 completions.insert(vec![
354 self.encoder[self.sorted_token_bytes[point].as_slice()],
355 ]);
356 point += 1;
357 }
358
359 // Now apply even more brute force. At every (other) possible position for the straddling
360 // token, concatenate additional bytes from that token (if any) to unstable_bytes,
361 // and retokenise the whole thing and see what we get.
362 for i in 1..unstable_bytes.len() {
363 let prefix = &unstable_bytes[..i];
364 let suffix = &unstable_bytes[i..];
365 let mut point = self
366 .sorted_token_bytes
367 .partition_point(|x| x.as_slice() < suffix);
368 // TODO: Perf optimisation if suffix starts with " "?
369 while point < self.sorted_token_bytes.len()
370 && self.sorted_token_bytes[point].starts_with(suffix)
371 {
372 let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat();
373 let encoded = match std::str::from_utf8(&possibility) {
374 // Morally, this is byte_pair_encode(&possibility, &self.encoder)
375 // But we might have introduced a regex split which would prevent merges.
376 // (particularly possible in the presence of unstable regex splits)
377 // So convert to UTF-8 and do regex splitting.
378 // E.g. with cl100k_base " !" gets split to " " + " !",
379 // but byte_pair_encode(" !") != byte_pair_encode(" ")
380 Ok(s) => self._encode_ordinary_native(s),
381
382 // Technically, whether or not this arm is correct depends on whether there
383 // would be a regex split before the UTF-8 truncation point.
384 // Probably niche enough that no one will ever notice (after all, people didn't
385 // notice all the big holes in the previous unstable token implementation)
386 Err(_) => byte_pair_encode(&possibility, &self.encoder),
387 // Something like the following is intriguing but incorrect:
388 // Err(e) => self._encode_ordinary_native(unsafe {
389 // std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()])
390 // }),
391 };
392 let mut seq = Vec::new();
393 let mut seq_len = 0;
394 for token in encoded {
395 seq.push(token);
396 seq_len += self.decoder[&token].len();
397 if seq_len >= unstable_bytes.len() {
398 break;
399 }
400 }
401 completions.insert(seq);
402 point += 1;
403 }
404 }
405
406 // This is also not straightforward. While we generally assume that regex splits are stable,
407 // unfortunately, they are not. That is, if adding bytes were to make a split appear in
408 // unstable_bytes, this could make tokens possible which our logic would otherwise think
409 // would be merged.
410 // For example, with gpt2, the use of \s+(?!\S) means that "\n\n" could
411 // develop a split, e.g. "\n\n0" splits into "\n"+"\n"+"0", making "\n" a possible token.
412 // Here is a quick and dirty fix:
413 // This isn't right if we ever remove \s+(?!\S)
414 if unstable_bytes.len() > 1 {
415 let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice());
416 if unstable_bytes.len() - last_decoded.1 > 0
417 && last_decoded.0.map_or(false, |c| c.is_whitespace())
418 {
419 let mut reencoded = byte_pair_encode(
420 &unstable_bytes[..unstable_bytes.len() - last_decoded.1],
421 &self.encoder,
422 );
423 reencoded.extend(byte_pair_encode(
424 &unstable_bytes[unstable_bytes.len() - last_decoded.1..],
425 &self.encoder,
426 ));
427 completions.insert(reencoded);
428 }
429 }
430
431 (tokens, completions)
432 }
433}
434
435#[pymethods]
436impl CoreBPE {
437 #[new]
438 fn new(
439 encoder: HashMap<Vec<u8>, usize>,
440 special_tokens_encoder: HashMap<String, usize>,
441 pattern: &str,
442 ) -> PyResult<Self> {
443 let regex = Regex::new(pattern)
444 .map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?;
445
446 let special_regex = {
447 let _parts = special_tokens_encoder
448 .keys()
449 .map(|s| fancy_regex::escape(s))
450 .collect::<Vec<_>>();
451 Regex::new(&_parts.join("|"))
452 .map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?
453 };
454
455 let decoder: HashMap<usize, Vec<u8>> =
456 encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
457
458 assert!(
459 encoder.len() == decoder.len(),
460 "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
461 );
462
463 let special_tokens_decoder: HashMap<usize, Vec<u8>> = special_tokens_encoder
464 .iter()
465 .map(|(k, v)| (*v, k.as_bytes().to_vec()))
466 .collect();
467
468 // Clone because I don't know how to tell Rust I'm not going to change the map
469 let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
470 sorted_token_bytes.sort();
471
472 Ok(CoreBPE {
473 encoder,
474 special_tokens_encoder,
475 decoder,
476 special_tokens_decoder,
477 regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
478 special_regex_tls: (0..MAX_NUM_THREADS)
479 .map(|_| special_regex.clone())
480 .collect(),
481 sorted_token_bytes,
482 })
483 }
484
485 // ====================
486 // Encoding
487 // ====================
488
489 fn encode_ordinary(&self, py: Python, text: &str) -> Vec<usize> {
490 py.allow_threads(|| self._encode_ordinary_native(text))
491 }
492
493 fn encode(&self, py: Python, text: &str, allowed_special: HashSet<&str>) -> Vec<usize> {
494 py.allow_threads(|| self._encode_native(text, &allowed_special).0)
495 }
496
497 fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<usize> {
498 py.allow_threads(|| {
499 match std::str::from_utf8(bytes) {
500 Ok(text) => self._encode_ordinary_native(text),
501 Err(e) => {
502 let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) };
503 let (tokens, last_piece_token_len) = self._encode_native(text, &HashSet::new());
504 let (mut tokens, last_piece_token_len) =
505 self._increase_last_piece_token_len(tokens, last_piece_token_len);
506 if !tokens.is_empty() && last_piece_token_len > 0 {
507 // Lop off the tokens from the last piece and run BPE on the remaining bytes
508 // Somewhat niche, but this may not be correct if we'd have had a regex
509 // split between the valid UTF-8 and the invalid bytes, which is why this
510 // method is private
511 let mut unstable_bytes =
512 self._decode_native(&tokens[tokens.len() - last_piece_token_len..]);
513 unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]);
514
515 tokens.truncate(tokens.len() - last_piece_token_len);
516 tokens.extend(byte_pair_encode(&unstable_bytes, &self.encoder));
517 }
518 tokens
519 }
520 }
521 })
522 }
523
524 fn encode_with_unstable(
525 &self,
526 py: Python,
527 text: &str,
528 allowed_special: HashSet<&str>,
529 ) -> Py<PyTuple> {
530 let (tokens, completions) =
531 py.allow_threads(|| self._encode_unstable_native(text, &allowed_special));
532 let py_completions =
533 PyList::new(py, completions.iter().map(|seq| PyList::new(py, &seq[..])));
534 (tokens, py_completions).into_py(py)
535 }
536
537 fn encode_single_token(&self, piece: &[u8]) -> PyResult<usize> {
538 if let Some(token) = self.encoder.get(piece).copied() {
539 return Ok(token);
540 }
541 if let Ok(piece_str) = std::str::from_utf8(piece) {
542 if let Some(token) = self.special_tokens_encoder.get(piece_str).copied() {
543 return Ok(token);
544 }
545 }
546 Err(PyErr::new::<exceptions::PyKeyError, _>(piece.to_owned()))
547 }
548
549 fn encode_single_piece(&self, piece: &[u8]) -> Vec<usize> {
550 if let Some(token) = self.encoder.get(piece) {
551 return vec![*token];
552 }
553 byte_pair_encode(piece, &self.encoder)
554 }
555
556 // ====================
557 // Decoding
558 // ====================
559
560 fn decode_bytes(&self, py: Python, tokens: Vec<usize>) -> Py<PyBytes> {
561 let bytes = py.allow_threads(|| self._decode_native(&tokens));
562 PyBytes::new(py, &bytes).into()
563 }
564
565 fn decode_single_token_bytes(&self, py: Python, token: usize) -> PyResult<Py<PyBytes>> {
566 if let Some(bytes) = self.decoder.get(&token) {
567 return Ok(PyBytes::new(py, bytes).into());
568 }
569 if let Some(bytes) = self.special_tokens_decoder.get(&token) {
570 return Ok(PyBytes::new(py, bytes).into());
571 }
572 Err(PyErr::new::<exceptions::PyKeyError, _>(token.to_string()))
573 }
574
575 // ====================
576 // Miscellaneous
577 // ====================
578
579 fn token_byte_values(&self, py: Python) -> Vec<Py<PyBytes>> {
580 self.sorted_token_bytes
581 .iter()
582 .map(|x| PyBytes::new(py, x).into())
583 .collect()
584 }
585}
586
587#[pymodule]
588fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> {
589 m.add_class::<CoreBPE>()?;
590 Ok(())
591}
592
593#[cfg(test)]
594mod tests {
595 use rustc_hash::FxHashMap as HashMap;
596
597 use crate::byte_pair_split;
598
599 #[test]
600 fn very_simple_test() {
601 let mut ranks = HashMap::default();
602 ranks.insert(b"ab".to_vec(), 1);
603 ranks.insert(b"cd".to_vec(), 2);
604
605 let res = byte_pair_split(b"abcd", &ranks);
606 assert_eq!(res, vec![b"ab", b"cd"]);
607 }
608}
609