openai/tiktoken
Publicmirrored fromhttps://github.com/openai/tiktokenAvailable
src/lib.rs
559lines · modecode
| 1 | // This check is new and seems buggy (possibly with PyO3 interaction) |
| 2 | #![allow(clippy::borrow_deref_ref)] |
| 3 | |
| 4 | use std::collections::HashSet; |
| 5 | use std::thread; |
| 6 | |
| 7 | use fancy_regex::Regex; |
| 8 | use pyo3::exceptions; |
| 9 | use pyo3::prelude::*; |
| 10 | use pyo3::types::{PyBytes, PyList, PyTuple}; |
| 11 | use pyo3::PyResult; |
| 12 | use rustc_hash::FxHashMap as HashMap; |
| 13 | |
| 14 | fn _byte_pair_merge(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<std::ops::Range<usize>> { |
| 15 | let mut parts: Vec<_> = (0..piece.len()).map(|i| i..i + 1).collect(); |
| 16 | |
| 17 | // If you have n parts and m merges, this does O(mn) work |
| 18 | // We could do something with a heap and do O(m log n) work |
| 19 | |
| 20 | // Note that we hash bytes, not token pairs. As long as we train BPE the way we |
| 21 | // currently do, this is equivalent. An easy way to break this would be to decouple |
| 22 | // merge priority from token index or to prevent specific token merges. |
| 23 | loop { |
| 24 | if parts.len() == 1 { |
| 25 | break; |
| 26 | } |
| 27 | let mut min_rank: Option<(usize, usize)> = None; |
| 28 | for i in 0..parts.len() - 1 { |
| 29 | let rank = if let Some(r) = ranks.get(&piece[parts[i].start..parts[i + 1].end]) { |
| 30 | *r |
| 31 | } else { |
| 32 | continue; |
| 33 | }; |
| 34 | if min_rank.is_none() || rank < min_rank.unwrap().0 { |
| 35 | min_rank = Some((rank, i)); |
| 36 | } |
| 37 | } |
| 38 | if let Some((_, i)) = min_rank { |
| 39 | parts[i] = parts[i].start..parts[i + 1].end; |
| 40 | parts.remove(i + 1); |
| 41 | } else { |
| 42 | break; |
| 43 | } |
| 44 | } |
| 45 | parts |
| 46 | } |
| 47 | |
| 48 | pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<usize> { |
| 49 | if piece.len() == 1 { |
| 50 | return vec![ranks[piece]]; |
| 51 | } |
| 52 | _byte_pair_merge(piece, ranks) |
| 53 | .iter() |
| 54 | .map(|p| ranks[&piece[p.start..p.end]]) |
| 55 | .collect() |
| 56 | } |
| 57 | |
| 58 | pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<&'a [u8]> { |
| 59 | if piece.len() == 1 { |
| 60 | return vec![piece]; |
| 61 | } |
| 62 | _byte_pair_merge(piece, ranks) |
| 63 | .iter() |
| 64 | .map(|p| &piece[p.start..p.end]) |
| 65 | .collect() |
| 66 | } |
| 67 | |
| 68 | // Various performance notes: |
| 69 | // |
| 70 | // Regex |
| 71 | // ===== |
| 72 | // Most of the time is spent in regex. The easiest way to speed this up is by using less fancy |
| 73 | // regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than |
| 74 | // the usual regex we use. |
| 75 | // |
| 76 | // However, given that we're using a regex parse-able by `regex`, there isn't much difference |
| 77 | // between using the `regex` crate and using the `fancy_regex` crate. |
| 78 | // |
| 79 | // There is an important interaction between threading, `regex` and `fancy_regex`. |
| 80 | // When using `fancy_regex`, we hit `regex.find_at`. It turns out that this causes contention on |
| 81 | // some mutable scratch space inside of `regex`. This absolutely kills performance. When using plain |
| 82 | // old `regex`, we don't hit this, because `find_iter` has a different code path. |
| 83 | // Related: https://github.com/rust-lang/regex/blob/master/PERFORMANCE.md |
| 84 | // Anyway, the way we get around this is with having a (mostly) thread local clone of the regex for |
| 85 | // each thread. |
| 86 | // |
| 87 | // Threading |
| 88 | // ========= |
| 89 | // I tried using `rayon`. It wasn't really faster than using Python threads and releasing the GIL. |
| 90 | // So goodbye `rayon`! Let thread count etc be in control of our Python users. |
| 91 | // |
| 92 | // Caching |
| 93 | // ======= |
| 94 | // The reference tokeniser has an lru cache over the equivalent of `byte_pair_encode`. |
| 95 | // Originally, we had one too! Without it, we were only vaguely faster than Python. |
| 96 | // I used an RWLock to protect the cache. This didn't seem to hurt single threaded performance |
| 97 | // noticeably, but it did affect multi-threaded performance. Weirdly, it seemed to affect |
| 98 | // multi-threaded performance even when I only had readers (maybed I messed something up?). |
| 99 | // Anyway, I realised that we could get rid of the cache, if we treat the set of tokens as a cache! |
| 100 | // These are exactly the set or merges that are likely to be hot. And now we don't have to think |
| 101 | // about interior mutability, memory use, or cloning. |
| 102 | // |
| 103 | // Hashing |
| 104 | // ======= |
| 105 | // We use FxHashMap instead of the standard HashMap. This is maybe like a 5-10% win? |
| 106 | // The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made |
| 107 | // to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster. |
| 108 | |
| 109 | use std::num::NonZeroU64; |
| 110 | pub struct FakeThreadId(NonZeroU64); |
| 111 | |
| 112 | fn hash_current_thread() -> usize { |
| 113 | // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter |
| 114 | // that works great for our use case of avoiding collisions in our array. Unfortunately, |
| 115 | // it's private. However, there are only so many ways you can layout a u64, so just transmute |
| 116 | // https://github.com/rust-lang/rust/issues/67939 |
| 117 | const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()]; |
| 118 | const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()]; |
| 119 | let x = unsafe { |
| 120 | std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0 |
| 121 | }; |
| 122 | u64::from(x) as usize |
| 123 | } |
| 124 | |
| 125 | const MAX_NUM_THREADS: usize = 128; |
| 126 | #[pyclass] |
| 127 | struct CoreBPE { |
| 128 | encoder: HashMap<Vec<u8>, usize>, |
| 129 | special_tokens_encoder: HashMap<String, usize>, |
| 130 | decoder: HashMap<usize, Vec<u8>>, |
| 131 | special_tokens_decoder: HashMap<usize, Vec<u8>>, |
| 132 | regex_tls: Vec<Regex>, |
| 133 | special_regex_tls: Vec<Regex>, |
| 134 | sorted_token_bytes: Vec<Vec<u8>>, |
| 135 | } |
| 136 | |
| 137 | impl CoreBPE { |
| 138 | fn _get_tl_regex(&self) -> &Regex { |
| 139 | // See performance notes above for what this is about |
| 140 | // It's also a little janky, please make a better version of it! |
| 141 | // However, it's nice that this doesn't leak memory to short-lived threads |
| 142 | &self.regex_tls[hash_current_thread() % MAX_NUM_THREADS] |
| 143 | } |
| 144 | |
| 145 | fn _get_tl_special_regex(&self) -> &Regex { |
| 146 | &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS] |
| 147 | } |
| 148 | |
| 149 | fn _decode_native(&self, tokens: &[usize]) -> Vec<u8> { |
| 150 | let mut ret = Vec::with_capacity(tokens.len() * 2); |
| 151 | for token in tokens { |
| 152 | let token_bytes = self |
| 153 | .decoder |
| 154 | .get(token) |
| 155 | .unwrap_or_else(|| &self.special_tokens_decoder[token]); |
| 156 | ret.extend(token_bytes); |
| 157 | } |
| 158 | ret |
| 159 | } |
| 160 | |
| 161 | fn _encode_ordinary_native(&self, text: &str) -> Vec<usize> { |
| 162 | // This is the core of the encoding logic; the other functions in here |
| 163 | // just make things complicated :-) |
| 164 | let regex = self._get_tl_regex(); |
| 165 | let mut ret = vec![]; |
| 166 | for mat in regex.find_iter(text) { |
| 167 | let piece = mat.unwrap().as_str().as_bytes(); |
| 168 | if let Some(token) = self.encoder.get(piece) { |
| 169 | ret.push(*token); |
| 170 | continue; |
| 171 | } |
| 172 | ret.extend(&byte_pair_encode(piece, &self.encoder)); |
| 173 | } |
| 174 | ret |
| 175 | } |
| 176 | |
| 177 | fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<usize>, usize) { |
| 178 | let special_regex = self._get_tl_special_regex(); |
| 179 | let regex = self._get_tl_regex(); |
| 180 | let mut ret = vec![]; |
| 181 | |
| 182 | let mut start = 0; |
| 183 | let mut last_piece_token_len = 0; |
| 184 | loop { |
| 185 | let mut next_special; |
| 186 | let mut start_find = start; |
| 187 | loop { |
| 188 | // Find the next allowed special token, if any |
| 189 | next_special = special_regex.find_from_pos(text, start_find).unwrap(); |
| 190 | match next_special { |
| 191 | Some(m) => { |
| 192 | if allowed_special.contains(&text[m.start()..m.end()]) { |
| 193 | break; |
| 194 | } |
| 195 | start_find = m.start() + 1; |
| 196 | } |
| 197 | None => break, |
| 198 | } |
| 199 | } |
| 200 | let end = next_special.map_or(text.len(), |m| m.start()); |
| 201 | |
| 202 | // Okay, here we go, compare this logic to _encode_ordinary_native |
| 203 | for mat in regex.find_iter(&text[start..end]) { |
| 204 | let piece = mat.unwrap().as_str().as_bytes(); |
| 205 | if let Some(token) = self.encoder.get(piece) { |
| 206 | last_piece_token_len = 1; |
| 207 | ret.push(*token); |
| 208 | continue; |
| 209 | } |
| 210 | let tokens = byte_pair_encode(piece, &self.encoder); |
| 211 | last_piece_token_len = tokens.len(); |
| 212 | ret.extend(&tokens); |
| 213 | } |
| 214 | |
| 215 | match next_special { |
| 216 | // And here we push the special token |
| 217 | Some(m) => { |
| 218 | let piece = m.as_str(); |
| 219 | let token = self.special_tokens_encoder[piece]; |
| 220 | ret.push(token); |
| 221 | start = m.end(); |
| 222 | last_piece_token_len = 0; |
| 223 | } |
| 224 | None => break, |
| 225 | } |
| 226 | } |
| 227 | |
| 228 | // last_piece_token_len is how many tokens came from the last regex split. This is used |
| 229 | // for determining unstable tokens, since you can't merge across (stable) regex splits |
| 230 | (ret, last_piece_token_len) |
| 231 | } |
| 232 | |
| 233 | fn _increase_last_piece_token_len( |
| 234 | &self, |
| 235 | tokens: Vec<usize>, |
| 236 | mut last_piece_token_len: usize, |
| 237 | ) -> (Vec<usize>, usize) { |
| 238 | // Unfortunately, the locations where our regex splits can be unstable. |
| 239 | // For the purposes of determining unstable tokens, unstable regex splitting |
| 240 | // is only a problem if a split that was present disappears, since this can |
| 241 | // lead to merging of tokens otherwise thought to be stable. |
| 242 | // cl100k_base makes our life hard by including the \s*[\r\n]+ |
| 243 | // pattern. This can e.g. cause "\n" + " " to become "\n \n". |
| 244 | // Here is a quick and dirty fix: |
| 245 | { |
| 246 | let token_is_all_space = |token| { |
| 247 | self.decoder |
| 248 | .get(token) |
| 249 | .map(|token_bytes| { |
| 250 | token_bytes |
| 251 | .iter() |
| 252 | .rev() |
| 253 | .all(|&b| [b' ', b'\n', b'\t'].contains(&b)) |
| 254 | }) |
| 255 | .unwrap_or(false) |
| 256 | }; |
| 257 | if last_piece_token_len > 0 |
| 258 | && token_is_all_space(&tokens[tokens.len() - last_piece_token_len]) |
| 259 | { |
| 260 | while (last_piece_token_len < tokens.len()) |
| 261 | && token_is_all_space(&tokens[tokens.len() - last_piece_token_len - 1]) |
| 262 | { |
| 263 | last_piece_token_len += 1; |
| 264 | } |
| 265 | } |
| 266 | } |
| 267 | debug_assert!(last_piece_token_len <= tokens.len()); |
| 268 | |
| 269 | (tokens, last_piece_token_len) |
| 270 | } |
| 271 | |
| 272 | fn _encode_unstable_native( |
| 273 | &self, |
| 274 | text: &str, |
| 275 | allowed_special: &HashSet<&str>, |
| 276 | ) -> (Vec<usize>, HashSet<Vec<usize>>) { |
| 277 | let (tokens, last_piece_token_len) = self._encode_native(text, allowed_special); |
| 278 | if last_piece_token_len == 0 { |
| 279 | // If last_piece_token_len is zero, the last token was a special token and we have |
| 280 | // no unstable bytes |
| 281 | return (tokens, HashSet::new()); |
| 282 | } |
| 283 | let (mut tokens, last_piece_token_len) = |
| 284 | self._increase_last_piece_token_len(tokens, last_piece_token_len); |
| 285 | |
| 286 | let unstable_bytes = self._decode_native(&tokens[tokens.len() - last_piece_token_len..]); |
| 287 | tokens.truncate(tokens.len() - last_piece_token_len); |
| 288 | |
| 289 | // TODO: we should try harder to find additional stable tokens |
| 290 | // This would reduce the amount of retokenising when determining completions |
| 291 | // Refer to the logic in an older version of this file |
| 292 | |
| 293 | let mut completions = HashSet::new(); |
| 294 | if unstable_bytes.is_empty() { |
| 295 | return (tokens, completions); |
| 296 | } |
| 297 | |
| 298 | // This is the easy bit. Just find all single tokens that start with unstable_bytes |
| 299 | // (including tokens that exactly match unstable_bytes) |
| 300 | // Separating this from the loop below helps with performance in a common case. |
| 301 | let mut point = self |
| 302 | .sorted_token_bytes |
| 303 | .partition_point(|x| x.as_slice() < unstable_bytes.as_slice()); |
| 304 | while point < self.sorted_token_bytes.len() |
| 305 | && self.sorted_token_bytes[point].starts_with(&unstable_bytes) |
| 306 | { |
| 307 | completions.insert(vec![ |
| 308 | self.encoder[self.sorted_token_bytes[point].as_slice()], |
| 309 | ]); |
| 310 | point += 1; |
| 311 | } |
| 312 | |
| 313 | // Now apply even more brute force. At every (other) possible position for the straddling |
| 314 | // token, concatenate additional bytes from that token (if any) to unstable_bytes, |
| 315 | // and retokenise the whole thing and see what we get. |
| 316 | for i in 1..unstable_bytes.len() { |
| 317 | let prefix = &unstable_bytes[..i]; |
| 318 | let suffix = &unstable_bytes[i..]; |
| 319 | let mut point = self |
| 320 | .sorted_token_bytes |
| 321 | .partition_point(|x| x.as_slice() < suffix); |
| 322 | // TODO: Perf optimisation if suffix starts with " "? |
| 323 | while point < self.sorted_token_bytes.len() |
| 324 | && self.sorted_token_bytes[point].starts_with(suffix) |
| 325 | { |
| 326 | let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat(); |
| 327 | let encoded = match std::str::from_utf8(&possibility) { |
| 328 | // Morally, this is byte_pair_encode(&possibility, &self.encoder) |
| 329 | // But we might have introduced a regex split which would prevent merges. |
| 330 | // (particularly possible in the presence of unstable regex splits) |
| 331 | // So convert to UTF-8 and do regex splitting. |
| 332 | // E.g. with cl100k_base " !" gets split to " " + " !", |
| 333 | // but byte_pair_encode(" !") != byte_pair_encode(" ") |
| 334 | Ok(s) => self._encode_ordinary_native(s), |
| 335 | |
| 336 | // Technically, whether or not this arm is correct depends on whether there |
| 337 | // would be a regex split before the UTF-8 truncation point. |
| 338 | // Probably niche enough that no one will ever notice (after all, people didn't |
| 339 | // notice all the big holes in the previous unstable token implementation) |
| 340 | Err(_) => byte_pair_encode(&possibility, &self.encoder), |
| 341 | // Something like the following is intriguing but incorrect: |
| 342 | // Err(e) => self._encode_ordinary_native(unsafe { |
| 343 | // std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()]) |
| 344 | // }), |
| 345 | }; |
| 346 | let mut seq = Vec::new(); |
| 347 | let mut seq_len = 0; |
| 348 | for token in encoded { |
| 349 | seq.push(token); |
| 350 | seq_len += self.decoder[&token].len(); |
| 351 | if seq_len >= unstable_bytes.len() { |
| 352 | break; |
| 353 | } |
| 354 | } |
| 355 | completions.insert(seq); |
| 356 | point += 1; |
| 357 | } |
| 358 | } |
| 359 | |
| 360 | // This is also not straightforward. While we generally assume that regex splits are stable, |
| 361 | // unfortunately, they are not. That is, if adding bytes were to make a split appear in |
| 362 | // unstable_bytes, this could make tokens possible which our logic would otherwise think |
| 363 | // would be merged. |
| 364 | // For example, with gpt2, the use of \s+(?!\S) means that "\n\n" could |
| 365 | // develop a split, e.g. "\n\n0" splits into "\n"+"\n"+"0", making "\n" a possible token. |
| 366 | // Here is a quick and dirty fix: |
| 367 | // This isn't right if we ever remove \s+(?!\S) |
| 368 | if unstable_bytes.len() > 1 { |
| 369 | let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice()); |
| 370 | if unstable_bytes.len() - last_decoded.1 > 0 |
| 371 | && last_decoded.0.map_or(false, |c| c.is_whitespace()) |
| 372 | { |
| 373 | let mut reencoded = byte_pair_encode( |
| 374 | &unstable_bytes[..unstable_bytes.len() - last_decoded.1], |
| 375 | &self.encoder, |
| 376 | ); |
| 377 | reencoded.extend(byte_pair_encode( |
| 378 | &unstable_bytes[unstable_bytes.len() - last_decoded.1..], |
| 379 | &self.encoder, |
| 380 | )); |
| 381 | completions.insert(reencoded); |
| 382 | } |
| 383 | } |
| 384 | |
| 385 | (tokens, completions) |
| 386 | } |
| 387 | } |
| 388 | |
| 389 | #[pymethods] |
| 390 | impl CoreBPE { |
| 391 | #[new] |
| 392 | fn new( |
| 393 | encoder: HashMap<Vec<u8>, usize>, |
| 394 | special_tokens_encoder: HashMap<String, usize>, |
| 395 | pattern: &str, |
| 396 | ) -> PyResult<Self> { |
| 397 | let regex = Regex::new(pattern) |
| 398 | .map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?; |
| 399 | |
| 400 | let special_regex = { |
| 401 | let _parts = special_tokens_encoder |
| 402 | .keys() |
| 403 | .map(|s| fancy_regex::escape(s)) |
| 404 | .collect::<Vec<_>>(); |
| 405 | Regex::new(&_parts.join("|")) |
| 406 | .map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))? |
| 407 | }; |
| 408 | |
| 409 | let decoder: HashMap<usize, Vec<u8>> = |
| 410 | encoder.iter().map(|(k, v)| (*v, k.clone())).collect(); |
| 411 | |
| 412 | assert!(encoder.len() == decoder.len()); |
| 413 | |
| 414 | let special_tokens_decoder: HashMap<usize, Vec<u8>> = special_tokens_encoder |
| 415 | .iter() |
| 416 | .map(|(k, v)| (*v, k.as_bytes().to_vec())) |
| 417 | .collect(); |
| 418 | |
| 419 | // Clone because I don't know how to tell Rust I'm not going to change the map |
| 420 | let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect(); |
| 421 | sorted_token_bytes.sort(); |
| 422 | |
| 423 | Ok(CoreBPE { |
| 424 | encoder, |
| 425 | special_tokens_encoder, |
| 426 | decoder, |
| 427 | special_tokens_decoder, |
| 428 | regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(), |
| 429 | special_regex_tls: (0..MAX_NUM_THREADS) |
| 430 | .map(|_| special_regex.clone()) |
| 431 | .collect(), |
| 432 | sorted_token_bytes, |
| 433 | }) |
| 434 | } |
| 435 | |
| 436 | // ==================== |
| 437 | // Encoding |
| 438 | // ==================== |
| 439 | |
| 440 | fn encode_ordinary(&self, py: Python, text: &str) -> Vec<usize> { |
| 441 | py.allow_threads(|| self._encode_ordinary_native(text)) |
| 442 | } |
| 443 | |
| 444 | fn encode(&self, py: Python, text: &str, allowed_special: HashSet<&str>) -> Vec<usize> { |
| 445 | py.allow_threads(|| self._encode_native(text, &allowed_special).0) |
| 446 | } |
| 447 | |
| 448 | fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<usize> { |
| 449 | py.allow_threads(|| { |
| 450 | match std::str::from_utf8(bytes) { |
| 451 | Ok(text) => self._encode_ordinary_native(text), |
| 452 | Err(e) => { |
| 453 | let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) }; |
| 454 | let (tokens, last_piece_token_len) = self._encode_native(text, &HashSet::new()); |
| 455 | let (mut tokens, last_piece_token_len) = |
| 456 | self._increase_last_piece_token_len(tokens, last_piece_token_len); |
| 457 | if !tokens.is_empty() && last_piece_token_len > 0 { |
| 458 | // Lop off the tokens from the last piece and run BPE on the remaining bytes |
| 459 | // Somewhat niche, but this may not be correct if we'd have had a regex |
| 460 | // split between the valid UTF-8 and the invalid bytes, which is why this |
| 461 | // method is private |
| 462 | let mut unstable_bytes = |
| 463 | self._decode_native(&tokens[tokens.len() - last_piece_token_len..]); |
| 464 | unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]); |
| 465 | |
| 466 | tokens.truncate(tokens.len() - last_piece_token_len); |
| 467 | tokens.extend(byte_pair_encode(&unstable_bytes, &self.encoder)); |
| 468 | } |
| 469 | tokens |
| 470 | } |
| 471 | } |
| 472 | }) |
| 473 | } |
| 474 | |
| 475 | fn encode_with_unstable( |
| 476 | &self, |
| 477 | py: Python, |
| 478 | text: &str, |
| 479 | allowed_special: HashSet<&str>, |
| 480 | ) -> Py<PyTuple> { |
| 481 | let (tokens, completions) = |
| 482 | py.allow_threads(|| self._encode_unstable_native(text, &allowed_special)); |
| 483 | let py_completions = |
| 484 | PyList::new(py, completions.iter().map(|seq| PyList::new(py, &seq[..]))); |
| 485 | (tokens, py_completions).into_py(py) |
| 486 | } |
| 487 | |
| 488 | fn encode_single_token(&self, piece: &[u8]) -> PyResult<usize> { |
| 489 | if let Some(token) = self.encoder.get(piece).copied() { |
| 490 | return Ok(token); |
| 491 | } |
| 492 | if let Ok(piece_str) = std::str::from_utf8(piece) { |
| 493 | if let Some(token) = self.special_tokens_encoder.get(piece_str).copied() { |
| 494 | return Ok(token); |
| 495 | } |
| 496 | } |
| 497 | Err(PyErr::new::<exceptions::PyKeyError, _>(piece.to_owned())) |
| 498 | } |
| 499 | |
| 500 | fn encode_single_piece(&self, piece: &[u8]) -> Vec<usize> { |
| 501 | if let Some(token) = self.encoder.get(piece) { |
| 502 | return vec![*token]; |
| 503 | } |
| 504 | byte_pair_encode(piece, &self.encoder) |
| 505 | } |
| 506 | |
| 507 | // ==================== |
| 508 | // Decoding |
| 509 | // ==================== |
| 510 | |
| 511 | fn decode_bytes(&self, py: Python, tokens: Vec<usize>) -> Py<PyBytes> { |
| 512 | let bytes = py.allow_threads(|| self._decode_native(&tokens)); |
| 513 | PyBytes::new(py, &bytes).into() |
| 514 | } |
| 515 | |
| 516 | fn decode_single_token_bytes(&self, py: Python, token: usize) -> PyResult<Py<PyBytes>> { |
| 517 | if let Some(bytes) = self.decoder.get(&token) { |
| 518 | return Ok(PyBytes::new(py, bytes).into()); |
| 519 | } |
| 520 | if let Some(bytes) = self.special_tokens_decoder.get(&token) { |
| 521 | return Ok(PyBytes::new(py, bytes).into()); |
| 522 | } |
| 523 | Err(PyErr::new::<exceptions::PyKeyError, _>(token.to_string())) |
| 524 | } |
| 525 | |
| 526 | // ==================== |
| 527 | // Miscellaneous |
| 528 | // ==================== |
| 529 | |
| 530 | fn token_byte_values(&self, py: Python) -> Vec<Py<PyBytes>> { |
| 531 | self.sorted_token_bytes |
| 532 | .iter() |
| 533 | .map(|x| PyBytes::new(py, x).into()) |
| 534 | .collect() |
| 535 | } |
| 536 | } |
| 537 | |
| 538 | #[pymodule] |
| 539 | fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> { |
| 540 | m.add_class::<CoreBPE>()?; |
| 541 | Ok(()) |
| 542 | } |
| 543 | |
| 544 | #[cfg(test)] |
| 545 | mod tests { |
| 546 | use rustc_hash::FxHashMap as HashMap; |
| 547 | |
| 548 | use crate::byte_pair_split; |
| 549 | |
| 550 | #[test] |
| 551 | fn very_simple_test() { |
| 552 | let mut ranks = HashMap::default(); |
| 553 | ranks.insert(b"ab".to_vec(), 1); |
| 554 | ranks.insert(b"cd".to_vec(), 2); |
| 555 | |
| 556 | let res = byte_pair_split(b"abcd", &ranks); |
| 557 | assert_eq!(res, vec![b"ab", b"cd"]); |
| 558 | } |
| 559 | } |
| 560 | |