openai/tiktoken
Publicmirrored from https://github.com/openai/tiktokenAvailable
src/py.rs
243lines · modecode
| 1 | use std::collections::HashSet; |
| 2 | |
| 3 | use pyo3::{ |
| 4 | exceptions, |
| 5 | prelude::*, |
| 6 | pybacked::PyBackedStr, |
| 7 | types::{PyBytes, PyList, PyTuple}, |
| 8 | PyResult, |
| 9 | }; |
| 10 | use rustc_hash::FxHashMap as HashMap; |
| 11 | |
| 12 | use crate::{byte_pair_encode, CoreBPE, Rank}; |
| 13 | |
| 14 | #[pymethods] |
| 15 | impl CoreBPE { |
| 16 | #[new] |
| 17 | fn py_new( |
| 18 | encoder: HashMap<Vec<u8>, Rank>, |
| 19 | special_tokens_encoder: HashMap<String, Rank>, |
| 20 | pattern: &str, |
| 21 | ) -> PyResult<Self> { |
| 22 | Self::new_internal(encoder, special_tokens_encoder, pattern) |
| 23 | .map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string())) |
| 24 | } |
| 25 | |
| 26 | // ==================== |
| 27 | // Encoding |
| 28 | // ==================== |
| 29 | |
| 30 | #[pyo3(name = "encode_ordinary")] |
| 31 | fn py_encode_ordinary(&self, py: Python, text: &str) -> Vec<Rank> { |
| 32 | py.allow_threads(|| self.encode_ordinary(text)) |
| 33 | } |
| 34 | |
| 35 | #[pyo3(name = "encode")] |
| 36 | fn py_encode( |
| 37 | &self, |
| 38 | py: Python, |
| 39 | text: &str, |
| 40 | allowed_special: HashSet<PyBackedStr>, |
| 41 | ) -> Vec<Rank> { |
| 42 | py.allow_threads(|| { |
| 43 | let allowed_special: HashSet<&str> = |
| 44 | allowed_special.iter().map(|s| s.as_ref()).collect(); |
| 45 | self.encode(text, &allowed_special).0 |
| 46 | }) |
| 47 | } |
| 48 | |
| 49 | fn encode_to_tiktoken_buffer( |
| 50 | &self, |
| 51 | py: Python, |
| 52 | text: &str, |
| 53 | allowed_special: HashSet<PyBackedStr>, |
| 54 | ) -> Py<PyAny> { |
| 55 | let tokens = py.allow_threads(|| { |
| 56 | let allowed_special: HashSet<&str> = |
| 57 | allowed_special.iter().map(|s| s.as_ref()).collect(); |
| 58 | self.encode(text, &allowed_special).0 |
| 59 | }); |
| 60 | let buffer = TiktokenBuffer { tokens }; |
| 61 | buffer.into_py(py) |
| 62 | } |
| 63 | |
| 64 | fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<Rank> { |
| 65 | py.allow_threads(|| { |
| 66 | match std::str::from_utf8(bytes) { |
| 67 | // Straightforward case |
| 68 | Ok(text) => self.encode_ordinary(text), |
| 69 | // Oops, don't actually have UTF-8. But we need to do the regex splitting in |
| 70 | // Unicode space, so we make our best guess at where we would have splits |
| 71 | Err(e) => { |
| 72 | let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) }; |
| 73 | let (tokens, last_piece_token_len) = self.encode(text, &HashSet::new()); |
| 74 | let (mut tokens, last_piece_token_len) = |
| 75 | self._increase_last_piece_token_len(tokens, last_piece_token_len); |
| 76 | |
| 77 | let mut unstable_bytes; |
| 78 | if !tokens.is_empty() && last_piece_token_len > 0 { |
| 79 | // Lop off the tokens from the last piece and run BPE on the remaining bytes |
| 80 | // This likely matches what models see better, e.g. if you assume we're |
| 81 | // dealing with truncated UTF-8 bytes. |
| 82 | // Niche, but note this may not be correct if we'd have had a regex |
| 83 | // split between the valid UTF-8 and the invalid bytes. |
| 84 | unstable_bytes = self |
| 85 | .decode_bytes(&tokens[tokens.len() - last_piece_token_len..]) |
| 86 | .unwrap(); |
| 87 | unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]); |
| 88 | |
| 89 | tokens.truncate(tokens.len() - last_piece_token_len); |
| 90 | } else { |
| 91 | unstable_bytes = bytes[e.valid_up_to()..].to_vec(); |
| 92 | } |
| 93 | |
| 94 | if !unstable_bytes.is_empty() { |
| 95 | match self.encoder.get(&unstable_bytes) { |
| 96 | Some(token) => tokens.push(*token), |
| 97 | None => { |
| 98 | tokens.extend(&byte_pair_encode(&unstable_bytes, &self.encoder)) |
| 99 | } |
| 100 | } |
| 101 | } |
| 102 | tokens |
| 103 | } |
| 104 | } |
| 105 | }) |
| 106 | } |
| 107 | |
| 108 | #[pyo3(name = "encode_with_unstable")] |
| 109 | fn py_encode_with_unstable( |
| 110 | &self, |
| 111 | py: Python, |
| 112 | text: &str, |
| 113 | allowed_special: HashSet<PyBackedStr>, |
| 114 | ) -> Py<PyTuple> { |
| 115 | let (tokens, completions) = py.allow_threads(|| { |
| 116 | let allowed_special: HashSet<&str> = |
| 117 | allowed_special.iter().map(|s| s.as_ref()).collect(); |
| 118 | self._encode_unstable_native(text, &allowed_special) |
| 119 | }); |
| 120 | let py_completions = PyList::new_bound( |
| 121 | py, |
| 122 | completions |
| 123 | .iter() |
| 124 | .map(|seq| PyList::new_bound(py, &seq[..])), |
| 125 | ); |
| 126 | (tokens, py_completions).into_py(py) |
| 127 | } |
| 128 | |
| 129 | fn encode_single_token(&self, piece: &[u8]) -> PyResult<Rank> { |
| 130 | if let Some(token) = self.encoder.get(piece).copied() { |
| 131 | return Ok(token); |
| 132 | } |
| 133 | if let Ok(piece_str) = std::str::from_utf8(piece) { |
| 134 | if let Some(token) = self.special_tokens_encoder.get(piece_str).copied() { |
| 135 | return Ok(token); |
| 136 | } |
| 137 | } |
| 138 | Err(PyErr::new::<exceptions::PyKeyError, _>(piece.to_owned())) |
| 139 | } |
| 140 | |
| 141 | fn encode_single_piece(&self, piece: &[u8]) -> Vec<Rank> { |
| 142 | if let Some(token) = self.encoder.get(piece) { |
| 143 | return vec![*token]; |
| 144 | } |
| 145 | byte_pair_encode(piece, &self.encoder) |
| 146 | } |
| 147 | |
| 148 | // ==================== |
| 149 | // Decoding |
| 150 | // ==================== |
| 151 | |
| 152 | #[pyo3(name = "decode_bytes")] |
| 153 | fn py_decode_bytes(&self, py: Python, tokens: Vec<Rank>) -> Result<Py<PyBytes>, PyErr> { |
| 154 | match py.allow_threads(|| self.decode_bytes(&tokens)) { |
| 155 | Ok(bytes) => Ok(PyBytes::new_bound(py, &bytes).into()), |
| 156 | Err(e) => Err(pyo3::exceptions::PyKeyError::new_err(format!("{}", e))), |
| 157 | } |
| 158 | } |
| 159 | |
| 160 | fn decode_single_token_bytes(&self, py: Python, token: Rank) -> PyResult<Py<PyBytes>> { |
| 161 | if let Some(bytes) = self.decoder.get(&token) { |
| 162 | return Ok(PyBytes::new_bound(py, bytes).into()); |
| 163 | } |
| 164 | if let Some(bytes) = self.special_tokens_decoder.get(&token) { |
| 165 | return Ok(PyBytes::new_bound(py, bytes).into()); |
| 166 | } |
| 167 | Err(PyErr::new::<exceptions::PyKeyError, _>(token.to_string())) |
| 168 | } |
| 169 | |
| 170 | // ==================== |
| 171 | // Miscellaneous |
| 172 | // ==================== |
| 173 | |
| 174 | fn token_byte_values(&self, py: Python) -> Vec<Py<PyBytes>> { |
| 175 | self.sorted_token_bytes |
| 176 | .iter() |
| 177 | .map(|x| PyBytes::new_bound(py, x).into()) |
| 178 | .collect() |
| 179 | } |
| 180 | } |
| 181 | |
| 182 | #[pyclass] |
| 183 | struct TiktokenBuffer { |
| 184 | tokens: Vec<Rank>, |
| 185 | } |
| 186 | |
| 187 | #[pymethods] |
| 188 | impl TiktokenBuffer { |
| 189 | // Based on https://github.com/PyO3/pyo3/blob/v0.22.2/tests/test_buffer_protocol.rs#L25 |
| 190 | unsafe fn __getbuffer__( |
| 191 | slf: Bound<'_, Self>, |
| 192 | view: *mut pyo3::ffi::Py_buffer, |
| 193 | flags: std::os::raw::c_int, |
| 194 | ) -> PyResult<()> { |
| 195 | if view.is_null() { |
| 196 | return Err(pyo3::exceptions::PyBufferError::new_err("View is null")); |
| 197 | } |
| 198 | if (flags & pyo3::ffi::PyBUF_WRITABLE) == pyo3::ffi::PyBUF_WRITABLE { |
| 199 | return Err(pyo3::exceptions::PyBufferError::new_err( |
| 200 | "Object is not writable", |
| 201 | )); |
| 202 | } |
| 203 | |
| 204 | (*view).obj = slf.clone().into_any().into_ptr(); |
| 205 | |
| 206 | let data = &slf.borrow().tokens; |
| 207 | (*view).buf = data.as_ptr() as *mut std::os::raw::c_void; |
| 208 | (*view).len = (data.len() * std::mem::size_of::<Rank>()) as isize; |
| 209 | (*view).readonly = 1; |
| 210 | (*view).itemsize = std::mem::size_of::<Rank>() as isize; |
| 211 | (*view).format = if (flags & pyo3::ffi::PyBUF_FORMAT) == pyo3::ffi::PyBUF_FORMAT { |
| 212 | let msg = std::ffi::CString::new("I").unwrap(); |
| 213 | msg.into_raw() |
| 214 | } else { |
| 215 | std::ptr::null_mut() |
| 216 | }; |
| 217 | (*view).ndim = 1; |
| 218 | (*view).shape = if (flags & pyo3::ffi::PyBUF_ND) == pyo3::ffi::PyBUF_ND { |
| 219 | &mut (*view).len |
| 220 | } else { |
| 221 | std::ptr::null_mut() |
| 222 | }; |
| 223 | (*view).strides = if (flags & pyo3::ffi::PyBUF_STRIDES) == pyo3::ffi::PyBUF_STRIDES { |
| 224 | &mut (*view).itemsize |
| 225 | } else { |
| 226 | std::ptr::null_mut() |
| 227 | }; |
| 228 | (*view).suboffsets = std::ptr::null_mut(); |
| 229 | (*view).internal = std::ptr::null_mut(); |
| 230 | |
| 231 | Ok(()) |
| 232 | } |
| 233 | |
| 234 | unsafe fn __releasebuffer__(&self, view: *mut pyo3::ffi::Py_buffer) { |
| 235 | std::mem::drop(std::ffi::CString::from_raw((*view).format)); |
| 236 | } |
| 237 | } |
| 238 | |
| 239 | #[pymodule] |
| 240 | fn _tiktoken(_py: Python, m: &Bound<PyModule>) -> PyResult<()> { |
| 241 | m.add_class::<CoreBPE>()?; |
| 242 | Ok(()) |
| 243 | } |