openai/tiktoken

Public

mirrored fromhttps://github.com/openai/tiktokenAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
0.11.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

src/py.rs

255lines · modecode

1use std::collections::HashSet;
2
3use pyo3::{
4 IntoPyObjectExt, PyResult, exceptions,
5 prelude::*,
6 pybacked::PyBackedStr,
7 types::{PyBytes, PyList},
8};
9use rustc_hash::FxHashMap as HashMap;
10
11use crate::{CoreBPE, Rank, byte_pair_encode};
12
13#[pymethods]
14impl CoreBPE {
15 #[new]
16 fn py_new(
17 encoder: HashMap<Vec<u8>, Rank>,
18 special_tokens_encoder: HashMap<String, Rank>,
19 pattern: &str,
20 ) -> PyResult<Self> {
21 Self::new_internal(encoder, special_tokens_encoder, pattern)
22 .map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))
23 }
24
25 // ====================
26 // Encoding
27 // ====================
28
29 #[pyo3(name = "encode_ordinary")]
30 fn py_encode_ordinary(&self, py: Python, text: &str) -> Vec<Rank> {
31 py.allow_threads(|| self.encode_ordinary(text))
32 }
33
34 #[pyo3(name = "encode")]
35 fn py_encode(
36 &self,
37 py: Python,
38 text: &str,
39 allowed_special: HashSet<PyBackedStr>,
40 ) -> PyResult<Vec<Rank>> {
41 py.allow_threads(|| {
42 let allowed_special: HashSet<&str> =
43 allowed_special.iter().map(|s| s.as_ref()).collect();
44 match self.encode(text, &allowed_special) {
45 Ok((tokens, _)) => Ok(tokens),
46 Err(e) => Err(PyErr::new::<exceptions::PyValueError, _>(e.message)),
47 }
48 })
49 }
50
51 fn encode_to_tiktoken_buffer(
52 &self,
53 py: Python,
54 text: &str,
55 allowed_special: HashSet<PyBackedStr>,
56 ) -> PyResult<Py<PyAny>> {
57 let tokens_res = py.allow_threads(|| {
58 let allowed_special: HashSet<&str> =
59 allowed_special.iter().map(|s| s.as_ref()).collect();
60 self.encode(text, &allowed_special)
61 });
62
63 let tokens = match tokens_res {
64 Ok((tokens, _)) => tokens,
65 Err(e) => return Err(PyErr::new::<exceptions::PyValueError, _>(e.message)),
66 };
67
68 let buffer = TiktokenBuffer { tokens };
69 buffer.into_py_any(py)
70 }
71
72 fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<Rank> {
73 py.allow_threads(|| {
74 match std::str::from_utf8(bytes) {
75 // Straightforward case
76 Ok(text) => self.encode_ordinary(text),
77 // Oops, don't actually have UTF-8. But we need to do the regex splitting in
78 // Unicode space, so we make our best guess at where we would have splits
79 Err(e) => {
80 let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) };
81 let (tokens, last_piece_token_len) =
82 self.encode(text, &HashSet::new()).unwrap();
83 let (mut tokens, last_piece_token_len) =
84 self._increase_last_piece_token_len(tokens, last_piece_token_len);
85
86 let mut unstable_bytes;
87 if !tokens.is_empty() && last_piece_token_len > 0 {
88 // Lop off the tokens from the last piece and run BPE on the remaining bytes
89 // This likely matches what models see better, e.g. if you assume we're
90 // dealing with truncated UTF-8 bytes.
91 // Niche, but note this may not be correct if we'd have had a regex
92 // split between the valid UTF-8 and the invalid bytes.
93 unstable_bytes = self
94 .decode_bytes(&tokens[tokens.len() - last_piece_token_len..])
95 .unwrap();
96 unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]);
97
98 tokens.truncate(tokens.len() - last_piece_token_len);
99 } else {
100 unstable_bytes = bytes[e.valid_up_to()..].to_vec();
101 }
102
103 if !unstable_bytes.is_empty() {
104 match self.encoder.get(&unstable_bytes) {
105 Some(token) => tokens.push(*token),
106 None => {
107 tokens.extend(&byte_pair_encode(&unstable_bytes, &self.encoder))
108 }
109 }
110 }
111 tokens
112 }
113 }
114 })
115 }
116
117 #[pyo3(name = "encode_with_unstable")]
118 fn py_encode_with_unstable(
119 &self,
120 py: Python,
121 text: &str,
122 allowed_special: HashSet<PyBackedStr>,
123 ) -> PyResult<(Vec<Rank>, Py<PyList>)> {
124 let (tokens, completions): (Vec<Rank>, HashSet<Vec<Rank>>) = py.allow_threads(|| {
125 let allowed_special: HashSet<&str> =
126 allowed_special.iter().map(|s| s.as_ref()).collect();
127 self._encode_unstable_native(text, &allowed_special)
128 });
129 let py_completions = PyList::new(py, completions.into_iter())?;
130 Ok((tokens, py_completions.into()))
131 }
132
133 fn encode_single_token(&self, piece: &[u8]) -> PyResult<Rank> {
134 if let Some(token) = self.encoder.get(piece).copied() {
135 return Ok(token);
136 }
137 if let Ok(piece_str) = std::str::from_utf8(piece) {
138 if let Some(token) = self.special_tokens_encoder.get(piece_str).copied() {
139 return Ok(token);
140 }
141 }
142 Err(PyErr::new::<exceptions::PyKeyError, _>(piece.to_owned()))
143 }
144
145 fn encode_single_piece(&self, piece: &[u8]) -> Vec<Rank> {
146 if let Some(token) = self.encoder.get(piece) {
147 return vec![*token];
148 }
149 byte_pair_encode(piece, &self.encoder)
150 }
151
152 // ====================
153 // Decoding
154 // ====================
155
156 #[pyo3(name = "decode_bytes")]
157 fn py_decode_bytes(&self, py: Python, tokens: Vec<Rank>) -> Result<Py<PyBytes>, PyErr> {
158 match py.allow_threads(|| self.decode_bytes(&tokens)) {
159 Ok(bytes) => Ok(PyBytes::new(py, &bytes).into()),
160 Err(e) => Err(pyo3::exceptions::PyKeyError::new_err(format!("{}", e))),
161 }
162 }
163
164 fn decode_single_token_bytes(&self, py: Python, token: Rank) -> PyResult<Py<PyBytes>> {
165 if let Some(bytes) = self.decoder.get(&token) {
166 return Ok(PyBytes::new(py, bytes).into());
167 }
168 if let Some(bytes) = self.special_tokens_decoder.get(&token) {
169 return Ok(PyBytes::new(py, bytes).into());
170 }
171 Err(PyErr::new::<exceptions::PyKeyError, _>(token.to_string()))
172 }
173
174 // ====================
175 // Miscellaneous
176 // ====================
177
178 fn token_byte_values(&self, py: Python) -> Vec<Py<PyBytes>> {
179 self.sorted_token_bytes
180 .iter()
181 .map(|x| PyBytes::new(py, x).into())
182 .collect()
183 }
184}
185
186#[pyclass]
187struct TiktokenBuffer {
188 tokens: Vec<Rank>,
189}
190
191#[pymethods]
192impl TiktokenBuffer {
193 // Based on https://github.com/PyO3/pyo3/blob/v0.22.2/tests/test_buffer_protocol.rs#L25
194 unsafe fn __getbuffer__(
195 slf: Bound<'_, Self>,
196 view: *mut pyo3::ffi::Py_buffer,
197 flags: std::os::raw::c_int,
198 ) -> PyResult<()> {
199 if view.is_null() {
200 return Err(pyo3::exceptions::PyBufferError::new_err("View is null"));
201 }
202 if (flags & pyo3::ffi::PyBUF_WRITABLE) == pyo3::ffi::PyBUF_WRITABLE {
203 return Err(pyo3::exceptions::PyBufferError::new_err(
204 "Object is not writable",
205 ));
206 }
207 unsafe {
208 let view_ref = &mut *view;
209 view_ref.obj = slf.clone().into_any().into_ptr();
210
211 let data = &slf.borrow().tokens;
212 view_ref.buf = data.as_ptr() as *mut std::os::raw::c_void;
213 view_ref.len = (data.len() * std::mem::size_of::<Rank>()) as isize;
214 view_ref.readonly = 1;
215 view_ref.itemsize = std::mem::size_of::<Rank>() as isize;
216 view_ref.format = if (flags & pyo3::ffi::PyBUF_FORMAT) == pyo3::ffi::PyBUF_FORMAT {
217 let msg = std::ffi::CString::new("I").unwrap();
218 msg.into_raw()
219 } else {
220 std::ptr::null_mut()
221 };
222 view_ref.ndim = 1;
223 view_ref.shape = if (flags & pyo3::ffi::PyBUF_ND) == pyo3::ffi::PyBUF_ND {
224 &mut view_ref.len
225 } else {
226 std::ptr::null_mut()
227 };
228 view_ref.strides = if (flags & pyo3::ffi::PyBUF_STRIDES) == pyo3::ffi::PyBUF_STRIDES {
229 &mut view_ref.itemsize
230 } else {
231 std::ptr::null_mut()
232 };
233 view_ref.suboffsets = std::ptr::null_mut();
234 view_ref.internal = std::ptr::null_mut();
235 }
236
237 Ok(())
238 }
239
240 unsafe fn __releasebuffer__(&self, view: *mut pyo3::ffi::Py_buffer) {
241 // Note that Py_buffer doesn't have a Drop impl
242 unsafe {
243 let view_ref = &mut *view;
244 if !view_ref.format.is_null() {
245 std::mem::drop(std::ffi::CString::from_raw(view_ref.format));
246 }
247 }
248 }
249}
250
251#[pymodule]
252fn _tiktoken(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
253 m.add_class::<CoreBPE>()?;
254 Ok(())
255}
256