openai/tiktoken

Public

mirrored fromhttps://github.com/openai/tiktokenAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
0.2.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

src/lib.rs

559lines · modecode

1// This check is new and seems buggy (possibly with PyO3 interaction)
2#![allow(clippy::borrow_deref_ref)]
3
4use std::collections::HashSet;
5use std::thread;
6
7use fancy_regex::Regex;
8use pyo3::exceptions;
9use pyo3::prelude::*;
10use pyo3::types::{PyBytes, PyList, PyTuple};
11use pyo3::PyResult;
12use rustc_hash::FxHashMap as HashMap;
13
14fn _byte_pair_merge(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<std::ops::Range<usize>> {
15 let mut parts: Vec<_> = (0..piece.len()).map(|i| i..i + 1).collect();
16
17 // If you have n parts and m merges, this does O(mn) work
18 // We could do something with a heap and do O(m log n) work
19
20 // Note that we hash bytes, not token pairs. As long as we train BPE the way we
21 // currently do, this is equivalent. An easy way to break this would be to decouple
22 // merge priority from token index or to prevent specific token merges.
23 loop {
24 if parts.len() == 1 {
25 break;
26 }
27 let mut min_rank: Option<(usize, usize)> = None;
28 for i in 0..parts.len() - 1 {
29 let rank = if let Some(r) = ranks.get(&piece[parts[i].start..parts[i + 1].end]) {
30 *r
31 } else {
32 continue;
33 };
34 if min_rank.is_none() || rank < min_rank.unwrap().0 {
35 min_rank = Some((rank, i));
36 }
37 }
38 if let Some((_, i)) = min_rank {
39 parts[i] = parts[i].start..parts[i + 1].end;
40 parts.remove(i + 1);
41 } else {
42 break;
43 }
44 }
45 parts
46}
47
48pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<usize> {
49 if piece.len() == 1 {
50 return vec![ranks[piece]];
51 }
52 _byte_pair_merge(piece, ranks)
53 .iter()
54 .map(|p| ranks[&piece[p.start..p.end]])
55 .collect()
56}
57
58pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<&'a [u8]> {
59 if piece.len() == 1 {
60 return vec![piece];
61 }
62 _byte_pair_merge(piece, ranks)
63 .iter()
64 .map(|p| &piece[p.start..p.end])
65 .collect()
66}
67
68// Various performance notes:
69//
70// Regex
71// =====
72// Most of the time is spent in regex. The easiest way to speed this up is by using less fancy
73// regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than
74// the usual regex we use.
75//
76// However, given that we're using a regex parse-able by `regex`, there isn't much difference
77// between using the `regex` crate and using the `fancy_regex` crate.
78//
79// There is an important interaction between threading, `regex` and `fancy_regex`.
80// When using `fancy_regex`, we hit `regex.find_at`. It turns out that this causes contention on
81// some mutable scratch space inside of `regex`. This absolutely kills performance. When using plain
82// old `regex`, we don't hit this, because `find_iter` has a different code path.
83// Related: https://github.com/rust-lang/regex/blob/master/PERFORMANCE.md
84// Anyway, the way we get around this is with having a (mostly) thread local clone of the regex for
85// each thread.
86//
87// Threading
88// =========
89// I tried using `rayon`. It wasn't really faster than using Python threads and releasing the GIL.
90// So goodbye `rayon`! Let thread count etc be in control of our Python users.
91//
92// Caching
93// =======
94// The reference tokeniser has an lru cache over the equivalent of `byte_pair_encode`.
95// Originally, we had one too! Without it, we were only vaguely faster than Python.
96// I used an RWLock to protect the cache. This didn't seem to hurt single threaded performance
97// noticeably, but it did affect multi-threaded performance. Weirdly, it seemed to affect
98// multi-threaded performance even when I only had readers (maybed I messed something up?).
99// Anyway, I realised that we could get rid of the cache, if we treat the set of tokens as a cache!
100// These are exactly the set or merges that are likely to be hot. And now we don't have to think
101// about interior mutability, memory use, or cloning.
102//
103// Hashing
104// =======
105// We use FxHashMap instead of the standard HashMap. This is maybe like a 5-10% win?
106// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made
107// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster.
108
109use std::num::NonZeroU64;
110pub struct FakeThreadId(NonZeroU64);
111
112fn hash_current_thread() -> usize {
113 // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter
114 // that works great for our use case of avoiding collisions in our array. Unfortunately,
115 // it's private. However, there are only so many ways you can layout a u64, so just transmute
116 // https://github.com/rust-lang/rust/issues/67939
117 const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()];
118 const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
119 let x = unsafe {
120 std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0
121 };
122 u64::from(x) as usize
123}
124
125const MAX_NUM_THREADS: usize = 128;
126#[pyclass]
127struct CoreBPE {
128 encoder: HashMap<Vec<u8>, usize>,
129 special_tokens_encoder: HashMap<String, usize>,
130 decoder: HashMap<usize, Vec<u8>>,
131 special_tokens_decoder: HashMap<usize, Vec<u8>>,
132 regex_tls: Vec<Regex>,
133 special_regex_tls: Vec<Regex>,
134 sorted_token_bytes: Vec<Vec<u8>>,
135}
136
137impl CoreBPE {
138 fn _get_tl_regex(&self) -> &Regex {
139 // See performance notes above for what this is about
140 // It's also a little janky, please make a better version of it!
141 // However, it's nice that this doesn't leak memory to short-lived threads
142 &self.regex_tls[hash_current_thread() % MAX_NUM_THREADS]
143 }
144
145 fn _get_tl_special_regex(&self) -> &Regex {
146 &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS]
147 }
148
149 fn _decode_native(&self, tokens: &[usize]) -> Vec<u8> {
150 let mut ret = Vec::with_capacity(tokens.len() * 2);
151 for token in tokens {
152 let token_bytes = self
153 .decoder
154 .get(token)
155 .unwrap_or_else(|| &self.special_tokens_decoder[token]);
156 ret.extend(token_bytes);
157 }
158 ret
159 }
160
161 fn _encode_ordinary_native(&self, text: &str) -> Vec<usize> {
162 // This is the core of the encoding logic; the other functions in here
163 // just make things complicated :-)
164 let regex = self._get_tl_regex();
165 let mut ret = vec![];
166 for mat in regex.find_iter(text) {
167 let piece = mat.unwrap().as_str().as_bytes();
168 if let Some(token) = self.encoder.get(piece) {
169 ret.push(*token);
170 continue;
171 }
172 ret.extend(&byte_pair_encode(piece, &self.encoder));
173 }
174 ret
175 }
176
177 fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<usize>, usize) {
178 let special_regex = self._get_tl_special_regex();
179 let regex = self._get_tl_regex();
180 let mut ret = vec![];
181
182 let mut start = 0;
183 let mut last_piece_token_len = 0;
184 loop {
185 let mut next_special;
186 let mut start_find = start;
187 loop {
188 // Find the next allowed special token, if any
189 next_special = special_regex.find_from_pos(text, start_find).unwrap();
190 match next_special {
191 Some(m) => {
192 if allowed_special.contains(&text[m.start()..m.end()]) {
193 break;
194 }
195 start_find = m.start() + 1;
196 }
197 None => break,
198 }
199 }
200 let end = next_special.map_or(text.len(), |m| m.start());
201
202 // Okay, here we go, compare this logic to _encode_ordinary_native
203 for mat in regex.find_iter(&text[start..end]) {
204 let piece = mat.unwrap().as_str().as_bytes();
205 if let Some(token) = self.encoder.get(piece) {
206 last_piece_token_len = 1;
207 ret.push(*token);
208 continue;
209 }
210 let tokens = byte_pair_encode(piece, &self.encoder);
211 last_piece_token_len = tokens.len();
212 ret.extend(&tokens);
213 }
214
215 match next_special {
216 // And here we push the special token
217 Some(m) => {
218 let piece = m.as_str();
219 let token = self.special_tokens_encoder[piece];
220 ret.push(token);
221 start = m.end();
222 last_piece_token_len = 0;
223 }
224 None => break,
225 }
226 }
227
228 // last_piece_token_len is how many tokens came from the last regex split. This is used
229 // for determining unstable tokens, since you can't merge across (stable) regex splits
230 (ret, last_piece_token_len)
231 }
232
233 fn _increase_last_piece_token_len(
234 &self,
235 tokens: Vec<usize>,
236 mut last_piece_token_len: usize,
237 ) -> (Vec<usize>, usize) {
238 // Unfortunately, the locations where our regex splits can be unstable.
239 // For the purposes of determining unstable tokens, unstable regex splitting
240 // is only a problem if a split that was present disappears, since this can
241 // lead to merging of tokens otherwise thought to be stable.
242 // cl100k_base makes our life hard by including the \s*[\r\n]+
243 // pattern. This can e.g. cause "\n" + " " to become "\n \n".
244 // Here is a quick and dirty fix:
245 {
246 let token_is_all_space = |token| {
247 self.decoder
248 .get(token)
249 .map(|token_bytes| {
250 token_bytes
251 .iter()
252 .rev()
253 .all(|&b| [b' ', b'\n', b'\t'].contains(&b))
254 })
255 .unwrap_or(false)
256 };
257 if last_piece_token_len > 0
258 && token_is_all_space(&tokens[tokens.len() - last_piece_token_len])
259 {
260 while (last_piece_token_len < tokens.len())
261 && token_is_all_space(&tokens[tokens.len() - last_piece_token_len - 1])
262 {
263 last_piece_token_len += 1;
264 }
265 }
266 }
267 debug_assert!(last_piece_token_len <= tokens.len());
268
269 (tokens, last_piece_token_len)
270 }
271
272 fn _encode_unstable_native(
273 &self,
274 text: &str,
275 allowed_special: &HashSet<&str>,
276 ) -> (Vec<usize>, HashSet<Vec<usize>>) {
277 let (tokens, last_piece_token_len) = self._encode_native(text, allowed_special);
278 if last_piece_token_len == 0 {
279 // If last_piece_token_len is zero, the last token was a special token and we have
280 // no unstable bytes
281 return (tokens, HashSet::new());
282 }
283 let (mut tokens, last_piece_token_len) =
284 self._increase_last_piece_token_len(tokens, last_piece_token_len);
285
286 let unstable_bytes = self._decode_native(&tokens[tokens.len() - last_piece_token_len..]);
287 tokens.truncate(tokens.len() - last_piece_token_len);
288
289 // TODO: we should try harder to find additional stable tokens
290 // This would reduce the amount of retokenising when determining completions
291 // Refer to the logic in an older version of this file
292
293 let mut completions = HashSet::new();
294 if unstable_bytes.is_empty() {
295 return (tokens, completions);
296 }
297
298 // This is the easy bit. Just find all single tokens that start with unstable_bytes
299 // (including tokens that exactly match unstable_bytes)
300 // Separating this from the loop below helps with performance in a common case.
301 let mut point = self
302 .sorted_token_bytes
303 .partition_point(|x| x.as_slice() < unstable_bytes.as_slice());
304 while point < self.sorted_token_bytes.len()
305 && self.sorted_token_bytes[point].starts_with(&unstable_bytes)
306 {
307 completions.insert(vec![
308 self.encoder[self.sorted_token_bytes[point].as_slice()],
309 ]);
310 point += 1;
311 }
312
313 // Now apply even more brute force. At every (other) possible position for the straddling
314 // token, concatenate additional bytes from that token (if any) to unstable_bytes,
315 // and retokenise the whole thing and see what we get.
316 for i in 1..unstable_bytes.len() {
317 let prefix = &unstable_bytes[..i];
318 let suffix = &unstable_bytes[i..];
319 let mut point = self
320 .sorted_token_bytes
321 .partition_point(|x| x.as_slice() < suffix);
322 // TODO: Perf optimisation if suffix starts with " "?
323 while point < self.sorted_token_bytes.len()
324 && self.sorted_token_bytes[point].starts_with(suffix)
325 {
326 let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat();
327 let encoded = match std::str::from_utf8(&possibility) {
328 // Morally, this is byte_pair_encode(&possibility, &self.encoder)
329 // But we might have introduced a regex split which would prevent merges.
330 // (particularly possible in the presence of unstable regex splits)
331 // So convert to UTF-8 and do regex splitting.
332 // E.g. with cl100k_base " !" gets split to " " + " !",
333 // but byte_pair_encode(" !") != byte_pair_encode(" ")
334 Ok(s) => self._encode_ordinary_native(s),
335
336 // Technically, whether or not this arm is correct depends on whether there
337 // would be a regex split before the UTF-8 truncation point.
338 // Probably niche enough that no one will ever notice (after all, people didn't
339 // notice all the big holes in the previous unstable token implementation)
340 Err(_) => byte_pair_encode(&possibility, &self.encoder),
341 // Something like the following is intriguing but incorrect:
342 // Err(e) => self._encode_ordinary_native(unsafe {
343 // std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()])
344 // }),
345 };
346 let mut seq = Vec::new();
347 let mut seq_len = 0;
348 for token in encoded {
349 seq.push(token);
350 seq_len += self.decoder[&token].len();
351 if seq_len >= unstable_bytes.len() {
352 break;
353 }
354 }
355 completions.insert(seq);
356 point += 1;
357 }
358 }
359
360 // This is also not straightforward. While we generally assume that regex splits are stable,
361 // unfortunately, they are not. That is, if adding bytes were to make a split appear in
362 // unstable_bytes, this could make tokens possible which our logic would otherwise think
363 // would be merged.
364 // For example, with gpt2, the use of \s+(?!\S) means that "\n\n" could
365 // develop a split, e.g. "\n\n0" splits into "\n"+"\n"+"0", making "\n" a possible token.
366 // Here is a quick and dirty fix:
367 // This isn't right if we ever remove \s+(?!\S)
368 if unstable_bytes.len() > 1 {
369 let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice());
370 if unstable_bytes.len() - last_decoded.1 > 0
371 && last_decoded.0.map_or(false, |c| c.is_whitespace())
372 {
373 let mut reencoded = byte_pair_encode(
374 &unstable_bytes[..unstable_bytes.len() - last_decoded.1],
375 &self.encoder,
376 );
377 reencoded.extend(byte_pair_encode(
378 &unstable_bytes[unstable_bytes.len() - last_decoded.1..],
379 &self.encoder,
380 ));
381 completions.insert(reencoded);
382 }
383 }
384
385 (tokens, completions)
386 }
387}
388
389#[pymethods]
390impl CoreBPE {
391 #[new]
392 fn new(
393 encoder: HashMap<Vec<u8>, usize>,
394 special_tokens_encoder: HashMap<String, usize>,
395 pattern: &str,
396 ) -> PyResult<Self> {
397 let regex = Regex::new(pattern)
398 .map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?;
399
400 let special_regex = {
401 let _parts = special_tokens_encoder
402 .keys()
403 .map(|s| fancy_regex::escape(s))
404 .collect::<Vec<_>>();
405 Regex::new(&_parts.join("|"))
406 .map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?
407 };
408
409 let decoder: HashMap<usize, Vec<u8>> =
410 encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
411
412 assert!(encoder.len() == decoder.len());
413
414 let special_tokens_decoder: HashMap<usize, Vec<u8>> = special_tokens_encoder
415 .iter()
416 .map(|(k, v)| (*v, k.as_bytes().to_vec()))
417 .collect();
418
419 // Clone because I don't know how to tell Rust I'm not going to change the map
420 let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
421 sorted_token_bytes.sort();
422
423 Ok(CoreBPE {
424 encoder,
425 special_tokens_encoder,
426 decoder,
427 special_tokens_decoder,
428 regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
429 special_regex_tls: (0..MAX_NUM_THREADS)
430 .map(|_| special_regex.clone())
431 .collect(),
432 sorted_token_bytes,
433 })
434 }
435
436 // ====================
437 // Encoding
438 // ====================
439
440 fn encode_ordinary(&self, py: Python, text: &str) -> Vec<usize> {
441 py.allow_threads(|| self._encode_ordinary_native(text))
442 }
443
444 fn encode(&self, py: Python, text: &str, allowed_special: HashSet<&str>) -> Vec<usize> {
445 py.allow_threads(|| self._encode_native(text, &allowed_special).0)
446 }
447
448 fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<usize> {
449 py.allow_threads(|| {
450 match std::str::from_utf8(bytes) {
451 Ok(text) => self._encode_ordinary_native(text),
452 Err(e) => {
453 let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) };
454 let (tokens, last_piece_token_len) = self._encode_native(text, &HashSet::new());
455 let (mut tokens, last_piece_token_len) =
456 self._increase_last_piece_token_len(tokens, last_piece_token_len);
457 if !tokens.is_empty() && last_piece_token_len > 0 {
458 // Lop off the tokens from the last piece and run BPE on the remaining bytes
459 // Somewhat niche, but this may not be correct if we'd have had a regex
460 // split between the valid UTF-8 and the invalid bytes, which is why this
461 // method is private
462 let mut unstable_bytes =
463 self._decode_native(&tokens[tokens.len() - last_piece_token_len..]);
464 unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]);
465
466 tokens.truncate(tokens.len() - last_piece_token_len);
467 tokens.extend(byte_pair_encode(&unstable_bytes, &self.encoder));
468 }
469 tokens
470 }
471 }
472 })
473 }
474
475 fn encode_with_unstable(
476 &self,
477 py: Python,
478 text: &str,
479 allowed_special: HashSet<&str>,
480 ) -> Py<PyTuple> {
481 let (tokens, completions) =
482 py.allow_threads(|| self._encode_unstable_native(text, &allowed_special));
483 let py_completions =
484 PyList::new(py, completions.iter().map(|seq| PyList::new(py, &seq[..])));
485 (tokens, py_completions).into_py(py)
486 }
487
488 fn encode_single_token(&self, piece: &[u8]) -> PyResult<usize> {
489 if let Some(token) = self.encoder.get(piece).copied() {
490 return Ok(token);
491 }
492 if let Ok(piece_str) = std::str::from_utf8(piece) {
493 if let Some(token) = self.special_tokens_encoder.get(piece_str).copied() {
494 return Ok(token);
495 }
496 }
497 Err(PyErr::new::<exceptions::PyKeyError, _>(piece.to_owned()))
498 }
499
500 fn encode_single_piece(&self, piece: &[u8]) -> Vec<usize> {
501 if let Some(token) = self.encoder.get(piece) {
502 return vec![*token];
503 }
504 byte_pair_encode(piece, &self.encoder)
505 }
506
507 // ====================
508 // Decoding
509 // ====================
510
511 fn decode_bytes(&self, py: Python, tokens: Vec<usize>) -> Py<PyBytes> {
512 let bytes = py.allow_threads(|| self._decode_native(&tokens));
513 PyBytes::new(py, &bytes).into()
514 }
515
516 fn decode_single_token_bytes(&self, py: Python, token: usize) -> PyResult<Py<PyBytes>> {
517 if let Some(bytes) = self.decoder.get(&token) {
518 return Ok(PyBytes::new(py, bytes).into());
519 }
520 if let Some(bytes) = self.special_tokens_decoder.get(&token) {
521 return Ok(PyBytes::new(py, bytes).into());
522 }
523 Err(PyErr::new::<exceptions::PyKeyError, _>(token.to_string()))
524 }
525
526 // ====================
527 // Miscellaneous
528 // ====================
529
530 fn token_byte_values(&self, py: Python) -> Vec<Py<PyBytes>> {
531 self.sorted_token_bytes
532 .iter()
533 .map(|x| PyBytes::new(py, x).into())
534 .collect()
535 }
536}
537
538#[pymodule]
539fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> {
540 m.add_class::<CoreBPE>()?;
541 Ok(())
542}
543
544#[cfg(test)]
545mod tests {
546 use rustc_hash::FxHashMap as HashMap;
547
548 use crate::byte_pair_split;
549
550 #[test]
551 fn very_simple_test() {
552 let mut ranks = HashMap::default();
553 ranks.insert(b"ab".to_vec(), 1);
554 ranks.insert(b"cd".to_vec(), 2);
555
556 let res = byte_pair_split(b"abcd", &ranks);
557 assert_eq!(res, vec![b"ab", b"cd"]);
558 }
559}
560