openai/tiktoken

Public

mirrored from https://github.com/openai/tiktokenAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
subquad2

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

src/py.rs

243lines · modecode

1use std::collections::HashSet;
2
3use pyo3::{
4 exceptions,
5 prelude::*,
6 pybacked::PyBackedStr,
7 types::{PyBytes, PyList, PyTuple},
8 PyResult,
9};
10use rustc_hash::FxHashMap as HashMap;
11
12use crate::{byte_pair_encode, CoreBPE, Rank};
13
14#[pymethods]
15impl CoreBPE {
16 #[new]
17 fn py_new(
18 encoder: HashMap<Vec<u8>, Rank>,
19 special_tokens_encoder: HashMap<String, Rank>,
20 pattern: &str,
21 ) -> PyResult<Self> {
22 Self::new_internal(encoder, special_tokens_encoder, pattern)
23 .map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))
24 }
25
26 // ====================
27 // Encoding
28 // ====================
29
30 #[pyo3(name = "encode_ordinary")]
31 fn py_encode_ordinary(&self, py: Python, text: &str) -> Vec<Rank> {
32 py.allow_threads(|| self.encode_ordinary(text))
33 }
34
35 #[pyo3(name = "encode")]
36 fn py_encode(
37 &self,
38 py: Python,
39 text: &str,
40 allowed_special: HashSet<PyBackedStr>,
41 ) -> Vec<Rank> {
42 py.allow_threads(|| {
43 let allowed_special: HashSet<&str> =
44 allowed_special.iter().map(|s| s.as_ref()).collect();
45 self.encode(text, &allowed_special).0
46 })
47 }
48
49 fn encode_to_tiktoken_buffer(
50 &self,
51 py: Python,
52 text: &str,
53 allowed_special: HashSet<PyBackedStr>,
54 ) -> Py<PyAny> {
55 let tokens = py.allow_threads(|| {
56 let allowed_special: HashSet<&str> =
57 allowed_special.iter().map(|s| s.as_ref()).collect();
58 self.encode(text, &allowed_special).0
59 });
60 let buffer = TiktokenBuffer { tokens };
61 buffer.into_py(py)
62 }
63
64 fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<Rank> {
65 py.allow_threads(|| {
66 match std::str::from_utf8(bytes) {
67 // Straightforward case
68 Ok(text) => self.encode_ordinary(text),
69 // Oops, don't actually have UTF-8. But we need to do the regex splitting in
70 // Unicode space, so we make our best guess at where we would have splits
71 Err(e) => {
72 let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) };
73 let (tokens, last_piece_token_len) = self.encode(text, &HashSet::new());
74 let (mut tokens, last_piece_token_len) =
75 self._increase_last_piece_token_len(tokens, last_piece_token_len);
76
77 let mut unstable_bytes;
78 if !tokens.is_empty() && last_piece_token_len > 0 {
79 // Lop off the tokens from the last piece and run BPE on the remaining bytes
80 // This likely matches what models see better, e.g. if you assume we're
81 // dealing with truncated UTF-8 bytes.
82 // Niche, but note this may not be correct if we'd have had a regex
83 // split between the valid UTF-8 and the invalid bytes.
84 unstable_bytes = self
85 .decode_bytes(&tokens[tokens.len() - last_piece_token_len..])
86 .unwrap();
87 unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]);
88
89 tokens.truncate(tokens.len() - last_piece_token_len);
90 } else {
91 unstable_bytes = bytes[e.valid_up_to()..].to_vec();
92 }
93
94 if !unstable_bytes.is_empty() {
95 match self.encoder.get(&unstable_bytes) {
96 Some(token) => tokens.push(*token),
97 None => {
98 tokens.extend(&byte_pair_encode(&unstable_bytes, &self.encoder))
99 }
100 }
101 }
102 tokens
103 }
104 }
105 })
106 }
107
108 #[pyo3(name = "encode_with_unstable")]
109 fn py_encode_with_unstable(
110 &self,
111 py: Python,
112 text: &str,
113 allowed_special: HashSet<PyBackedStr>,
114 ) -> Py<PyTuple> {
115 let (tokens, completions) = py.allow_threads(|| {
116 let allowed_special: HashSet<&str> =
117 allowed_special.iter().map(|s| s.as_ref()).collect();
118 self._encode_unstable_native(text, &allowed_special)
119 });
120 let py_completions = PyList::new_bound(
121 py,
122 completions
123 .iter()
124 .map(|seq| PyList::new_bound(py, &seq[..])),
125 );
126 (tokens, py_completions).into_py(py)
127 }
128
129 fn encode_single_token(&self, piece: &[u8]) -> PyResult<Rank> {
130 if let Some(token) = self.encoder.get(piece).copied() {
131 return Ok(token);
132 }
133 if let Ok(piece_str) = std::str::from_utf8(piece) {
134 if let Some(token) = self.special_tokens_encoder.get(piece_str).copied() {
135 return Ok(token);
136 }
137 }
138 Err(PyErr::new::<exceptions::PyKeyError, _>(piece.to_owned()))
139 }
140
141 fn encode_single_piece(&self, piece: &[u8]) -> Vec<Rank> {
142 if let Some(token) = self.encoder.get(piece) {
143 return vec![*token];
144 }
145 byte_pair_encode(piece, &self.encoder)
146 }
147
148 // ====================
149 // Decoding
150 // ====================
151
152 #[pyo3(name = "decode_bytes")]
153 fn py_decode_bytes(&self, py: Python, tokens: Vec<Rank>) -> Result<Py<PyBytes>, PyErr> {
154 match py.allow_threads(|| self.decode_bytes(&tokens)) {
155 Ok(bytes) => Ok(PyBytes::new_bound(py, &bytes).into()),
156 Err(e) => Err(pyo3::exceptions::PyKeyError::new_err(format!("{}", e))),
157 }
158 }
159
160 fn decode_single_token_bytes(&self, py: Python, token: Rank) -> PyResult<Py<PyBytes>> {
161 if let Some(bytes) = self.decoder.get(&token) {
162 return Ok(PyBytes::new_bound(py, bytes).into());
163 }
164 if let Some(bytes) = self.special_tokens_decoder.get(&token) {
165 return Ok(PyBytes::new_bound(py, bytes).into());
166 }
167 Err(PyErr::new::<exceptions::PyKeyError, _>(token.to_string()))
168 }
169
170 // ====================
171 // Miscellaneous
172 // ====================
173
174 fn token_byte_values(&self, py: Python) -> Vec<Py<PyBytes>> {
175 self.sorted_token_bytes
176 .iter()
177 .map(|x| PyBytes::new_bound(py, x).into())
178 .collect()
179 }
180}
181
182#[pyclass]
183struct TiktokenBuffer {
184 tokens: Vec<Rank>,
185}
186
187#[pymethods]
188impl TiktokenBuffer {
189 // Based on https://github.com/PyO3/pyo3/blob/v0.22.2/tests/test_buffer_protocol.rs#L25
190 unsafe fn __getbuffer__(
191 slf: Bound<'_, Self>,
192 view: *mut pyo3::ffi::Py_buffer,
193 flags: std::os::raw::c_int,
194 ) -> PyResult<()> {
195 if view.is_null() {
196 return Err(pyo3::exceptions::PyBufferError::new_err("View is null"));
197 }
198 if (flags & pyo3::ffi::PyBUF_WRITABLE) == pyo3::ffi::PyBUF_WRITABLE {
199 return Err(pyo3::exceptions::PyBufferError::new_err(
200 "Object is not writable",
201 ));
202 }
203
204 (*view).obj = slf.clone().into_any().into_ptr();
205
206 let data = &slf.borrow().tokens;
207 (*view).buf = data.as_ptr() as *mut std::os::raw::c_void;
208 (*view).len = (data.len() * std::mem::size_of::<Rank>()) as isize;
209 (*view).readonly = 1;
210 (*view).itemsize = std::mem::size_of::<Rank>() as isize;
211 (*view).format = if (flags & pyo3::ffi::PyBUF_FORMAT) == pyo3::ffi::PyBUF_FORMAT {
212 let msg = std::ffi::CString::new("I").unwrap();
213 msg.into_raw()
214 } else {
215 std::ptr::null_mut()
216 };
217 (*view).ndim = 1;
218 (*view).shape = if (flags & pyo3::ffi::PyBUF_ND) == pyo3::ffi::PyBUF_ND {
219 &mut (*view).len
220 } else {
221 std::ptr::null_mut()
222 };
223 (*view).strides = if (flags & pyo3::ffi::PyBUF_STRIDES) == pyo3::ffi::PyBUF_STRIDES {
224 &mut (*view).itemsize
225 } else {
226 std::ptr::null_mut()
227 };
228 (*view).suboffsets = std::ptr::null_mut();
229 (*view).internal = std::ptr::null_mut();
230
231 Ok(())
232 }
233
234 unsafe fn __releasebuffer__(&self, view: *mut pyo3::ffi::Py_buffer) {
235 std::mem::drop(std::ffi::CString::from_raw((*view).format));
236 }
237}
238
239#[pymodule]
240fn _tiktoken(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
241 m.add_class::<CoreBPE>()?;
242 Ok(())
243}