openai/tiktoken
Publicmirrored fromhttps://github.com/openai/tiktokenAvailable
src/py.rs
247lines · modecode
| 1 | use std::collections::HashSet; |
| 2 | |
| 3 | use pyo3::{ |
| 4 | exceptions, |
| 5 | prelude::*, |
| 6 | pybacked::PyBackedStr, |
| 7 | types::{PyBytes, PyList, PyTuple}, |
| 8 | PyResult, |
| 9 | }; |
| 10 | use rustc_hash::FxHashMap as HashMap; |
| 11 | |
| 12 | use crate::{byte_pair_encode, CoreBPE, Rank}; |
| 13 | |
| 14 | #[pymethods] |
| 15 | impl CoreBPE { |
| 16 | #[new] |
| 17 | fn py_new( |
| 18 | encoder: HashMap<Vec<u8>, Rank>, |
| 19 | special_tokens_encoder: HashMap<String, Rank>, |
| 20 | pattern: &str, |
| 21 | ) -> PyResult<Self> { |
| 22 | Self::new_internal( |
| 23 | encoder, |
| 24 | special_tokens_encoder, |
| 25 | pattern, |
| 26 | ) |
| 27 | .map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string())) |
| 28 | } |
| 29 | |
| 30 | // ==================== |
| 31 | // Encoding |
| 32 | // ==================== |
| 33 | |
| 34 | #[pyo3(name = "encode_ordinary")] |
| 35 | fn py_encode_ordinary(&self, py: Python, text: &str) -> Vec<Rank> { |
| 36 | py.allow_threads(|| self.encode_ordinary(text)) |
| 37 | } |
| 38 | |
| 39 | #[pyo3(name = "encode")] |
| 40 | fn py_encode( |
| 41 | &self, |
| 42 | py: Python, |
| 43 | text: &str, |
| 44 | allowed_special: HashSet<PyBackedStr>, |
| 45 | ) -> Vec<Rank> { |
| 46 | py.allow_threads(|| { |
| 47 | let allowed_special: HashSet<&str> = |
| 48 | allowed_special.iter().map(|s| s.as_ref()).collect(); |
| 49 | self.encode(text, &allowed_special).0 |
| 50 | }) |
| 51 | } |
| 52 | |
| 53 | fn encode_to_tiktoken_buffer( |
| 54 | &self, |
| 55 | py: Python, |
| 56 | text: &str, |
| 57 | allowed_special: HashSet<PyBackedStr>, |
| 58 | ) -> Py<PyAny> { |
| 59 | let tokens = py.allow_threads(|| { |
| 60 | let allowed_special: HashSet<&str> = |
| 61 | allowed_special.iter().map(|s| s.as_ref()).collect(); |
| 62 | self.encode(text, &allowed_special).0 |
| 63 | }); |
| 64 | let buffer = TiktokenBuffer { tokens }; |
| 65 | buffer.into_py(py) |
| 66 | } |
| 67 | |
| 68 | fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<Rank> { |
| 69 | py.allow_threads(|| { |
| 70 | match std::str::from_utf8(bytes) { |
| 71 | // Straightforward case |
| 72 | Ok(text) => self.encode_ordinary(text), |
| 73 | // Oops, don't actually have UTF-8. But we need to do the regex splitting in |
| 74 | // Unicode space, so we make our best guess at where we would have splits |
| 75 | Err(e) => { |
| 76 | let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) }; |
| 77 | let (tokens, last_piece_token_len) = self.encode(text, &HashSet::new()); |
| 78 | let (mut tokens, last_piece_token_len) = |
| 79 | self._increase_last_piece_token_len(tokens, last_piece_token_len); |
| 80 | |
| 81 | let mut unstable_bytes; |
| 82 | if !tokens.is_empty() && last_piece_token_len > 0 { |
| 83 | // Lop off the tokens from the last piece and run BPE on the remaining bytes |
| 84 | // This likely matches what models see better, e.g. if you assume we're |
| 85 | // dealing with truncated UTF-8 bytes. |
| 86 | // Niche, but note this may not be correct if we'd have had a regex |
| 87 | // split between the valid UTF-8 and the invalid bytes. |
| 88 | unstable_bytes = self |
| 89 | .decode_bytes(&tokens[tokens.len() - last_piece_token_len..]) |
| 90 | .unwrap(); |
| 91 | unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]); |
| 92 | |
| 93 | tokens.truncate(tokens.len() - last_piece_token_len); |
| 94 | } else { |
| 95 | unstable_bytes = bytes[e.valid_up_to()..].to_vec(); |
| 96 | } |
| 97 | |
| 98 | if !unstable_bytes.is_empty() { |
| 99 | match self.encoder.get(&unstable_bytes) { |
| 100 | Some(token) => tokens.push(*token), |
| 101 | None => { |
| 102 | tokens.extend(&byte_pair_encode(&unstable_bytes, &self.encoder)) |
| 103 | } |
| 104 | } |
| 105 | } |
| 106 | tokens |
| 107 | } |
| 108 | } |
| 109 | }) |
| 110 | } |
| 111 | |
| 112 | #[pyo3(name = "encode_with_unstable")] |
| 113 | fn py_encode_with_unstable( |
| 114 | &self, |
| 115 | py: Python, |
| 116 | text: &str, |
| 117 | allowed_special: HashSet<PyBackedStr>, |
| 118 | ) -> Py<PyTuple> { |
| 119 | let (tokens, completions) = py.allow_threads(|| { |
| 120 | let allowed_special: HashSet<&str> = |
| 121 | allowed_special.iter().map(|s| s.as_ref()).collect(); |
| 122 | self._encode_unstable_native(text, &allowed_special) |
| 123 | }); |
| 124 | let py_completions = PyList::new_bound( |
| 125 | py, |
| 126 | completions |
| 127 | .iter() |
| 128 | .map(|seq| PyList::new_bound(py, &seq[..])), |
| 129 | ); |
| 130 | (tokens, py_completions).into_py(py) |
| 131 | } |
| 132 | |
| 133 | fn encode_single_token(&self, piece: &[u8]) -> PyResult<Rank> { |
| 134 | if let Some(token) = self.encoder.get(piece).copied() { |
| 135 | return Ok(token); |
| 136 | } |
| 137 | if let Ok(piece_str) = std::str::from_utf8(piece) { |
| 138 | if let Some(token) = self.special_tokens_encoder.get(piece_str).copied() { |
| 139 | return Ok(token); |
| 140 | } |
| 141 | } |
| 142 | Err(PyErr::new::<exceptions::PyKeyError, _>(piece.to_owned())) |
| 143 | } |
| 144 | |
| 145 | fn encode_single_piece(&self, piece: &[u8]) -> Vec<Rank> { |
| 146 | if let Some(token) = self.encoder.get(piece) { |
| 147 | return vec![*token]; |
| 148 | } |
| 149 | byte_pair_encode(piece, &self.encoder) |
| 150 | } |
| 151 | |
| 152 | // ==================== |
| 153 | // Decoding |
| 154 | // ==================== |
| 155 | |
| 156 | #[pyo3(name = "decode_bytes")] |
| 157 | fn py_decode_bytes(&self, py: Python, tokens: Vec<Rank>) -> Result<Py<PyBytes>, PyErr> { |
| 158 | match py.allow_threads(|| self.decode_bytes(&tokens)) { |
| 159 | Ok(bytes) => Ok(PyBytes::new_bound(py, &bytes).into()), |
| 160 | Err(e) => Err(pyo3::exceptions::PyKeyError::new_err(format!("{}", e))), |
| 161 | } |
| 162 | } |
| 163 | |
| 164 | fn decode_single_token_bytes(&self, py: Python, token: Rank) -> PyResult<Py<PyBytes>> { |
| 165 | if let Some(bytes) = self.decoder.get(&token) { |
| 166 | return Ok(PyBytes::new_bound(py, bytes).into()); |
| 167 | } |
| 168 | if let Some(bytes) = self.special_tokens_decoder.get(&token) { |
| 169 | return Ok(PyBytes::new_bound(py, bytes).into()); |
| 170 | } |
| 171 | Err(PyErr::new::<exceptions::PyKeyError, _>(token.to_string())) |
| 172 | } |
| 173 | |
| 174 | // ==================== |
| 175 | // Miscellaneous |
| 176 | // ==================== |
| 177 | |
| 178 | fn token_byte_values(&self, py: Python) -> Vec<Py<PyBytes>> { |
| 179 | self.sorted_token_bytes |
| 180 | .iter() |
| 181 | .map(|x| PyBytes::new_bound(py, x).into()) |
| 182 | .collect() |
| 183 | } |
| 184 | } |
| 185 | |
| 186 | #[pyclass] |
| 187 | struct TiktokenBuffer { |
| 188 | tokens: Vec<Rank>, |
| 189 | } |
| 190 | |
| 191 | #[pymethods] |
| 192 | impl TiktokenBuffer { |
| 193 | // Based on https://github.com/PyO3/pyo3/blob/v0.22.2/tests/test_buffer_protocol.rs#L25 |
| 194 | unsafe fn __getbuffer__( |
| 195 | slf: Bound<'_, Self>, |
| 196 | view: *mut pyo3::ffi::Py_buffer, |
| 197 | flags: std::os::raw::c_int, |
| 198 | ) -> PyResult<()> { |
| 199 | if view.is_null() { |
| 200 | return Err(pyo3::exceptions::PyBufferError::new_err("View is null")); |
| 201 | } |
| 202 | if (flags & pyo3::ffi::PyBUF_WRITABLE) == pyo3::ffi::PyBUF_WRITABLE { |
| 203 | return Err(pyo3::exceptions::PyBufferError::new_err( |
| 204 | "Object is not writable", |
| 205 | )); |
| 206 | } |
| 207 | |
| 208 | (*view).obj = slf.clone().into_any().into_ptr(); |
| 209 | |
| 210 | let data = &slf.borrow().tokens; |
| 211 | (*view).buf = data.as_ptr() as *mut std::os::raw::c_void; |
| 212 | (*view).len = (data.len() * std::mem::size_of::<Rank>()) as isize; |
| 213 | (*view).readonly = 1; |
| 214 | (*view).itemsize = std::mem::size_of::<Rank>() as isize; |
| 215 | (*view).format = if (flags & pyo3::ffi::PyBUF_FORMAT) == pyo3::ffi::PyBUF_FORMAT { |
| 216 | let msg = std::ffi::CString::new("I").unwrap(); |
| 217 | msg.into_raw() |
| 218 | } else { |
| 219 | std::ptr::null_mut() |
| 220 | }; |
| 221 | (*view).ndim = 1; |
| 222 | (*view).shape = if (flags & pyo3::ffi::PyBUF_ND) == pyo3::ffi::PyBUF_ND { |
| 223 | &mut (*view).len |
| 224 | } else { |
| 225 | std::ptr::null_mut() |
| 226 | }; |
| 227 | (*view).strides = if (flags & pyo3::ffi::PyBUF_STRIDES) == pyo3::ffi::PyBUF_STRIDES { |
| 228 | &mut (*view).itemsize |
| 229 | } else { |
| 230 | std::ptr::null_mut() |
| 231 | }; |
| 232 | (*view).suboffsets = std::ptr::null_mut(); |
| 233 | (*view).internal = std::ptr::null_mut(); |
| 234 | |
| 235 | Ok(()) |
| 236 | } |
| 237 | |
| 238 | unsafe fn __releasebuffer__(&self, view: *mut pyo3::ffi::Py_buffer) { |
| 239 | std::mem::drop(std::ffi::CString::from_raw((*view).format)); |
| 240 | } |
| 241 | } |
| 242 | |
| 243 | #[pymodule] |
| 244 | fn _tiktoken(_py: Python, m: &Bound<PyModule>) -> PyResult<()> { |
| 245 | m.add_class::<CoreBPE>()?; |
| 246 | Ok(()) |
| 247 | } |
| 248 | |