openai/tiktoken
Publicmirrored fromhttps://github.com/openai/tiktokenAvailable
src/lib.rs
608lines · modecode
| 1 | // This check is new and seems buggy (possibly with PyO3 interaction) |
| 2 | #![allow(clippy::borrow_deref_ref)] |
| 3 | |
| 4 | use std::collections::HashSet; |
| 5 | use std::thread; |
| 6 | |
| 7 | use fancy_regex::Regex; |
| 8 | use pyo3::exceptions; |
| 9 | use pyo3::prelude::*; |
| 10 | use pyo3::types::{PyBytes, PyList, PyTuple}; |
| 11 | use pyo3::PyResult; |
| 12 | use rustc_hash::FxHashMap as HashMap; |
| 13 | |
| 14 | fn _byte_pair_merge<T>( |
| 15 | piece: &[u8], |
| 16 | ranks: &HashMap<Vec<u8>, usize>, |
| 17 | f: impl Fn(std::ops::Range<usize>) -> T, |
| 18 | ) -> Vec<T> { |
| 19 | // This is a vector of (start, rank). |
| 20 | // The rank is of the byte pair starting at position start. |
| 21 | // The rank of the last item in the vector is not a valid value. |
| 22 | let mut parts: Vec<(usize, usize)> = (0..piece.len() + 1).map(|i| (i, usize::MAX)).collect(); |
| 23 | |
| 24 | let get_rank = { |
| 25 | #[inline(always)] |
| 26 | |parts: &Vec<(usize, usize)>, start_idx: usize, skip: usize| { |
| 27 | if (start_idx + skip + 2) < parts.len() { |
| 28 | ranks |
| 29 | .get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0]) |
| 30 | .copied() |
| 31 | } else { |
| 32 | None |
| 33 | } |
| 34 | } |
| 35 | }; |
| 36 | |
| 37 | // We look up the ranks once in the beginning and iteratively update |
| 38 | // them during each merge, which reduces the number of rank lookups. |
| 39 | for i in 0..parts.len() - 2 { |
| 40 | match get_rank(&parts, i, 0) { |
| 41 | Some(rank) => { |
| 42 | // usize::MAX is a sentinel value and cannot be a valid rank |
| 43 | debug_assert!(rank != usize::MAX); |
| 44 | parts[i].1 = rank; |
| 45 | } |
| 46 | None => { |
| 47 | continue; |
| 48 | } |
| 49 | }; |
| 50 | } |
| 51 | |
| 52 | // If you have n parts and m merges, this does O(mn) work. |
| 53 | // We could do something with a heap and do O(m log n) work. |
| 54 | // It is important to consider that n is often small (<100), and as such |
| 55 | // the cache-locality benefits outweigh the algorithmic complexity downsides |
| 56 | // of the `parts` vector data structure above. |
| 57 | |
| 58 | // Note that we hash bytes, not token pairs. As long as we train BPE the way we |
| 59 | // currently do, this is equivalent. An easy way to break this would be to decouple |
| 60 | // merge priority from token index or to prevent specific token merges. |
| 61 | loop { |
| 62 | if parts.len() == 1 { |
| 63 | break; |
| 64 | } |
| 65 | |
| 66 | // usize::MAX is a sentinel rank value allowing us to |
| 67 | // take the min more quickly |
| 68 | let mut min_rank: (usize, usize) = (usize::MAX, 0); |
| 69 | for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() { |
| 70 | if rank < min_rank.0 { |
| 71 | min_rank = (rank, i); |
| 72 | } |
| 73 | } |
| 74 | |
| 75 | if min_rank.0 != usize::MAX { |
| 76 | let i = min_rank.1; |
| 77 | |
| 78 | // NOTE: We are about to remove parts[i + 1]. We do not do it |
| 79 | // yet because there are cache-locality benefits to updating |
| 80 | // parts[i] and parts[i-1] before removing, which could thrash |
| 81 | // the cache. Thus, we update the rank calculation by skipping over |
| 82 | // parts[i + 1], by invoking `get_rank!` with `skip = 1`. |
| 83 | parts[i].1 = get_rank(&parts, i, 1).unwrap_or(usize::MAX); |
| 84 | if i > 0 { |
| 85 | parts[i - 1].1 = get_rank(&parts, i - 1, 1).unwrap_or(usize::MAX); |
| 86 | } |
| 87 | |
| 88 | parts.remove(i + 1); |
| 89 | } else { |
| 90 | break; |
| 91 | } |
| 92 | } |
| 93 | let mut out: Vec<T> = Vec::with_capacity(parts.len() - 1); |
| 94 | for i in 0..parts.len() - 1 { |
| 95 | out.push(f(parts[i].0..parts[i + 1].0)); |
| 96 | } |
| 97 | out |
| 98 | } |
| 99 | |
| 100 | pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<usize> { |
| 101 | if piece.len() == 1 { |
| 102 | return vec![ranks[piece]]; |
| 103 | } |
| 104 | _byte_pair_merge(piece, ranks, |p| ranks[&piece[p.start..p.end]]) |
| 105 | } |
| 106 | |
| 107 | pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<&'a [u8]> { |
| 108 | if piece.len() == 1 { |
| 109 | return vec![piece]; |
| 110 | } |
| 111 | _byte_pair_merge(piece, ranks, |p| &piece[p.start..p.end]) |
| 112 | } |
| 113 | |
| 114 | // Various performance notes: |
| 115 | // |
| 116 | // Regex |
| 117 | // ===== |
| 118 | // Most of the time is spent in regex. The easiest way to speed this up is by using less fancy |
| 119 | // regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than |
| 120 | // the usual regex we use. |
| 121 | // |
| 122 | // However, given that we're using a regex parse-able by `regex`, there isn't much difference |
| 123 | // between using the `regex` crate and using the `fancy_regex` crate. |
| 124 | // |
| 125 | // There is an important interaction between threading, `regex` and `fancy_regex`. |
| 126 | // When using `fancy_regex`, we hit `regex.find_at`. It turns out that this causes contention on |
| 127 | // some mutable scratch space inside of `regex`. This absolutely kills performance. When using plain |
| 128 | // old `regex`, we don't hit this, because `find_iter` has a different code path. |
| 129 | // Related: https://github.com/rust-lang/regex/blob/master/PERFORMANCE.md |
| 130 | // Anyway, the way we get around this is with having a (mostly) thread local clone of the regex for |
| 131 | // each thread. |
| 132 | // |
| 133 | // Threading |
| 134 | // ========= |
| 135 | // I tried using `rayon`. It wasn't really faster than using Python threads and releasing the GIL. |
| 136 | // So goodbye `rayon`! Let thread count etc be in control of our Python users. |
| 137 | // |
| 138 | // Caching |
| 139 | // ======= |
| 140 | // The reference tokeniser has an lru cache over the equivalent of `byte_pair_encode`. |
| 141 | // Originally, we had one too! Without it, we were only vaguely faster than Python. |
| 142 | // I used an RWLock to protect the cache. This didn't seem to hurt single threaded performance |
| 143 | // noticeably, but it did affect multi-threaded performance. Weirdly, it seemed to affect |
| 144 | // multi-threaded performance even when I only had readers (maybed I messed something up?). |
| 145 | // Anyway, I realised that we could get rid of the cache, if we treat the set of tokens as a cache! |
| 146 | // These are exactly the set or merges that are likely to be hot. And now we don't have to think |
| 147 | // about interior mutability, memory use, or cloning. |
| 148 | // |
| 149 | // Hashing |
| 150 | // ======= |
| 151 | // We use FxHashMap instead of the standard HashMap. This is maybe like a 5-10% win? |
| 152 | // The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made |
| 153 | // to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster. |
| 154 | |
| 155 | use std::num::NonZeroU64; |
| 156 | pub struct FakeThreadId(NonZeroU64); |
| 157 | |
| 158 | fn hash_current_thread() -> usize { |
| 159 | // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter |
| 160 | // that works great for our use case of avoiding collisions in our array. Unfortunately, |
| 161 | // it's private. However, there are only so many ways you can layout a u64, so just transmute |
| 162 | // https://github.com/rust-lang/rust/issues/67939 |
| 163 | const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()]; |
| 164 | const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()]; |
| 165 | let x = unsafe { |
| 166 | std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0 |
| 167 | }; |
| 168 | u64::from(x) as usize |
| 169 | } |
| 170 | |
| 171 | const MAX_NUM_THREADS: usize = 128; |
| 172 | #[pyclass] |
| 173 | struct CoreBPE { |
| 174 | encoder: HashMap<Vec<u8>, usize>, |
| 175 | special_tokens_encoder: HashMap<String, usize>, |
| 176 | decoder: HashMap<usize, Vec<u8>>, |
| 177 | special_tokens_decoder: HashMap<usize, Vec<u8>>, |
| 178 | regex_tls: Vec<Regex>, |
| 179 | special_regex_tls: Vec<Regex>, |
| 180 | sorted_token_bytes: Vec<Vec<u8>>, |
| 181 | } |
| 182 | |
| 183 | impl CoreBPE { |
| 184 | fn _get_tl_regex(&self) -> &Regex { |
| 185 | // See performance notes above for what this is about |
| 186 | // It's also a little janky, please make a better version of it! |
| 187 | // However, it's nice that this doesn't leak memory to short-lived threads |
| 188 | &self.regex_tls[hash_current_thread() % MAX_NUM_THREADS] |
| 189 | } |
| 190 | |
| 191 | fn _get_tl_special_regex(&self) -> &Regex { |
| 192 | &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS] |
| 193 | } |
| 194 | |
| 195 | fn _decode_native(&self, tokens: &[usize]) -> Vec<u8> { |
| 196 | let mut ret = Vec::with_capacity(tokens.len() * 2); |
| 197 | for token in tokens { |
| 198 | let token_bytes = self |
| 199 | .decoder |
| 200 | .get(token) |
| 201 | .unwrap_or_else(|| &self.special_tokens_decoder[token]); |
| 202 | ret.extend(token_bytes); |
| 203 | } |
| 204 | ret |
| 205 | } |
| 206 | |
| 207 | fn _encode_ordinary_native(&self, text: &str) -> Vec<usize> { |
| 208 | // This is the core of the encoding logic; the other functions in here |
| 209 | // just make things complicated :-) |
| 210 | let regex = self._get_tl_regex(); |
| 211 | let mut ret = vec![]; |
| 212 | for mat in regex.find_iter(text) { |
| 213 | let piece = mat.unwrap().as_str().as_bytes(); |
| 214 | if let Some(token) = self.encoder.get(piece) { |
| 215 | ret.push(*token); |
| 216 | continue; |
| 217 | } |
| 218 | ret.extend(&byte_pair_encode(piece, &self.encoder)); |
| 219 | } |
| 220 | ret |
| 221 | } |
| 222 | |
| 223 | fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<usize>, usize) { |
| 224 | let special_regex = self._get_tl_special_regex(); |
| 225 | let regex = self._get_tl_regex(); |
| 226 | let mut ret = vec![]; |
| 227 | |
| 228 | let mut start = 0; |
| 229 | let mut last_piece_token_len = 0; |
| 230 | loop { |
| 231 | let mut next_special; |
| 232 | let mut start_find = start; |
| 233 | loop { |
| 234 | // Find the next allowed special token, if any |
| 235 | next_special = special_regex.find_from_pos(text, start_find).unwrap(); |
| 236 | match next_special { |
| 237 | Some(m) => { |
| 238 | if allowed_special.contains(&text[m.start()..m.end()]) { |
| 239 | break; |
| 240 | } |
| 241 | start_find = m.start() + 1; |
| 242 | } |
| 243 | None => break, |
| 244 | } |
| 245 | } |
| 246 | let end = next_special.map_or(text.len(), |m| m.start()); |
| 247 | |
| 248 | // Okay, here we go, compare this logic to _encode_ordinary_native |
| 249 | for mat in regex.find_iter(&text[start..end]) { |
| 250 | let piece = mat.unwrap().as_str().as_bytes(); |
| 251 | if let Some(token) = self.encoder.get(piece) { |
| 252 | last_piece_token_len = 1; |
| 253 | ret.push(*token); |
| 254 | continue; |
| 255 | } |
| 256 | let tokens = byte_pair_encode(piece, &self.encoder); |
| 257 | last_piece_token_len = tokens.len(); |
| 258 | ret.extend(&tokens); |
| 259 | } |
| 260 | |
| 261 | match next_special { |
| 262 | // And here we push the special token |
| 263 | Some(m) => { |
| 264 | let piece = m.as_str(); |
| 265 | let token = self.special_tokens_encoder[piece]; |
| 266 | ret.push(token); |
| 267 | start = m.end(); |
| 268 | last_piece_token_len = 0; |
| 269 | } |
| 270 | None => break, |
| 271 | } |
| 272 | } |
| 273 | |
| 274 | // last_piece_token_len is how many tokens came from the last regex split. This is used |
| 275 | // for determining unstable tokens, since you can't merge across (stable) regex splits |
| 276 | (ret, last_piece_token_len) |
| 277 | } |
| 278 | |
| 279 | fn _increase_last_piece_token_len( |
| 280 | &self, |
| 281 | tokens: Vec<usize>, |
| 282 | mut last_piece_token_len: usize, |
| 283 | ) -> (Vec<usize>, usize) { |
| 284 | // Unfortunately, the locations where our regex splits can be unstable. |
| 285 | // For the purposes of determining unstable tokens, unstable regex splitting |
| 286 | // is only a problem if a split that was present disappears, since this can |
| 287 | // lead to merging of tokens otherwise thought to be stable. |
| 288 | // cl100k_base makes our life hard by including the \s*[\r\n]+ |
| 289 | // pattern. This can e.g. cause "\n" + " " to become "\n \n". |
| 290 | // Here is a quick and dirty fix: |
| 291 | { |
| 292 | let token_is_all_space = |token| { |
| 293 | self.decoder |
| 294 | .get(token) |
| 295 | .map(|token_bytes| { |
| 296 | token_bytes |
| 297 | .iter() |
| 298 | .rev() |
| 299 | .all(|&b| [b' ', b'\n', b'\t'].contains(&b)) |
| 300 | }) |
| 301 | .unwrap_or(false) |
| 302 | }; |
| 303 | if last_piece_token_len > 0 |
| 304 | && token_is_all_space(&tokens[tokens.len() - last_piece_token_len]) |
| 305 | { |
| 306 | while (last_piece_token_len < tokens.len()) |
| 307 | && token_is_all_space(&tokens[tokens.len() - last_piece_token_len - 1]) |
| 308 | { |
| 309 | last_piece_token_len += 1; |
| 310 | } |
| 311 | } |
| 312 | } |
| 313 | debug_assert!(last_piece_token_len <= tokens.len()); |
| 314 | |
| 315 | (tokens, last_piece_token_len) |
| 316 | } |
| 317 | |
| 318 | fn _encode_unstable_native( |
| 319 | &self, |
| 320 | text: &str, |
| 321 | allowed_special: &HashSet<&str>, |
| 322 | ) -> (Vec<usize>, HashSet<Vec<usize>>) { |
| 323 | let (tokens, last_piece_token_len) = self._encode_native(text, allowed_special); |
| 324 | if last_piece_token_len == 0 { |
| 325 | // If last_piece_token_len is zero, the last token was a special token and we have |
| 326 | // no unstable bytes |
| 327 | return (tokens, HashSet::new()); |
| 328 | } |
| 329 | let (mut tokens, last_piece_token_len) = |
| 330 | self._increase_last_piece_token_len(tokens, last_piece_token_len); |
| 331 | |
| 332 | let unstable_bytes = self._decode_native(&tokens[tokens.len() - last_piece_token_len..]); |
| 333 | tokens.truncate(tokens.len() - last_piece_token_len); |
| 334 | |
| 335 | // TODO: we should try harder to find additional stable tokens |
| 336 | // This would reduce the amount of retokenising when determining completions |
| 337 | // Refer to the logic in an older version of this file |
| 338 | |
| 339 | let mut completions = HashSet::new(); |
| 340 | if unstable_bytes.is_empty() { |
| 341 | return (tokens, completions); |
| 342 | } |
| 343 | |
| 344 | // This is the easy bit. Just find all single tokens that start with unstable_bytes |
| 345 | // (including tokens that exactly match unstable_bytes) |
| 346 | // Separating this from the loop below helps with performance in a common case. |
| 347 | let mut point = self |
| 348 | .sorted_token_bytes |
| 349 | .partition_point(|x| x.as_slice() < unstable_bytes.as_slice()); |
| 350 | while point < self.sorted_token_bytes.len() |
| 351 | && self.sorted_token_bytes[point].starts_with(&unstable_bytes) |
| 352 | { |
| 353 | completions.insert(vec![ |
| 354 | self.encoder[self.sorted_token_bytes[point].as_slice()], |
| 355 | ]); |
| 356 | point += 1; |
| 357 | } |
| 358 | |
| 359 | // Now apply even more brute force. At every (other) possible position for the straddling |
| 360 | // token, concatenate additional bytes from that token (if any) to unstable_bytes, |
| 361 | // and retokenise the whole thing and see what we get. |
| 362 | for i in 1..unstable_bytes.len() { |
| 363 | let prefix = &unstable_bytes[..i]; |
| 364 | let suffix = &unstable_bytes[i..]; |
| 365 | let mut point = self |
| 366 | .sorted_token_bytes |
| 367 | .partition_point(|x| x.as_slice() < suffix); |
| 368 | // TODO: Perf optimisation if suffix starts with " "? |
| 369 | while point < self.sorted_token_bytes.len() |
| 370 | && self.sorted_token_bytes[point].starts_with(suffix) |
| 371 | { |
| 372 | let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat(); |
| 373 | let encoded = match std::str::from_utf8(&possibility) { |
| 374 | // Morally, this is byte_pair_encode(&possibility, &self.encoder) |
| 375 | // But we might have introduced a regex split which would prevent merges. |
| 376 | // (particularly possible in the presence of unstable regex splits) |
| 377 | // So convert to UTF-8 and do regex splitting. |
| 378 | // E.g. with cl100k_base " !" gets split to " " + " !", |
| 379 | // but byte_pair_encode(" !") != byte_pair_encode(" ") |
| 380 | Ok(s) => self._encode_ordinary_native(s), |
| 381 | |
| 382 | // Technically, whether or not this arm is correct depends on whether there |
| 383 | // would be a regex split before the UTF-8 truncation point. |
| 384 | // Probably niche enough that no one will ever notice (after all, people didn't |
| 385 | // notice all the big holes in the previous unstable token implementation) |
| 386 | Err(_) => byte_pair_encode(&possibility, &self.encoder), |
| 387 | // Something like the following is intriguing but incorrect: |
| 388 | // Err(e) => self._encode_ordinary_native(unsafe { |
| 389 | // std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()]) |
| 390 | // }), |
| 391 | }; |
| 392 | let mut seq = Vec::new(); |
| 393 | let mut seq_len = 0; |
| 394 | for token in encoded { |
| 395 | seq.push(token); |
| 396 | seq_len += self.decoder[&token].len(); |
| 397 | if seq_len >= unstable_bytes.len() { |
| 398 | break; |
| 399 | } |
| 400 | } |
| 401 | completions.insert(seq); |
| 402 | point += 1; |
| 403 | } |
| 404 | } |
| 405 | |
| 406 | // This is also not straightforward. While we generally assume that regex splits are stable, |
| 407 | // unfortunately, they are not. That is, if adding bytes were to make a split appear in |
| 408 | // unstable_bytes, this could make tokens possible which our logic would otherwise think |
| 409 | // would be merged. |
| 410 | // For example, with gpt2, the use of \s+(?!\S) means that "\n\n" could |
| 411 | // develop a split, e.g. "\n\n0" splits into "\n"+"\n"+"0", making "\n" a possible token. |
| 412 | // Here is a quick and dirty fix: |
| 413 | // This isn't right if we ever remove \s+(?!\S) |
| 414 | if unstable_bytes.len() > 1 { |
| 415 | let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice()); |
| 416 | if unstable_bytes.len() - last_decoded.1 > 0 |
| 417 | && last_decoded.0.map_or(false, |c| c.is_whitespace()) |
| 418 | { |
| 419 | let mut reencoded = byte_pair_encode( |
| 420 | &unstable_bytes[..unstable_bytes.len() - last_decoded.1], |
| 421 | &self.encoder, |
| 422 | ); |
| 423 | reencoded.extend(byte_pair_encode( |
| 424 | &unstable_bytes[unstable_bytes.len() - last_decoded.1..], |
| 425 | &self.encoder, |
| 426 | )); |
| 427 | completions.insert(reencoded); |
| 428 | } |
| 429 | } |
| 430 | |
| 431 | (tokens, completions) |
| 432 | } |
| 433 | } |
| 434 | |
| 435 | #[pymethods] |
| 436 | impl CoreBPE { |
| 437 | #[new] |
| 438 | fn new( |
| 439 | encoder: HashMap<Vec<u8>, usize>, |
| 440 | special_tokens_encoder: HashMap<String, usize>, |
| 441 | pattern: &str, |
| 442 | ) -> PyResult<Self> { |
| 443 | let regex = Regex::new(pattern) |
| 444 | .map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?; |
| 445 | |
| 446 | let special_regex = { |
| 447 | let _parts = special_tokens_encoder |
| 448 | .keys() |
| 449 | .map(|s| fancy_regex::escape(s)) |
| 450 | .collect::<Vec<_>>(); |
| 451 | Regex::new(&_parts.join("|")) |
| 452 | .map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))? |
| 453 | }; |
| 454 | |
| 455 | let decoder: HashMap<usize, Vec<u8>> = |
| 456 | encoder.iter().map(|(k, v)| (*v, k.clone())).collect(); |
| 457 | |
| 458 | assert!( |
| 459 | encoder.len() == decoder.len(), |
| 460 | "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?" |
| 461 | ); |
| 462 | |
| 463 | let special_tokens_decoder: HashMap<usize, Vec<u8>> = special_tokens_encoder |
| 464 | .iter() |
| 465 | .map(|(k, v)| (*v, k.as_bytes().to_vec())) |
| 466 | .collect(); |
| 467 | |
| 468 | // Clone because I don't know how to tell Rust I'm not going to change the map |
| 469 | let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect(); |
| 470 | sorted_token_bytes.sort(); |
| 471 | |
| 472 | Ok(CoreBPE { |
| 473 | encoder, |
| 474 | special_tokens_encoder, |
| 475 | decoder, |
| 476 | special_tokens_decoder, |
| 477 | regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(), |
| 478 | special_regex_tls: (0..MAX_NUM_THREADS) |
| 479 | .map(|_| special_regex.clone()) |
| 480 | .collect(), |
| 481 | sorted_token_bytes, |
| 482 | }) |
| 483 | } |
| 484 | |
| 485 | // ==================== |
| 486 | // Encoding |
| 487 | // ==================== |
| 488 | |
| 489 | fn encode_ordinary(&self, py: Python, text: &str) -> Vec<usize> { |
| 490 | py.allow_threads(|| self._encode_ordinary_native(text)) |
| 491 | } |
| 492 | |
| 493 | fn encode(&self, py: Python, text: &str, allowed_special: HashSet<&str>) -> Vec<usize> { |
| 494 | py.allow_threads(|| self._encode_native(text, &allowed_special).0) |
| 495 | } |
| 496 | |
| 497 | fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<usize> { |
| 498 | py.allow_threads(|| { |
| 499 | match std::str::from_utf8(bytes) { |
| 500 | Ok(text) => self._encode_ordinary_native(text), |
| 501 | Err(e) => { |
| 502 | let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) }; |
| 503 | let (tokens, last_piece_token_len) = self._encode_native(text, &HashSet::new()); |
| 504 | let (mut tokens, last_piece_token_len) = |
| 505 | self._increase_last_piece_token_len(tokens, last_piece_token_len); |
| 506 | if !tokens.is_empty() && last_piece_token_len > 0 { |
| 507 | // Lop off the tokens from the last piece and run BPE on the remaining bytes |
| 508 | // Somewhat niche, but this may not be correct if we'd have had a regex |
| 509 | // split between the valid UTF-8 and the invalid bytes, which is why this |
| 510 | // method is private |
| 511 | let mut unstable_bytes = |
| 512 | self._decode_native(&tokens[tokens.len() - last_piece_token_len..]); |
| 513 | unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]); |
| 514 | |
| 515 | tokens.truncate(tokens.len() - last_piece_token_len); |
| 516 | tokens.extend(byte_pair_encode(&unstable_bytes, &self.encoder)); |
| 517 | } |
| 518 | tokens |
| 519 | } |
| 520 | } |
| 521 | }) |
| 522 | } |
| 523 | |
| 524 | fn encode_with_unstable( |
| 525 | &self, |
| 526 | py: Python, |
| 527 | text: &str, |
| 528 | allowed_special: HashSet<&str>, |
| 529 | ) -> Py<PyTuple> { |
| 530 | let (tokens, completions) = |
| 531 | py.allow_threads(|| self._encode_unstable_native(text, &allowed_special)); |
| 532 | let py_completions = |
| 533 | PyList::new(py, completions.iter().map(|seq| PyList::new(py, &seq[..]))); |
| 534 | (tokens, py_completions).into_py(py) |
| 535 | } |
| 536 | |
| 537 | fn encode_single_token(&self, piece: &[u8]) -> PyResult<usize> { |
| 538 | if let Some(token) = self.encoder.get(piece).copied() { |
| 539 | return Ok(token); |
| 540 | } |
| 541 | if let Ok(piece_str) = std::str::from_utf8(piece) { |
| 542 | if let Some(token) = self.special_tokens_encoder.get(piece_str).copied() { |
| 543 | return Ok(token); |
| 544 | } |
| 545 | } |
| 546 | Err(PyErr::new::<exceptions::PyKeyError, _>(piece.to_owned())) |
| 547 | } |
| 548 | |
| 549 | fn encode_single_piece(&self, piece: &[u8]) -> Vec<usize> { |
| 550 | if let Some(token) = self.encoder.get(piece) { |
| 551 | return vec![*token]; |
| 552 | } |
| 553 | byte_pair_encode(piece, &self.encoder) |
| 554 | } |
| 555 | |
| 556 | // ==================== |
| 557 | // Decoding |
| 558 | // ==================== |
| 559 | |
| 560 | fn decode_bytes(&self, py: Python, tokens: Vec<usize>) -> Py<PyBytes> { |
| 561 | let bytes = py.allow_threads(|| self._decode_native(&tokens)); |
| 562 | PyBytes::new(py, &bytes).into() |
| 563 | } |
| 564 | |
| 565 | fn decode_single_token_bytes(&self, py: Python, token: usize) -> PyResult<Py<PyBytes>> { |
| 566 | if let Some(bytes) = self.decoder.get(&token) { |
| 567 | return Ok(PyBytes::new(py, bytes).into()); |
| 568 | } |
| 569 | if let Some(bytes) = self.special_tokens_decoder.get(&token) { |
| 570 | return Ok(PyBytes::new(py, bytes).into()); |
| 571 | } |
| 572 | Err(PyErr::new::<exceptions::PyKeyError, _>(token.to_string())) |
| 573 | } |
| 574 | |
| 575 | // ==================== |
| 576 | // Miscellaneous |
| 577 | // ==================== |
| 578 | |
| 579 | fn token_byte_values(&self, py: Python) -> Vec<Py<PyBytes>> { |
| 580 | self.sorted_token_bytes |
| 581 | .iter() |
| 582 | .map(|x| PyBytes::new(py, x).into()) |
| 583 | .collect() |
| 584 | } |
| 585 | } |
| 586 | |
| 587 | #[pymodule] |
| 588 | fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> { |
| 589 | m.add_class::<CoreBPE>()?; |
| 590 | Ok(()) |
| 591 | } |
| 592 | |
| 593 | #[cfg(test)] |
| 594 | mod tests { |
| 595 | use rustc_hash::FxHashMap as HashMap; |
| 596 | |
| 597 | use crate::byte_pair_split; |
| 598 | |
| 599 | #[test] |
| 600 | fn very_simple_test() { |
| 601 | let mut ranks = HashMap::default(); |
| 602 | ranks.insert(b"ab".to_vec(), 1); |
| 603 | ranks.insert(b"cd".to_vec(), 2); |
| 604 | |
| 605 | let res = byte_pair_split(b"abcd", &ranks); |
| 606 | assert_eq!(res, vec![b"ab", b"cd"]); |
| 607 | } |
| 608 | } |
| 609 | |