openai/tiktoken
Publicmirrored fromhttps://github.com/openai/tiktokenAvailable
src/lib.rs
703lines · modecode
| 1 | // This check is new and seems buggy (possibly with PyO3 interaction) |
| 2 | #![allow(clippy::borrow_deref_ref)] |
| 3 | |
| 4 | use std::collections::HashSet; |
| 5 | use std::num::NonZeroU64; |
| 6 | use std::thread; |
| 7 | |
| 8 | use fancy_regex::Regex; |
| 9 | use pyo3::exceptions; |
| 10 | use pyo3::prelude::*; |
| 11 | use pyo3::pybacked::PyBackedStr; |
| 12 | use pyo3::types::{PyBytes, PyList, PyTuple}; |
| 13 | use pyo3::PyResult; |
| 14 | use rustc_hash::FxHashMap as HashMap; |
| 15 | |
| 16 | type Rank = u32; |
| 17 | |
| 18 | fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> { |
| 19 | // This is a vector of (start, rank). |
| 20 | // The rank is of the pair starting at position start. |
| 21 | let mut parts = Vec::with_capacity(piece.len() + 1); |
| 22 | |
| 23 | // Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE |
| 24 | // the way we currently do, this is equivalent. An easy way to break this would be to decouple |
| 25 | // merge priority from token index or to prevent specific token merges. |
| 26 | let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX); |
| 27 | for i in 0..piece.len() - 1 { |
| 28 | let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX); |
| 29 | if rank < min_rank.0 { |
| 30 | min_rank = (rank, i); |
| 31 | } |
| 32 | parts.push((i, rank)); |
| 33 | } |
| 34 | parts.push((piece.len() - 1, Rank::MAX)); |
| 35 | parts.push((piece.len(), Rank::MAX)); |
| 36 | |
| 37 | let get_rank = { |
| 38 | #[inline(always)] |
| 39 | |parts: &Vec<(usize, Rank)>, i: usize| { |
| 40 | if (i + 3) < parts.len() { |
| 41 | // Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted |
| 42 | // parts[i + 1], see comment in the main loop. |
| 43 | *ranks |
| 44 | .get(&piece[parts[i].0..parts[i + 3].0]) |
| 45 | .unwrap_or(&Rank::MAX) |
| 46 | } else { |
| 47 | Rank::MAX |
| 48 | } |
| 49 | } |
| 50 | }; |
| 51 | |
| 52 | // If you have n parts and m merges, this does O(mn) work. |
| 53 | // We could do something with a heap and do O(m log n) work. |
| 54 | // n is often very small so considerations like cache-locality outweigh the algorithmic |
| 55 | // complexity downsides of the `parts` vector. |
| 56 | while min_rank.0 != Rank::MAX { |
| 57 | let i = min_rank.1; |
| 58 | // Update parts[i] and parts[i - 1] before removing parts[i + 1], since |
| 59 | // `parts.remove(i + 1)` will thrash the cache. |
| 60 | if i > 0 { |
| 61 | parts[i - 1].1 = get_rank(&parts, i - 1); |
| 62 | } |
| 63 | parts[i].1 = get_rank(&parts, i); |
| 64 | parts.remove(i + 1); |
| 65 | |
| 66 | min_rank = (Rank::MAX, usize::MAX); |
| 67 | for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() { |
| 68 | if rank < min_rank.0 { |
| 69 | min_rank = (rank, i); |
| 70 | } |
| 71 | } |
| 72 | } |
| 73 | parts |
| 74 | } |
| 75 | |
| 76 | pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> { |
| 77 | if piece.len() == 1 { |
| 78 | return vec![ranks[piece]]; |
| 79 | } |
| 80 | _byte_pair_merge(ranks, piece) |
| 81 | .windows(2) |
| 82 | .map(|part| ranks[&piece[part[0].0..part[1].0]]) |
| 83 | .collect() |
| 84 | } |
| 85 | |
| 86 | pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<&'a [u8]> { |
| 87 | assert!(piece.len() > 1); |
| 88 | _byte_pair_merge(ranks, piece) |
| 89 | .windows(2) |
| 90 | .map(|part| &piece[part[0].0..part[1].0]) |
| 91 | .collect() |
| 92 | } |
| 93 | |
| 94 | // Various performance notes: |
| 95 | // |
| 96 | // Regex |
| 97 | // ===== |
| 98 | // Most of the time is spent in regex. The easiest way to speed this up is by using less fancy |
| 99 | // regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than |
| 100 | // the usual regex we use. |
| 101 | // |
| 102 | // However, given that we're using a regex parse-able by `regex`, there isn't much difference |
| 103 | // between using the `regex` crate and using the `fancy_regex` crate. |
| 104 | // |
| 105 | // There is an important interaction between threading, `regex` and `fancy_regex`. |
| 106 | // When using `fancy_regex`, we hit `regex.find_at`. It turns out that this causes contention on |
| 107 | // some mutable scratch space inside of `regex`. This absolutely kills performance. When using plain |
| 108 | // old `regex`, we don't hit this, because `find_iter` has a different code path. |
| 109 | // Related: https://github.com/rust-lang/regex/blob/master/PERFORMANCE.md |
| 110 | // Anyway, the way we get around this is with having a (mostly) thread local clone of the regex for |
| 111 | // each thread. |
| 112 | // |
| 113 | // Threading |
| 114 | // ========= |
| 115 | // I tried using `rayon`. It wasn't really faster than using Python threads and releasing the GIL. |
| 116 | // So goodbye `rayon`! Let thread count etc be in control of our Python users. |
| 117 | // |
| 118 | // Caching |
| 119 | // ======= |
| 120 | // The reference tokeniser has an lru cache over the equivalent of `byte_pair_encode`. |
| 121 | // Originally, we had one too! Without it, we were only vaguely faster than Python. |
| 122 | // I used an RWLock to protect the cache. This didn't seem to hurt single threaded performance |
| 123 | // noticeably, but it did affect multi-threaded performance. Weirdly, it seemed to affect |
| 124 | // multi-threaded performance even when I only had readers (maybed I messed something up?). |
| 125 | // Anyway, I realised that we could get rid of the cache, if we treat the set of tokens as a cache! |
| 126 | // These are exactly the set or merges that are likely to be hot. And now we don't have to think |
| 127 | // about interior mutability, memory use, or cloning. |
| 128 | // |
| 129 | // Hashing |
| 130 | // ======= |
| 131 | // We use FxHashMap instead of the standard HashMap. This is maybe like a 5-10% win? |
| 132 | // The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made |
| 133 | // to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster. |
| 134 | |
| 135 | pub struct FakeThreadId(NonZeroU64); |
| 136 | |
| 137 | fn hash_current_thread() -> usize { |
| 138 | // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter |
| 139 | // that works great for our use case of avoiding collisions in our array. Unfortunately, |
| 140 | // it's private. However, there are only so many ways you can layout a u64, so just transmute |
| 141 | // https://github.com/rust-lang/rust/issues/67939 |
| 142 | const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()]; |
| 143 | const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()]; |
| 144 | let x = unsafe { |
| 145 | std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0 |
| 146 | }; |
| 147 | u64::from(x) as usize |
| 148 | } |
| 149 | |
| 150 | #[derive(Debug, Clone)] |
| 151 | struct DecodeKeyError { |
| 152 | token: Rank, |
| 153 | } |
| 154 | |
| 155 | impl std::fmt::Display for DecodeKeyError { |
| 156 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { |
| 157 | write!(f, "Invalid token for decoding: {}", self.token) |
| 158 | } |
| 159 | } |
| 160 | |
| 161 | const MAX_NUM_THREADS: usize = 128; |
| 162 | |
| 163 | #[pyclass] |
| 164 | struct CoreBPE { |
| 165 | encoder: HashMap<Vec<u8>, Rank>, |
| 166 | special_tokens_encoder: HashMap<String, Rank>, |
| 167 | decoder: HashMap<Rank, Vec<u8>>, |
| 168 | special_tokens_decoder: HashMap<Rank, Vec<u8>>, |
| 169 | regex_tls: Vec<Regex>, |
| 170 | special_regex_tls: Vec<Regex>, |
| 171 | sorted_token_bytes: Vec<Vec<u8>>, |
| 172 | } |
| 173 | |
| 174 | impl CoreBPE { |
| 175 | fn _get_tl_regex(&self) -> &Regex { |
| 176 | // See performance notes above for what this is about |
| 177 | // It's also a little janky, please make a better version of it! |
| 178 | // However, it's nice that this doesn't leak memory to short-lived threads |
| 179 | &self.regex_tls[hash_current_thread() % MAX_NUM_THREADS] |
| 180 | } |
| 181 | |
| 182 | fn _get_tl_special_regex(&self) -> &Regex { |
| 183 | &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS] |
| 184 | } |
| 185 | |
| 186 | fn _decode_native(&self, tokens: &[Rank]) -> Result<Vec<u8>, DecodeKeyError> { |
| 187 | let mut ret = Vec::with_capacity(tokens.len() * 2); |
| 188 | for &token in tokens { |
| 189 | let token_bytes = match self.decoder.get(&token) { |
| 190 | Some(bytes) => bytes, |
| 191 | None => self |
| 192 | .special_tokens_decoder |
| 193 | .get(&token) |
| 194 | .ok_or(DecodeKeyError { token })?, |
| 195 | }; |
| 196 | ret.extend(token_bytes); |
| 197 | } |
| 198 | Ok(ret) |
| 199 | } |
| 200 | |
| 201 | fn _encode_ordinary_native(&self, text: &str) -> Vec<Rank> { |
| 202 | // This is the core of the encoding logic; the other functions in here |
| 203 | // just make things complicated :-) |
| 204 | let regex = self._get_tl_regex(); |
| 205 | let mut ret = vec![]; |
| 206 | for mat in regex.find_iter(text) { |
| 207 | let piece = mat.unwrap().as_str().as_bytes(); |
| 208 | match self.encoder.get(piece) { |
| 209 | Some(token) => ret.push(*token), |
| 210 | None => ret.extend(&byte_pair_encode(piece, &self.encoder)), |
| 211 | } |
| 212 | } |
| 213 | ret |
| 214 | } |
| 215 | |
| 216 | fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<Rank>, usize) { |
| 217 | let special_regex = self._get_tl_special_regex(); |
| 218 | let regex = self._get_tl_regex(); |
| 219 | let mut ret = vec![]; |
| 220 | |
| 221 | let mut start = 0; |
| 222 | let mut last_piece_token_len = 0; |
| 223 | loop { |
| 224 | let mut next_special; |
| 225 | let mut start_find = start; |
| 226 | loop { |
| 227 | // Find the next allowed special token, if any |
| 228 | next_special = special_regex.find_from_pos(text, start_find).unwrap(); |
| 229 | match next_special { |
| 230 | Some(m) => { |
| 231 | if allowed_special.contains(&text[m.start()..m.end()]) { |
| 232 | break; |
| 233 | } |
| 234 | start_find = m.start() + 1; |
| 235 | } |
| 236 | None => break, |
| 237 | } |
| 238 | } |
| 239 | let end = next_special.map_or(text.len(), |m| m.start()); |
| 240 | |
| 241 | // Okay, here we go, compare this logic to _encode_ordinary_native |
| 242 | for mat in regex.find_iter(&text[start..end]) { |
| 243 | let piece = mat.unwrap().as_str().as_bytes(); |
| 244 | if let Some(token) = self.encoder.get(piece) { |
| 245 | last_piece_token_len = 1; |
| 246 | ret.push(*token); |
| 247 | continue; |
| 248 | } |
| 249 | let tokens = byte_pair_encode(piece, &self.encoder); |
| 250 | last_piece_token_len = tokens.len(); |
| 251 | ret.extend(&tokens); |
| 252 | } |
| 253 | |
| 254 | match next_special { |
| 255 | // And here we push the special token |
| 256 | Some(m) => { |
| 257 | let piece = m.as_str(); |
| 258 | let token = self.special_tokens_encoder[piece]; |
| 259 | ret.push(token); |
| 260 | start = m.end(); |
| 261 | last_piece_token_len = 0; |
| 262 | } |
| 263 | None => break, |
| 264 | } |
| 265 | } |
| 266 | |
| 267 | // last_piece_token_len is how many tokens came from the last regex split. This is used |
| 268 | // for determining unstable tokens, since you can't merge across (stable) regex splits |
| 269 | (ret, last_piece_token_len) |
| 270 | } |
| 271 | |
| 272 | fn _increase_last_piece_token_len( |
| 273 | &self, |
| 274 | tokens: Vec<Rank>, |
| 275 | mut last_piece_token_len: usize, |
| 276 | ) -> (Vec<Rank>, usize) { |
| 277 | // Unfortunately, the locations where our regex splits can be unstable. |
| 278 | // For the purposes of determining unstable tokens, unstable regex splitting |
| 279 | // is only a problem if a split that was present disappears, since this can |
| 280 | // lead to merging of tokens otherwise thought to be stable. |
| 281 | // cl100k_base makes our life hard by including the \s*[\r\n]+ |
| 282 | // pattern. This can e.g. cause "\n" + " " to become "\n \n". |
| 283 | // Here is a quick and dirty fix: |
| 284 | { |
| 285 | let token_is_all_space = |token| { |
| 286 | self.decoder |
| 287 | .get(token) |
| 288 | .map(|token_bytes| { |
| 289 | token_bytes |
| 290 | .iter() |
| 291 | .rev() |
| 292 | .all(|&b| [b' ', b'\n', b'\t'].contains(&b)) |
| 293 | }) |
| 294 | .unwrap_or(false) |
| 295 | }; |
| 296 | if last_piece_token_len > 0 |
| 297 | && token_is_all_space(&tokens[tokens.len() - last_piece_token_len]) |
| 298 | { |
| 299 | while (last_piece_token_len < tokens.len()) |
| 300 | && token_is_all_space(&tokens[tokens.len() - last_piece_token_len - 1]) |
| 301 | { |
| 302 | last_piece_token_len += 1; |
| 303 | } |
| 304 | } |
| 305 | } |
| 306 | debug_assert!(last_piece_token_len <= tokens.len()); |
| 307 | |
| 308 | (tokens, last_piece_token_len) |
| 309 | } |
| 310 | |
| 311 | fn _encode_unstable_native( |
| 312 | &self, |
| 313 | text: &str, |
| 314 | allowed_special: &HashSet<&str>, |
| 315 | ) -> (Vec<Rank>, HashSet<Vec<Rank>>) { |
| 316 | let (tokens, last_piece_token_len) = self._encode_native(text, allowed_special); |
| 317 | if last_piece_token_len == 0 { |
| 318 | // If last_piece_token_len is zero, the last token was a special token and we have |
| 319 | // no unstable bytes |
| 320 | return (tokens, HashSet::new()); |
| 321 | } |
| 322 | let (mut tokens, last_piece_token_len) = |
| 323 | self._increase_last_piece_token_len(tokens, last_piece_token_len); |
| 324 | |
| 325 | let unstable_bytes = self |
| 326 | ._decode_native(&tokens[tokens.len() - last_piece_token_len..]) |
| 327 | .unwrap(); |
| 328 | tokens.truncate(tokens.len() - last_piece_token_len); |
| 329 | |
| 330 | // TODO: we should try harder to find additional stable tokens |
| 331 | // This would reduce the amount of retokenising when determining completions |
| 332 | // Refer to the logic in an older version of this file |
| 333 | |
| 334 | let mut completions = HashSet::new(); |
| 335 | if unstable_bytes.is_empty() { |
| 336 | return (tokens, completions); |
| 337 | } |
| 338 | |
| 339 | // This is the easy bit. Just find all single tokens that start with unstable_bytes |
| 340 | // (including tokens that exactly match unstable_bytes) |
| 341 | // Separating this from the loop below helps with performance in a common case. |
| 342 | let mut point = self |
| 343 | .sorted_token_bytes |
| 344 | .partition_point(|x| x.as_slice() < unstable_bytes.as_slice()); |
| 345 | while point < self.sorted_token_bytes.len() |
| 346 | && self.sorted_token_bytes[point].starts_with(&unstable_bytes) |
| 347 | { |
| 348 | completions.insert(vec![ |
| 349 | self.encoder[self.sorted_token_bytes[point].as_slice()], |
| 350 | ]); |
| 351 | point += 1; |
| 352 | } |
| 353 | |
| 354 | // Now apply even more brute force. At every (other) possible position for the straddling |
| 355 | // token, concatenate additional bytes from that token (if any) to unstable_bytes, |
| 356 | // and retokenise the whole thing and see what we get. |
| 357 | for i in 1..unstable_bytes.len() { |
| 358 | let prefix = &unstable_bytes[..i]; |
| 359 | let suffix = &unstable_bytes[i..]; |
| 360 | let mut point = self |
| 361 | .sorted_token_bytes |
| 362 | .partition_point(|x| x.as_slice() < suffix); |
| 363 | // TODO: Perf optimisation if suffix starts with " "? |
| 364 | while point < self.sorted_token_bytes.len() |
| 365 | && self.sorted_token_bytes[point].starts_with(suffix) |
| 366 | { |
| 367 | let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat(); |
| 368 | let encoded = match std::str::from_utf8(&possibility) { |
| 369 | // Morally, this is byte_pair_encode(&possibility, &self.encoder) |
| 370 | // But we might have introduced a regex split which would prevent merges. |
| 371 | // (particularly possible in the presence of unstable regex splits) |
| 372 | // So convert to UTF-8 and do regex splitting. |
| 373 | // E.g. with cl100k_base " !" gets split to " " + " !", |
| 374 | // but byte_pair_encode(" !") != byte_pair_encode(" ") |
| 375 | Ok(s) => self._encode_ordinary_native(s), |
| 376 | |
| 377 | // Technically, whether or not this arm is correct depends on whether there |
| 378 | // would be a regex split before the UTF-8 truncation point. |
| 379 | // Probably niche enough that no one will ever notice (after all, people didn't |
| 380 | // notice all the big holes in the previous unstable token implementation) |
| 381 | Err(_) => byte_pair_encode(&possibility, &self.encoder), |
| 382 | // Something like the following is intriguing but incorrect: |
| 383 | // Err(e) => self._encode_ordinary_native(unsafe { |
| 384 | // std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()]) |
| 385 | // }), |
| 386 | }; |
| 387 | let mut seq = Vec::new(); |
| 388 | let mut seq_len = 0; |
| 389 | for token in encoded { |
| 390 | seq.push(token); |
| 391 | seq_len += self.decoder[&token].len(); |
| 392 | if seq_len >= unstable_bytes.len() { |
| 393 | break; |
| 394 | } |
| 395 | } |
| 396 | completions.insert(seq); |
| 397 | point += 1; |
| 398 | } |
| 399 | } |
| 400 | |
| 401 | // This is also not straightforward. While we generally assume that regex splits are stable, |
| 402 | // unfortunately, they are not. That is, if adding bytes were to make a split appear in |
| 403 | // unstable_bytes, this could make tokens possible which our logic would otherwise think |
| 404 | // would be merged. |
| 405 | // For example, with gpt2, the use of \s+(?!\S) means that "\n\n" could |
| 406 | // develop a split, e.g. "\n\n0" splits into "\n"+"\n"+"0", making "\n" a possible token. |
| 407 | // Here is a quick and dirty fix: |
| 408 | // This isn't right if we ever remove \s+(?!\S) |
| 409 | if unstable_bytes.len() > 1 { |
| 410 | let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice()); |
| 411 | if unstable_bytes.len() - last_decoded.1 > 0 |
| 412 | && last_decoded.0.map_or(false, |c| c.is_whitespace()) |
| 413 | { |
| 414 | let mut reencoded = byte_pair_encode( |
| 415 | &unstable_bytes[..unstable_bytes.len() - last_decoded.1], |
| 416 | &self.encoder, |
| 417 | ); |
| 418 | reencoded.extend(byte_pair_encode( |
| 419 | &unstable_bytes[unstable_bytes.len() - last_decoded.1..], |
| 420 | &self.encoder, |
| 421 | )); |
| 422 | completions.insert(reencoded); |
| 423 | } |
| 424 | } |
| 425 | |
| 426 | (tokens, completions) |
| 427 | } |
| 428 | } |
| 429 | |
| 430 | #[pymethods] |
| 431 | impl CoreBPE { |
| 432 | #[new] |
| 433 | fn new( |
| 434 | encoder: HashMap<Vec<u8>, Rank>, |
| 435 | special_tokens_encoder: HashMap<String, Rank>, |
| 436 | pattern: &str, |
| 437 | ) -> PyResult<Self> { |
| 438 | let regex = Regex::new(pattern) |
| 439 | .map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?; |
| 440 | |
| 441 | let special_regex = { |
| 442 | let _parts = special_tokens_encoder |
| 443 | .keys() |
| 444 | .map(|s| fancy_regex::escape(s)) |
| 445 | .collect::<Vec<_>>(); |
| 446 | Regex::new(&_parts.join("|")) |
| 447 | .map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))? |
| 448 | }; |
| 449 | |
| 450 | let decoder: HashMap<Rank, Vec<u8>> = |
| 451 | encoder.iter().map(|(k, v)| (*v, k.clone())).collect(); |
| 452 | |
| 453 | assert!( |
| 454 | encoder.len() == decoder.len(), |
| 455 | "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?" |
| 456 | ); |
| 457 | |
| 458 | let special_tokens_decoder: HashMap<Rank, Vec<u8>> = special_tokens_encoder |
| 459 | .iter() |
| 460 | .map(|(k, v)| (*v, k.as_bytes().to_vec())) |
| 461 | .collect(); |
| 462 | |
| 463 | // Clone because I don't know how to tell Rust I'm not going to change the map |
| 464 | let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect(); |
| 465 | sorted_token_bytes.sort(); |
| 466 | |
| 467 | Ok(CoreBPE { |
| 468 | encoder, |
| 469 | special_tokens_encoder, |
| 470 | decoder, |
| 471 | special_tokens_decoder, |
| 472 | regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(), |
| 473 | special_regex_tls: (0..MAX_NUM_THREADS) |
| 474 | .map(|_| special_regex.clone()) |
| 475 | .collect(), |
| 476 | sorted_token_bytes, |
| 477 | }) |
| 478 | } |
| 479 | |
| 480 | // ==================== |
| 481 | // Encoding |
| 482 | // ==================== |
| 483 | |
| 484 | fn encode_ordinary(&self, py: Python, text: &str) -> Vec<Rank> { |
| 485 | py.allow_threads(|| self._encode_ordinary_native(text)) |
| 486 | } |
| 487 | |
| 488 | fn encode(&self, py: Python, text: &str, allowed_special: HashSet<PyBackedStr>) -> Vec<Rank> { |
| 489 | py.allow_threads(|| { |
| 490 | let allowed_special: HashSet<&str> = |
| 491 | allowed_special.iter().map(|s| s.as_ref()).collect(); |
| 492 | self._encode_native(text, &allowed_special).0 |
| 493 | }) |
| 494 | } |
| 495 | |
| 496 | fn encode_to_tiktoken_buffer( |
| 497 | &self, |
| 498 | py: Python, |
| 499 | text: &str, |
| 500 | allowed_special: HashSet<PyBackedStr>, |
| 501 | ) -> Py<PyAny> { |
| 502 | let tokens = py.allow_threads(|| { |
| 503 | let allowed_special: HashSet<&str> = |
| 504 | allowed_special.iter().map(|s| s.as_ref()).collect(); |
| 505 | self._encode_native(text, &allowed_special).0 |
| 506 | }); |
| 507 | let buffer = TiktokenBuffer { tokens }; |
| 508 | buffer.into_py(py) |
| 509 | } |
| 510 | |
| 511 | fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<Rank> { |
| 512 | py.allow_threads(|| { |
| 513 | match std::str::from_utf8(bytes) { |
| 514 | Ok(text) => self._encode_ordinary_native(text), |
| 515 | Err(e) => { |
| 516 | let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) }; |
| 517 | let (tokens, last_piece_token_len) = self._encode_native(text, &HashSet::new()); |
| 518 | let (mut tokens, last_piece_token_len) = |
| 519 | self._increase_last_piece_token_len(tokens, last_piece_token_len); |
| 520 | if !tokens.is_empty() && last_piece_token_len > 0 { |
| 521 | // Lop off the tokens from the last piece and run BPE on the remaining bytes |
| 522 | // Somewhat niche, but this may not be correct if we'd have had a regex |
| 523 | // split between the valid UTF-8 and the invalid bytes, which is why this |
| 524 | // method is private |
| 525 | let mut unstable_bytes = self |
| 526 | ._decode_native(&tokens[tokens.len() - last_piece_token_len..]) |
| 527 | .unwrap(); |
| 528 | unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]); |
| 529 | |
| 530 | tokens.truncate(tokens.len() - last_piece_token_len); |
| 531 | match self.encoder.get(&unstable_bytes) { |
| 532 | Some(token) => tokens.push(*token), |
| 533 | None => { |
| 534 | tokens.extend(&byte_pair_encode(&unstable_bytes, &self.encoder)) |
| 535 | } |
| 536 | } |
| 537 | } |
| 538 | tokens |
| 539 | } |
| 540 | } |
| 541 | }) |
| 542 | } |
| 543 | |
| 544 | fn encode_with_unstable( |
| 545 | &self, |
| 546 | py: Python, |
| 547 | text: &str, |
| 548 | allowed_special: HashSet<PyBackedStr>, |
| 549 | ) -> Py<PyTuple> { |
| 550 | let (tokens, completions) = py.allow_threads(|| { |
| 551 | let allowed_special: HashSet<&str> = |
| 552 | allowed_special.iter().map(|s| s.as_ref()).collect(); |
| 553 | self._encode_unstable_native(text, &allowed_special) |
| 554 | }); |
| 555 | let py_completions = PyList::new_bound( |
| 556 | py, |
| 557 | completions |
| 558 | .iter() |
| 559 | .map(|seq| PyList::new_bound(py, &seq[..])), |
| 560 | ); |
| 561 | (tokens, py_completions).into_py(py) |
| 562 | } |
| 563 | |
| 564 | fn encode_single_token(&self, piece: &[u8]) -> PyResult<Rank> { |
| 565 | if let Some(token) = self.encoder.get(piece).copied() { |
| 566 | return Ok(token); |
| 567 | } |
| 568 | if let Ok(piece_str) = std::str::from_utf8(piece) { |
| 569 | if let Some(token) = self.special_tokens_encoder.get(piece_str).copied() { |
| 570 | return Ok(token); |
| 571 | } |
| 572 | } |
| 573 | Err(PyErr::new::<exceptions::PyKeyError, _>(piece.to_owned())) |
| 574 | } |
| 575 | |
| 576 | fn encode_single_piece(&self, piece: &[u8]) -> Vec<Rank> { |
| 577 | if let Some(token) = self.encoder.get(piece) { |
| 578 | return vec![*token]; |
| 579 | } |
| 580 | byte_pair_encode(piece, &self.encoder) |
| 581 | } |
| 582 | |
| 583 | // ==================== |
| 584 | // Decoding |
| 585 | // ==================== |
| 586 | |
| 587 | fn decode_bytes(&self, py: Python, tokens: Vec<Rank>) -> Result<Py<PyBytes>, PyErr> { |
| 588 | match py.allow_threads(|| self._decode_native(&tokens)) { |
| 589 | Ok(bytes) => Ok(PyBytes::new_bound(py, &bytes).into()), |
| 590 | Err(e) => Err(pyo3::exceptions::PyKeyError::new_err(format!("{}", e))), |
| 591 | } |
| 592 | } |
| 593 | |
| 594 | fn decode_single_token_bytes(&self, py: Python, token: Rank) -> PyResult<Py<PyBytes>> { |
| 595 | if let Some(bytes) = self.decoder.get(&token) { |
| 596 | return Ok(PyBytes::new_bound(py, bytes).into()); |
| 597 | } |
| 598 | if let Some(bytes) = self.special_tokens_decoder.get(&token) { |
| 599 | return Ok(PyBytes::new_bound(py, bytes).into()); |
| 600 | } |
| 601 | Err(PyErr::new::<exceptions::PyKeyError, _>(token.to_string())) |
| 602 | } |
| 603 | |
| 604 | // ==================== |
| 605 | // Miscellaneous |
| 606 | // ==================== |
| 607 | |
| 608 | fn token_byte_values(&self, py: Python) -> Vec<Py<PyBytes>> { |
| 609 | self.sorted_token_bytes |
| 610 | .iter() |
| 611 | .map(|x| PyBytes::new_bound(py, x).into()) |
| 612 | .collect() |
| 613 | } |
| 614 | } |
| 615 | |
| 616 | #[pyclass] |
| 617 | struct TiktokenBuffer { |
| 618 | tokens: Vec<Rank>, |
| 619 | } |
| 620 | |
| 621 | #[pymethods] |
| 622 | impl TiktokenBuffer { |
| 623 | // Based on https://github.com/PyO3/pyo3/blob/v0.22.2/tests/test_buffer_protocol.rs#L25 |
| 624 | unsafe fn __getbuffer__( |
| 625 | slf: Bound<'_, Self>, |
| 626 | view: *mut pyo3::ffi::Py_buffer, |
| 627 | flags: std::os::raw::c_int, |
| 628 | ) -> PyResult<()> { |
| 629 | if view.is_null() { |
| 630 | return Err(pyo3::exceptions::PyBufferError::new_err("View is null")); |
| 631 | } |
| 632 | if (flags & pyo3::ffi::PyBUF_WRITABLE) == pyo3::ffi::PyBUF_WRITABLE { |
| 633 | return Err(pyo3::exceptions::PyBufferError::new_err( |
| 634 | "Object is not writable", |
| 635 | )); |
| 636 | } |
| 637 | |
| 638 | (*view).obj = slf.clone().into_any().into_ptr(); |
| 639 | |
| 640 | let data = &slf.borrow().tokens; |
| 641 | (*view).buf = data.as_ptr() as *mut std::os::raw::c_void; |
| 642 | (*view).len = (data.len() * std::mem::size_of::<Rank>()) as isize; |
| 643 | (*view).readonly = 1; |
| 644 | (*view).itemsize = std::mem::size_of::<Rank>() as isize; |
| 645 | (*view).format = if (flags & pyo3::ffi::PyBUF_FORMAT) == pyo3::ffi::PyBUF_FORMAT { |
| 646 | let msg = std::ffi::CString::new("I").unwrap(); |
| 647 | msg.into_raw() |
| 648 | } else { |
| 649 | std::ptr::null_mut() |
| 650 | }; |
| 651 | (*view).ndim = 1; |
| 652 | (*view).shape = if (flags & pyo3::ffi::PyBUF_ND) == pyo3::ffi::PyBUF_ND { |
| 653 | &mut (*view).len |
| 654 | } else { |
| 655 | std::ptr::null_mut() |
| 656 | }; |
| 657 | (*view).strides = if (flags & pyo3::ffi::PyBUF_STRIDES) == pyo3::ffi::PyBUF_STRIDES { |
| 658 | &mut (*view).itemsize |
| 659 | } else { |
| 660 | std::ptr::null_mut() |
| 661 | }; |
| 662 | (*view).suboffsets = std::ptr::null_mut(); |
| 663 | (*view).internal = std::ptr::null_mut(); |
| 664 | |
| 665 | Ok(()) |
| 666 | } |
| 667 | |
| 668 | unsafe fn __releasebuffer__(&self, view: *mut pyo3::ffi::Py_buffer) { |
| 669 | std::mem::drop(std::ffi::CString::from_raw((*view).format)); |
| 670 | } |
| 671 | } |
| 672 | |
| 673 | #[pymodule] |
| 674 | fn _tiktoken(_py: Python, m: &Bound<PyModule>) -> PyResult<()> { |
| 675 | m.add_class::<CoreBPE>()?; |
| 676 | Ok(()) |
| 677 | } |
| 678 | |
| 679 | #[cfg(test)] |
| 680 | mod tests { |
| 681 | |
| 682 | use rustc_hash::FxHashMap as HashMap; |
| 683 | |
| 684 | use crate::{byte_pair_split, Rank}; |
| 685 | |
| 686 | fn setup_ranks() -> HashMap<Vec<u8>, Rank> { |
| 687 | HashMap::from_iter([(b"ab".to_vec(), 0), (b"cd".to_vec(), 1)]) |
| 688 | } |
| 689 | |
| 690 | #[test] |
| 691 | fn test_simple_characters() { |
| 692 | let ranks = setup_ranks(); |
| 693 | let res = byte_pair_split(b"abcd", &ranks); |
| 694 | assert_eq!(res, vec![b"ab", b"cd"]); |
| 695 | } |
| 696 | |
| 697 | #[test] |
| 698 | fn test_repeated_characters() { |
| 699 | let ranks = setup_ranks(); |
| 700 | let res = byte_pair_split(b"abab", &ranks); |
| 701 | assert_eq!(res, vec![b"ab", b"ab"]); |
| 702 | } |
| 703 | } |
| 704 | |