openai/tiktoken
Publicmirrored fromhttps://github.com/openai/tiktokenAvailable
src/py.rs
255lines · modecode
| 1 | use std::collections::HashSet; |
| 2 | |
| 3 | use pyo3::{ |
| 4 | IntoPyObjectExt, PyResult, exceptions, |
| 5 | prelude::*, |
| 6 | pybacked::PyBackedStr, |
| 7 | types::{PyBytes, PyList}, |
| 8 | }; |
| 9 | use rustc_hash::FxHashMap as HashMap; |
| 10 | |
| 11 | use crate::{CoreBPE, Rank, byte_pair_encode}; |
| 12 | |
| 13 | #[pymethods] |
| 14 | impl CoreBPE { |
| 15 | #[new] |
| 16 | fn py_new( |
| 17 | encoder: HashMap<Vec<u8>, Rank>, |
| 18 | special_tokens_encoder: HashMap<String, Rank>, |
| 19 | pattern: &str, |
| 20 | ) -> PyResult<Self> { |
| 21 | Self::new_internal(encoder, special_tokens_encoder, pattern) |
| 22 | .map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string())) |
| 23 | } |
| 24 | |
| 25 | // ==================== |
| 26 | // Encoding |
| 27 | // ==================== |
| 28 | |
| 29 | #[pyo3(name = "encode_ordinary")] |
| 30 | fn py_encode_ordinary(&self, py: Python, text: &str) -> Vec<Rank> { |
| 31 | py.allow_threads(|| self.encode_ordinary(text)) |
| 32 | } |
| 33 | |
| 34 | #[pyo3(name = "encode")] |
| 35 | fn py_encode( |
| 36 | &self, |
| 37 | py: Python, |
| 38 | text: &str, |
| 39 | allowed_special: HashSet<PyBackedStr>, |
| 40 | ) -> PyResult<Vec<Rank>> { |
| 41 | py.allow_threads(|| { |
| 42 | let allowed_special: HashSet<&str> = |
| 43 | allowed_special.iter().map(|s| s.as_ref()).collect(); |
| 44 | match self.encode(text, &allowed_special) { |
| 45 | Ok((tokens, _)) => Ok(tokens), |
| 46 | Err(e) => Err(PyErr::new::<exceptions::PyValueError, _>(e.message)), |
| 47 | } |
| 48 | }) |
| 49 | } |
| 50 | |
| 51 | fn encode_to_tiktoken_buffer( |
| 52 | &self, |
| 53 | py: Python, |
| 54 | text: &str, |
| 55 | allowed_special: HashSet<PyBackedStr>, |
| 56 | ) -> PyResult<Py<PyAny>> { |
| 57 | let tokens_res = py.allow_threads(|| { |
| 58 | let allowed_special: HashSet<&str> = |
| 59 | allowed_special.iter().map(|s| s.as_ref()).collect(); |
| 60 | self.encode(text, &allowed_special) |
| 61 | }); |
| 62 | |
| 63 | let tokens = match tokens_res { |
| 64 | Ok((tokens, _)) => tokens, |
| 65 | Err(e) => return Err(PyErr::new::<exceptions::PyValueError, _>(e.message)), |
| 66 | }; |
| 67 | |
| 68 | let buffer = TiktokenBuffer { tokens }; |
| 69 | buffer.into_py_any(py) |
| 70 | } |
| 71 | |
| 72 | fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<Rank> { |
| 73 | py.allow_threads(|| { |
| 74 | match std::str::from_utf8(bytes) { |
| 75 | // Straightforward case |
| 76 | Ok(text) => self.encode_ordinary(text), |
| 77 | // Oops, don't actually have UTF-8. But we need to do the regex splitting in |
| 78 | // Unicode space, so we make our best guess at where we would have splits |
| 79 | Err(e) => { |
| 80 | let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) }; |
| 81 | let (tokens, last_piece_token_len) = |
| 82 | self.encode(text, &HashSet::new()).unwrap(); |
| 83 | let (mut tokens, last_piece_token_len) = |
| 84 | self._increase_last_piece_token_len(tokens, last_piece_token_len); |
| 85 | |
| 86 | let mut unstable_bytes; |
| 87 | if !tokens.is_empty() && last_piece_token_len > 0 { |
| 88 | // Lop off the tokens from the last piece and run BPE on the remaining bytes |
| 89 | // This likely matches what models see better, e.g. if you assume we're |
| 90 | // dealing with truncated UTF-8 bytes. |
| 91 | // Niche, but note this may not be correct if we'd have had a regex |
| 92 | // split between the valid UTF-8 and the invalid bytes. |
| 93 | unstable_bytes = self |
| 94 | .decode_bytes(&tokens[tokens.len() - last_piece_token_len..]) |
| 95 | .unwrap(); |
| 96 | unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]); |
| 97 | |
| 98 | tokens.truncate(tokens.len() - last_piece_token_len); |
| 99 | } else { |
| 100 | unstable_bytes = bytes[e.valid_up_to()..].to_vec(); |
| 101 | } |
| 102 | |
| 103 | if !unstable_bytes.is_empty() { |
| 104 | match self.encoder.get(&unstable_bytes) { |
| 105 | Some(token) => tokens.push(*token), |
| 106 | None => { |
| 107 | tokens.extend(&byte_pair_encode(&unstable_bytes, &self.encoder)) |
| 108 | } |
| 109 | } |
| 110 | } |
| 111 | tokens |
| 112 | } |
| 113 | } |
| 114 | }) |
| 115 | } |
| 116 | |
| 117 | #[pyo3(name = "encode_with_unstable")] |
| 118 | fn py_encode_with_unstable( |
| 119 | &self, |
| 120 | py: Python, |
| 121 | text: &str, |
| 122 | allowed_special: HashSet<PyBackedStr>, |
| 123 | ) -> PyResult<(Vec<Rank>, Py<PyList>)> { |
| 124 | let (tokens, completions): (Vec<Rank>, HashSet<Vec<Rank>>) = py.allow_threads(|| { |
| 125 | let allowed_special: HashSet<&str> = |
| 126 | allowed_special.iter().map(|s| s.as_ref()).collect(); |
| 127 | self._encode_unstable_native(text, &allowed_special) |
| 128 | }); |
| 129 | let py_completions = PyList::new(py, completions.into_iter())?; |
| 130 | Ok((tokens, py_completions.into())) |
| 131 | } |
| 132 | |
| 133 | fn encode_single_token(&self, piece: &[u8]) -> PyResult<Rank> { |
| 134 | if let Some(token) = self.encoder.get(piece).copied() { |
| 135 | return Ok(token); |
| 136 | } |
| 137 | if let Ok(piece_str) = std::str::from_utf8(piece) { |
| 138 | if let Some(token) = self.special_tokens_encoder.get(piece_str).copied() { |
| 139 | return Ok(token); |
| 140 | } |
| 141 | } |
| 142 | Err(PyErr::new::<exceptions::PyKeyError, _>(piece.to_owned())) |
| 143 | } |
| 144 | |
| 145 | fn encode_single_piece(&self, piece: &[u8]) -> Vec<Rank> { |
| 146 | if let Some(token) = self.encoder.get(piece) { |
| 147 | return vec![*token]; |
| 148 | } |
| 149 | byte_pair_encode(piece, &self.encoder) |
| 150 | } |
| 151 | |
| 152 | // ==================== |
| 153 | // Decoding |
| 154 | // ==================== |
| 155 | |
| 156 | #[pyo3(name = "decode_bytes")] |
| 157 | fn py_decode_bytes(&self, py: Python, tokens: Vec<Rank>) -> Result<Py<PyBytes>, PyErr> { |
| 158 | match py.allow_threads(|| self.decode_bytes(&tokens)) { |
| 159 | Ok(bytes) => Ok(PyBytes::new(py, &bytes).into()), |
| 160 | Err(e) => Err(pyo3::exceptions::PyKeyError::new_err(format!("{}", e))), |
| 161 | } |
| 162 | } |
| 163 | |
| 164 | fn decode_single_token_bytes(&self, py: Python, token: Rank) -> PyResult<Py<PyBytes>> { |
| 165 | if let Some(bytes) = self.decoder.get(&token) { |
| 166 | return Ok(PyBytes::new(py, bytes).into()); |
| 167 | } |
| 168 | if let Some(bytes) = self.special_tokens_decoder.get(&token) { |
| 169 | return Ok(PyBytes::new(py, bytes).into()); |
| 170 | } |
| 171 | Err(PyErr::new::<exceptions::PyKeyError, _>(token.to_string())) |
| 172 | } |
| 173 | |
| 174 | // ==================== |
| 175 | // Miscellaneous |
| 176 | // ==================== |
| 177 | |
| 178 | fn token_byte_values(&self, py: Python) -> Vec<Py<PyBytes>> { |
| 179 | self.sorted_token_bytes |
| 180 | .iter() |
| 181 | .map(|x| PyBytes::new(py, x).into()) |
| 182 | .collect() |
| 183 | } |
| 184 | } |
| 185 | |
| 186 | #[pyclass] |
| 187 | struct TiktokenBuffer { |
| 188 | tokens: Vec<Rank>, |
| 189 | } |
| 190 | |
| 191 | #[pymethods] |
| 192 | impl TiktokenBuffer { |
| 193 | // Based on https://github.com/PyO3/pyo3/blob/v0.22.2/tests/test_buffer_protocol.rs#L25 |
| 194 | unsafe fn __getbuffer__( |
| 195 | slf: Bound<'_, Self>, |
| 196 | view: *mut pyo3::ffi::Py_buffer, |
| 197 | flags: std::os::raw::c_int, |
| 198 | ) -> PyResult<()> { |
| 199 | if view.is_null() { |
| 200 | return Err(pyo3::exceptions::PyBufferError::new_err("View is null")); |
| 201 | } |
| 202 | if (flags & pyo3::ffi::PyBUF_WRITABLE) == pyo3::ffi::PyBUF_WRITABLE { |
| 203 | return Err(pyo3::exceptions::PyBufferError::new_err( |
| 204 | "Object is not writable", |
| 205 | )); |
| 206 | } |
| 207 | unsafe { |
| 208 | let view_ref = &mut *view; |
| 209 | view_ref.obj = slf.clone().into_any().into_ptr(); |
| 210 | |
| 211 | let data = &slf.borrow().tokens; |
| 212 | view_ref.buf = data.as_ptr() as *mut std::os::raw::c_void; |
| 213 | view_ref.len = (data.len() * std::mem::size_of::<Rank>()) as isize; |
| 214 | view_ref.readonly = 1; |
| 215 | view_ref.itemsize = std::mem::size_of::<Rank>() as isize; |
| 216 | view_ref.format = if (flags & pyo3::ffi::PyBUF_FORMAT) == pyo3::ffi::PyBUF_FORMAT { |
| 217 | let msg = std::ffi::CString::new("I").unwrap(); |
| 218 | msg.into_raw() |
| 219 | } else { |
| 220 | std::ptr::null_mut() |
| 221 | }; |
| 222 | view_ref.ndim = 1; |
| 223 | view_ref.shape = if (flags & pyo3::ffi::PyBUF_ND) == pyo3::ffi::PyBUF_ND { |
| 224 | &mut view_ref.len |
| 225 | } else { |
| 226 | std::ptr::null_mut() |
| 227 | }; |
| 228 | view_ref.strides = if (flags & pyo3::ffi::PyBUF_STRIDES) == pyo3::ffi::PyBUF_STRIDES { |
| 229 | &mut view_ref.itemsize |
| 230 | } else { |
| 231 | std::ptr::null_mut() |
| 232 | }; |
| 233 | view_ref.suboffsets = std::ptr::null_mut(); |
| 234 | view_ref.internal = std::ptr::null_mut(); |
| 235 | } |
| 236 | |
| 237 | Ok(()) |
| 238 | } |
| 239 | |
| 240 | unsafe fn __releasebuffer__(&self, view: *mut pyo3::ffi::Py_buffer) { |
| 241 | // Note that Py_buffer doesn't have a Drop impl |
| 242 | unsafe { |
| 243 | let view_ref = &mut *view; |
| 244 | if !view_ref.format.is_null() { |
| 245 | std::mem::drop(std::ffi::CString::from_raw(view_ref.format)); |
| 246 | } |
| 247 | } |
| 248 | } |
| 249 | } |
| 250 | |
| 251 | #[pymodule] |
| 252 | fn _tiktoken(_py: Python, m: &Bound<PyModule>) -> PyResult<()> { |
| 253 | m.add_class::<CoreBPE>()?; |
| 254 | Ok(()) |
| 255 | } |
| 256 | |