openai/tiktoken

Public

mirrored fromhttps://github.com/openai/tiktokenAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
0.8.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

src/lib.rs

703lines · modecode

1// This check is new and seems buggy (possibly with PyO3 interaction)
2#![allow(clippy::borrow_deref_ref)]
3
4use std::collections::HashSet;
5use std::num::NonZeroU64;
6use std::thread;
7
8use fancy_regex::Regex;
9use pyo3::exceptions;
10use pyo3::prelude::*;
11use pyo3::pybacked::PyBackedStr;
12use pyo3::types::{PyBytes, PyList, PyTuple};
13use pyo3::PyResult;
14use rustc_hash::FxHashMap as HashMap;
15
16type Rank = u32;
17
18fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
19 // This is a vector of (start, rank).
20 // The rank is of the pair starting at position start.
21 let mut parts = Vec::with_capacity(piece.len() + 1);
22
23 // Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE
24 // the way we currently do, this is equivalent. An easy way to break this would be to decouple
25 // merge priority from token index or to prevent specific token merges.
26 let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
27 for i in 0..piece.len() - 1 {
28 let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
29 if rank < min_rank.0 {
30 min_rank = (rank, i);
31 }
32 parts.push((i, rank));
33 }
34 parts.push((piece.len() - 1, Rank::MAX));
35 parts.push((piece.len(), Rank::MAX));
36
37 let get_rank = {
38 #[inline(always)]
39 |parts: &Vec<(usize, Rank)>, i: usize| {
40 if (i + 3) < parts.len() {
41 // Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted
42 // parts[i + 1], see comment in the main loop.
43 *ranks
44 .get(&piece[parts[i].0..parts[i + 3].0])
45 .unwrap_or(&Rank::MAX)
46 } else {
47 Rank::MAX
48 }
49 }
50 };
51
52 // If you have n parts and m merges, this does O(mn) work.
53 // We could do something with a heap and do O(m log n) work.
54 // n is often very small so considerations like cache-locality outweigh the algorithmic
55 // complexity downsides of the `parts` vector.
56 while min_rank.0 != Rank::MAX {
57 let i = min_rank.1;
58 // Update parts[i] and parts[i - 1] before removing parts[i + 1], since
59 // `parts.remove(i + 1)` will thrash the cache.
60 if i > 0 {
61 parts[i - 1].1 = get_rank(&parts, i - 1);
62 }
63 parts[i].1 = get_rank(&parts, i);
64 parts.remove(i + 1);
65
66 min_rank = (Rank::MAX, usize::MAX);
67 for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
68 if rank < min_rank.0 {
69 min_rank = (rank, i);
70 }
71 }
72 }
73 parts
74}
75
76pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
77 if piece.len() == 1 {
78 return vec![ranks[piece]];
79 }
80 _byte_pair_merge(ranks, piece)
81 .windows(2)
82 .map(|part| ranks[&piece[part[0].0..part[1].0]])
83 .collect()
84}
85
86pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<&'a [u8]> {
87 assert!(piece.len() > 1);
88 _byte_pair_merge(ranks, piece)
89 .windows(2)
90 .map(|part| &piece[part[0].0..part[1].0])
91 .collect()
92}
93
94// Various performance notes:
95//
96// Regex
97// =====
98// Most of the time is spent in regex. The easiest way to speed this up is by using less fancy
99// regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than
100// the usual regex we use.
101//
102// However, given that we're using a regex parse-able by `regex`, there isn't much difference
103// between using the `regex` crate and using the `fancy_regex` crate.
104//
105// There is an important interaction between threading, `regex` and `fancy_regex`.
106// When using `fancy_regex`, we hit `regex.find_at`. It turns out that this causes contention on
107// some mutable scratch space inside of `regex`. This absolutely kills performance. When using plain
108// old `regex`, we don't hit this, because `find_iter` has a different code path.
109// Related: https://github.com/rust-lang/regex/blob/master/PERFORMANCE.md
110// Anyway, the way we get around this is with having a (mostly) thread local clone of the regex for
111// each thread.
112//
113// Threading
114// =========
115// I tried using `rayon`. It wasn't really faster than using Python threads and releasing the GIL.
116// So goodbye `rayon`! Let thread count etc be in control of our Python users.
117//
118// Caching
119// =======
120// The reference tokeniser has an lru cache over the equivalent of `byte_pair_encode`.
121// Originally, we had one too! Without it, we were only vaguely faster than Python.
122// I used an RWLock to protect the cache. This didn't seem to hurt single threaded performance
123// noticeably, but it did affect multi-threaded performance. Weirdly, it seemed to affect
124// multi-threaded performance even when I only had readers (maybed I messed something up?).
125// Anyway, I realised that we could get rid of the cache, if we treat the set of tokens as a cache!
126// These are exactly the set or merges that are likely to be hot. And now we don't have to think
127// about interior mutability, memory use, or cloning.
128//
129// Hashing
130// =======
131// We use FxHashMap instead of the standard HashMap. This is maybe like a 5-10% win?
132// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made
133// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster.
134
135pub struct FakeThreadId(NonZeroU64);
136
137fn hash_current_thread() -> usize {
138 // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter
139 // that works great for our use case of avoiding collisions in our array. Unfortunately,
140 // it's private. However, there are only so many ways you can layout a u64, so just transmute
141 // https://github.com/rust-lang/rust/issues/67939
142 const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()];
143 const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
144 let x = unsafe {
145 std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0
146 };
147 u64::from(x) as usize
148}
149
150#[derive(Debug, Clone)]
151struct DecodeKeyError {
152 token: Rank,
153}
154
155impl std::fmt::Display for DecodeKeyError {
156 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
157 write!(f, "Invalid token for decoding: {}", self.token)
158 }
159}
160
161const MAX_NUM_THREADS: usize = 128;
162
163#[pyclass]
164struct CoreBPE {
165 encoder: HashMap<Vec<u8>, Rank>,
166 special_tokens_encoder: HashMap<String, Rank>,
167 decoder: HashMap<Rank, Vec<u8>>,
168 special_tokens_decoder: HashMap<Rank, Vec<u8>>,
169 regex_tls: Vec<Regex>,
170 special_regex_tls: Vec<Regex>,
171 sorted_token_bytes: Vec<Vec<u8>>,
172}
173
174impl CoreBPE {
175 fn _get_tl_regex(&self) -> &Regex {
176 // See performance notes above for what this is about
177 // It's also a little janky, please make a better version of it!
178 // However, it's nice that this doesn't leak memory to short-lived threads
179 &self.regex_tls[hash_current_thread() % MAX_NUM_THREADS]
180 }
181
182 fn _get_tl_special_regex(&self) -> &Regex {
183 &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS]
184 }
185
186 fn _decode_native(&self, tokens: &[Rank]) -> Result<Vec<u8>, DecodeKeyError> {
187 let mut ret = Vec::with_capacity(tokens.len() * 2);
188 for &token in tokens {
189 let token_bytes = match self.decoder.get(&token) {
190 Some(bytes) => bytes,
191 None => self
192 .special_tokens_decoder
193 .get(&token)
194 .ok_or(DecodeKeyError { token })?,
195 };
196 ret.extend(token_bytes);
197 }
198 Ok(ret)
199 }
200
201 fn _encode_ordinary_native(&self, text: &str) -> Vec<Rank> {
202 // This is the core of the encoding logic; the other functions in here
203 // just make things complicated :-)
204 let regex = self._get_tl_regex();
205 let mut ret = vec![];
206 for mat in regex.find_iter(text) {
207 let piece = mat.unwrap().as_str().as_bytes();
208 match self.encoder.get(piece) {
209 Some(token) => ret.push(*token),
210 None => ret.extend(&byte_pair_encode(piece, &self.encoder)),
211 }
212 }
213 ret
214 }
215
216 fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<Rank>, usize) {
217 let special_regex = self._get_tl_special_regex();
218 let regex = self._get_tl_regex();
219 let mut ret = vec![];
220
221 let mut start = 0;
222 let mut last_piece_token_len = 0;
223 loop {
224 let mut next_special;
225 let mut start_find = start;
226 loop {
227 // Find the next allowed special token, if any
228 next_special = special_regex.find_from_pos(text, start_find).unwrap();
229 match next_special {
230 Some(m) => {
231 if allowed_special.contains(&text[m.start()..m.end()]) {
232 break;
233 }
234 start_find = m.start() + 1;
235 }
236 None => break,
237 }
238 }
239 let end = next_special.map_or(text.len(), |m| m.start());
240
241 // Okay, here we go, compare this logic to _encode_ordinary_native
242 for mat in regex.find_iter(&text[start..end]) {
243 let piece = mat.unwrap().as_str().as_bytes();
244 if let Some(token) = self.encoder.get(piece) {
245 last_piece_token_len = 1;
246 ret.push(*token);
247 continue;
248 }
249 let tokens = byte_pair_encode(piece, &self.encoder);
250 last_piece_token_len = tokens.len();
251 ret.extend(&tokens);
252 }
253
254 match next_special {
255 // And here we push the special token
256 Some(m) => {
257 let piece = m.as_str();
258 let token = self.special_tokens_encoder[piece];
259 ret.push(token);
260 start = m.end();
261 last_piece_token_len = 0;
262 }
263 None => break,
264 }
265 }
266
267 // last_piece_token_len is how many tokens came from the last regex split. This is used
268 // for determining unstable tokens, since you can't merge across (stable) regex splits
269 (ret, last_piece_token_len)
270 }
271
272 fn _increase_last_piece_token_len(
273 &self,
274 tokens: Vec<Rank>,
275 mut last_piece_token_len: usize,
276 ) -> (Vec<Rank>, usize) {
277 // Unfortunately, the locations where our regex splits can be unstable.
278 // For the purposes of determining unstable tokens, unstable regex splitting
279 // is only a problem if a split that was present disappears, since this can
280 // lead to merging of tokens otherwise thought to be stable.
281 // cl100k_base makes our life hard by including the \s*[\r\n]+
282 // pattern. This can e.g. cause "\n" + " " to become "\n \n".
283 // Here is a quick and dirty fix:
284 {
285 let token_is_all_space = |token| {
286 self.decoder
287 .get(token)
288 .map(|token_bytes| {
289 token_bytes
290 .iter()
291 .rev()
292 .all(|&b| [b' ', b'\n', b'\t'].contains(&b))
293 })
294 .unwrap_or(false)
295 };
296 if last_piece_token_len > 0
297 && token_is_all_space(&tokens[tokens.len() - last_piece_token_len])
298 {
299 while (last_piece_token_len < tokens.len())
300 && token_is_all_space(&tokens[tokens.len() - last_piece_token_len - 1])
301 {
302 last_piece_token_len += 1;
303 }
304 }
305 }
306 debug_assert!(last_piece_token_len <= tokens.len());
307
308 (tokens, last_piece_token_len)
309 }
310
311 fn _encode_unstable_native(
312 &self,
313 text: &str,
314 allowed_special: &HashSet<&str>,
315 ) -> (Vec<Rank>, HashSet<Vec<Rank>>) {
316 let (tokens, last_piece_token_len) = self._encode_native(text, allowed_special);
317 if last_piece_token_len == 0 {
318 // If last_piece_token_len is zero, the last token was a special token and we have
319 // no unstable bytes
320 return (tokens, HashSet::new());
321 }
322 let (mut tokens, last_piece_token_len) =
323 self._increase_last_piece_token_len(tokens, last_piece_token_len);
324
325 let unstable_bytes = self
326 ._decode_native(&tokens[tokens.len() - last_piece_token_len..])
327 .unwrap();
328 tokens.truncate(tokens.len() - last_piece_token_len);
329
330 // TODO: we should try harder to find additional stable tokens
331 // This would reduce the amount of retokenising when determining completions
332 // Refer to the logic in an older version of this file
333
334 let mut completions = HashSet::new();
335 if unstable_bytes.is_empty() {
336 return (tokens, completions);
337 }
338
339 // This is the easy bit. Just find all single tokens that start with unstable_bytes
340 // (including tokens that exactly match unstable_bytes)
341 // Separating this from the loop below helps with performance in a common case.
342 let mut point = self
343 .sorted_token_bytes
344 .partition_point(|x| x.as_slice() < unstable_bytes.as_slice());
345 while point < self.sorted_token_bytes.len()
346 && self.sorted_token_bytes[point].starts_with(&unstable_bytes)
347 {
348 completions.insert(vec![
349 self.encoder[self.sorted_token_bytes[point].as_slice()],
350 ]);
351 point += 1;
352 }
353
354 // Now apply even more brute force. At every (other) possible position for the straddling
355 // token, concatenate additional bytes from that token (if any) to unstable_bytes,
356 // and retokenise the whole thing and see what we get.
357 for i in 1..unstable_bytes.len() {
358 let prefix = &unstable_bytes[..i];
359 let suffix = &unstable_bytes[i..];
360 let mut point = self
361 .sorted_token_bytes
362 .partition_point(|x| x.as_slice() < suffix);
363 // TODO: Perf optimisation if suffix starts with " "?
364 while point < self.sorted_token_bytes.len()
365 && self.sorted_token_bytes[point].starts_with(suffix)
366 {
367 let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat();
368 let encoded = match std::str::from_utf8(&possibility) {
369 // Morally, this is byte_pair_encode(&possibility, &self.encoder)
370 // But we might have introduced a regex split which would prevent merges.
371 // (particularly possible in the presence of unstable regex splits)
372 // So convert to UTF-8 and do regex splitting.
373 // E.g. with cl100k_base " !" gets split to " " + " !",
374 // but byte_pair_encode(" !") != byte_pair_encode(" ")
375 Ok(s) => self._encode_ordinary_native(s),
376
377 // Technically, whether or not this arm is correct depends on whether there
378 // would be a regex split before the UTF-8 truncation point.
379 // Probably niche enough that no one will ever notice (after all, people didn't
380 // notice all the big holes in the previous unstable token implementation)
381 Err(_) => byte_pair_encode(&possibility, &self.encoder),
382 // Something like the following is intriguing but incorrect:
383 // Err(e) => self._encode_ordinary_native(unsafe {
384 // std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()])
385 // }),
386 };
387 let mut seq = Vec::new();
388 let mut seq_len = 0;
389 for token in encoded {
390 seq.push(token);
391 seq_len += self.decoder[&token].len();
392 if seq_len >= unstable_bytes.len() {
393 break;
394 }
395 }
396 completions.insert(seq);
397 point += 1;
398 }
399 }
400
401 // This is also not straightforward. While we generally assume that regex splits are stable,
402 // unfortunately, they are not. That is, if adding bytes were to make a split appear in
403 // unstable_bytes, this could make tokens possible which our logic would otherwise think
404 // would be merged.
405 // For example, with gpt2, the use of \s+(?!\S) means that "\n\n" could
406 // develop a split, e.g. "\n\n0" splits into "\n"+"\n"+"0", making "\n" a possible token.
407 // Here is a quick and dirty fix:
408 // This isn't right if we ever remove \s+(?!\S)
409 if unstable_bytes.len() > 1 {
410 let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice());
411 if unstable_bytes.len() - last_decoded.1 > 0
412 && last_decoded.0.map_or(false, |c| c.is_whitespace())
413 {
414 let mut reencoded = byte_pair_encode(
415 &unstable_bytes[..unstable_bytes.len() - last_decoded.1],
416 &self.encoder,
417 );
418 reencoded.extend(byte_pair_encode(
419 &unstable_bytes[unstable_bytes.len() - last_decoded.1..],
420 &self.encoder,
421 ));
422 completions.insert(reencoded);
423 }
424 }
425
426 (tokens, completions)
427 }
428}
429
430#[pymethods]
431impl CoreBPE {
432 #[new]
433 fn new(
434 encoder: HashMap<Vec<u8>, Rank>,
435 special_tokens_encoder: HashMap<String, Rank>,
436 pattern: &str,
437 ) -> PyResult<Self> {
438 let regex = Regex::new(pattern)
439 .map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?;
440
441 let special_regex = {
442 let _parts = special_tokens_encoder
443 .keys()
444 .map(|s| fancy_regex::escape(s))
445 .collect::<Vec<_>>();
446 Regex::new(&_parts.join("|"))
447 .map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?
448 };
449
450 let decoder: HashMap<Rank, Vec<u8>> =
451 encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
452
453 assert!(
454 encoder.len() == decoder.len(),
455 "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
456 );
457
458 let special_tokens_decoder: HashMap<Rank, Vec<u8>> = special_tokens_encoder
459 .iter()
460 .map(|(k, v)| (*v, k.as_bytes().to_vec()))
461 .collect();
462
463 // Clone because I don't know how to tell Rust I'm not going to change the map
464 let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
465 sorted_token_bytes.sort();
466
467 Ok(CoreBPE {
468 encoder,
469 special_tokens_encoder,
470 decoder,
471 special_tokens_decoder,
472 regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
473 special_regex_tls: (0..MAX_NUM_THREADS)
474 .map(|_| special_regex.clone())
475 .collect(),
476 sorted_token_bytes,
477 })
478 }
479
480 // ====================
481 // Encoding
482 // ====================
483
484 fn encode_ordinary(&self, py: Python, text: &str) -> Vec<Rank> {
485 py.allow_threads(|| self._encode_ordinary_native(text))
486 }
487
488 fn encode(&self, py: Python, text: &str, allowed_special: HashSet<PyBackedStr>) -> Vec<Rank> {
489 py.allow_threads(|| {
490 let allowed_special: HashSet<&str> =
491 allowed_special.iter().map(|s| s.as_ref()).collect();
492 self._encode_native(text, &allowed_special).0
493 })
494 }
495
496 fn encode_to_tiktoken_buffer(
497 &self,
498 py: Python,
499 text: &str,
500 allowed_special: HashSet<PyBackedStr>,
501 ) -> Py<PyAny> {
502 let tokens = py.allow_threads(|| {
503 let allowed_special: HashSet<&str> =
504 allowed_special.iter().map(|s| s.as_ref()).collect();
505 self._encode_native(text, &allowed_special).0
506 });
507 let buffer = TiktokenBuffer { tokens };
508 buffer.into_py(py)
509 }
510
511 fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<Rank> {
512 py.allow_threads(|| {
513 match std::str::from_utf8(bytes) {
514 Ok(text) => self._encode_ordinary_native(text),
515 Err(e) => {
516 let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) };
517 let (tokens, last_piece_token_len) = self._encode_native(text, &HashSet::new());
518 let (mut tokens, last_piece_token_len) =
519 self._increase_last_piece_token_len(tokens, last_piece_token_len);
520 if !tokens.is_empty() && last_piece_token_len > 0 {
521 // Lop off the tokens from the last piece and run BPE on the remaining bytes
522 // Somewhat niche, but this may not be correct if we'd have had a regex
523 // split between the valid UTF-8 and the invalid bytes, which is why this
524 // method is private
525 let mut unstable_bytes = self
526 ._decode_native(&tokens[tokens.len() - last_piece_token_len..])
527 .unwrap();
528 unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]);
529
530 tokens.truncate(tokens.len() - last_piece_token_len);
531 match self.encoder.get(&unstable_bytes) {
532 Some(token) => tokens.push(*token),
533 None => {
534 tokens.extend(&byte_pair_encode(&unstable_bytes, &self.encoder))
535 }
536 }
537 }
538 tokens
539 }
540 }
541 })
542 }
543
544 fn encode_with_unstable(
545 &self,
546 py: Python,
547 text: &str,
548 allowed_special: HashSet<PyBackedStr>,
549 ) -> Py<PyTuple> {
550 let (tokens, completions) = py.allow_threads(|| {
551 let allowed_special: HashSet<&str> =
552 allowed_special.iter().map(|s| s.as_ref()).collect();
553 self._encode_unstable_native(text, &allowed_special)
554 });
555 let py_completions = PyList::new_bound(
556 py,
557 completions
558 .iter()
559 .map(|seq| PyList::new_bound(py, &seq[..])),
560 );
561 (tokens, py_completions).into_py(py)
562 }
563
564 fn encode_single_token(&self, piece: &[u8]) -> PyResult<Rank> {
565 if let Some(token) = self.encoder.get(piece).copied() {
566 return Ok(token);
567 }
568 if let Ok(piece_str) = std::str::from_utf8(piece) {
569 if let Some(token) = self.special_tokens_encoder.get(piece_str).copied() {
570 return Ok(token);
571 }
572 }
573 Err(PyErr::new::<exceptions::PyKeyError, _>(piece.to_owned()))
574 }
575
576 fn encode_single_piece(&self, piece: &[u8]) -> Vec<Rank> {
577 if let Some(token) = self.encoder.get(piece) {
578 return vec![*token];
579 }
580 byte_pair_encode(piece, &self.encoder)
581 }
582
583 // ====================
584 // Decoding
585 // ====================
586
587 fn decode_bytes(&self, py: Python, tokens: Vec<Rank>) -> Result<Py<PyBytes>, PyErr> {
588 match py.allow_threads(|| self._decode_native(&tokens)) {
589 Ok(bytes) => Ok(PyBytes::new_bound(py, &bytes).into()),
590 Err(e) => Err(pyo3::exceptions::PyKeyError::new_err(format!("{}", e))),
591 }
592 }
593
594 fn decode_single_token_bytes(&self, py: Python, token: Rank) -> PyResult<Py<PyBytes>> {
595 if let Some(bytes) = self.decoder.get(&token) {
596 return Ok(PyBytes::new_bound(py, bytes).into());
597 }
598 if let Some(bytes) = self.special_tokens_decoder.get(&token) {
599 return Ok(PyBytes::new_bound(py, bytes).into());
600 }
601 Err(PyErr::new::<exceptions::PyKeyError, _>(token.to_string()))
602 }
603
604 // ====================
605 // Miscellaneous
606 // ====================
607
608 fn token_byte_values(&self, py: Python) -> Vec<Py<PyBytes>> {
609 self.sorted_token_bytes
610 .iter()
611 .map(|x| PyBytes::new_bound(py, x).into())
612 .collect()
613 }
614}
615
616#[pyclass]
617struct TiktokenBuffer {
618 tokens: Vec<Rank>,
619}
620
621#[pymethods]
622impl TiktokenBuffer {
623 // Based on https://github.com/PyO3/pyo3/blob/v0.22.2/tests/test_buffer_protocol.rs#L25
624 unsafe fn __getbuffer__(
625 slf: Bound<'_, Self>,
626 view: *mut pyo3::ffi::Py_buffer,
627 flags: std::os::raw::c_int,
628 ) -> PyResult<()> {
629 if view.is_null() {
630 return Err(pyo3::exceptions::PyBufferError::new_err("View is null"));
631 }
632 if (flags & pyo3::ffi::PyBUF_WRITABLE) == pyo3::ffi::PyBUF_WRITABLE {
633 return Err(pyo3::exceptions::PyBufferError::new_err(
634 "Object is not writable",
635 ));
636 }
637
638 (*view).obj = slf.clone().into_any().into_ptr();
639
640 let data = &slf.borrow().tokens;
641 (*view).buf = data.as_ptr() as *mut std::os::raw::c_void;
642 (*view).len = (data.len() * std::mem::size_of::<Rank>()) as isize;
643 (*view).readonly = 1;
644 (*view).itemsize = std::mem::size_of::<Rank>() as isize;
645 (*view).format = if (flags & pyo3::ffi::PyBUF_FORMAT) == pyo3::ffi::PyBUF_FORMAT {
646 let msg = std::ffi::CString::new("I").unwrap();
647 msg.into_raw()
648 } else {
649 std::ptr::null_mut()
650 };
651 (*view).ndim = 1;
652 (*view).shape = if (flags & pyo3::ffi::PyBUF_ND) == pyo3::ffi::PyBUF_ND {
653 &mut (*view).len
654 } else {
655 std::ptr::null_mut()
656 };
657 (*view).strides = if (flags & pyo3::ffi::PyBUF_STRIDES) == pyo3::ffi::PyBUF_STRIDES {
658 &mut (*view).itemsize
659 } else {
660 std::ptr::null_mut()
661 };
662 (*view).suboffsets = std::ptr::null_mut();
663 (*view).internal = std::ptr::null_mut();
664
665 Ok(())
666 }
667
668 unsafe fn __releasebuffer__(&self, view: *mut pyo3::ffi::Py_buffer) {
669 std::mem::drop(std::ffi::CString::from_raw((*view).format));
670 }
671}
672
673#[pymodule]
674fn _tiktoken(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
675 m.add_class::<CoreBPE>()?;
676 Ok(())
677}
678
679#[cfg(test)]
680mod tests {
681
682 use rustc_hash::FxHashMap as HashMap;
683
684 use crate::{byte_pair_split, Rank};
685
686 fn setup_ranks() -> HashMap<Vec<u8>, Rank> {
687 HashMap::from_iter([(b"ab".to_vec(), 0), (b"cd".to_vec(), 1)])
688 }
689
690 #[test]
691 fn test_simple_characters() {
692 let ranks = setup_ranks();
693 let res = byte_pair_split(b"abcd", &ranks);
694 assert_eq!(res, vec![b"ab", b"cd"]);
695 }
696
697 #[test]
698 fn test_repeated_characters() {
699 let ranks = setup_ranks();
700 let res = byte_pair_split(b"abab", &ranks);
701 assert_eq!(res, vec![b"ab", b"ab"]);
702 }
703}
704