openai/tiktoken

Public

mirrored fromhttps://github.com/openai/tiktokenAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
3591ff175d6a80efbe4fcc7f0e219ddd4b8c52f1

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

src/py.rs

247lines · modecode

1use std::collections::HashSet;
2
3use pyo3::{
4 exceptions,
5 prelude::*,
6 pybacked::PyBackedStr,
7 types::{PyBytes, PyList, PyTuple},
8 PyResult,
9};
10use rustc_hash::FxHashMap as HashMap;
11
12use crate::{byte_pair_encode, CoreBPE, Rank};
13
14#[pymethods]
15impl CoreBPE {
16 #[new]
17 fn py_new(
18 encoder: HashMap<Vec<u8>, Rank>,
19 special_tokens_encoder: HashMap<String, Rank>,
20 pattern: &str,
21 ) -> PyResult<Self> {
22 Self::new_internal(
23 encoder,
24 special_tokens_encoder,
25 pattern,
26 )
27 .map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))
28 }
29
30 // ====================
31 // Encoding
32 // ====================
33
34 #[pyo3(name = "encode_ordinary")]
35 fn py_encode_ordinary(&self, py: Python, text: &str) -> Vec<Rank> {
36 py.allow_threads(|| self.encode_ordinary(text))
37 }
38
39 #[pyo3(name = "encode")]
40 fn py_encode(
41 &self,
42 py: Python,
43 text: &str,
44 allowed_special: HashSet<PyBackedStr>,
45 ) -> Vec<Rank> {
46 py.allow_threads(|| {
47 let allowed_special: HashSet<&str> =
48 allowed_special.iter().map(|s| s.as_ref()).collect();
49 self.encode(text, &allowed_special).0
50 })
51 }
52
53 fn encode_to_tiktoken_buffer(
54 &self,
55 py: Python,
56 text: &str,
57 allowed_special: HashSet<PyBackedStr>,
58 ) -> Py<PyAny> {
59 let tokens = py.allow_threads(|| {
60 let allowed_special: HashSet<&str> =
61 allowed_special.iter().map(|s| s.as_ref()).collect();
62 self.encode(text, &allowed_special).0
63 });
64 let buffer = TiktokenBuffer { tokens };
65 buffer.into_py(py)
66 }
67
68 fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<Rank> {
69 py.allow_threads(|| {
70 match std::str::from_utf8(bytes) {
71 // Straightforward case
72 Ok(text) => self.encode_ordinary(text),
73 // Oops, don't actually have UTF-8. But we need to do the regex splitting in
74 // Unicode space, so we make our best guess at where we would have splits
75 Err(e) => {
76 let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) };
77 let (tokens, last_piece_token_len) = self.encode(text, &HashSet::new());
78 let (mut tokens, last_piece_token_len) =
79 self._increase_last_piece_token_len(tokens, last_piece_token_len);
80
81 let mut unstable_bytes;
82 if !tokens.is_empty() && last_piece_token_len > 0 {
83 // Lop off the tokens from the last piece and run BPE on the remaining bytes
84 // This likely matches what models see better, e.g. if you assume we're
85 // dealing with truncated UTF-8 bytes.
86 // Niche, but note this may not be correct if we'd have had a regex
87 // split between the valid UTF-8 and the invalid bytes.
88 unstable_bytes = self
89 .decode_bytes(&tokens[tokens.len() - last_piece_token_len..])
90 .unwrap();
91 unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]);
92
93 tokens.truncate(tokens.len() - last_piece_token_len);
94 } else {
95 unstable_bytes = bytes[e.valid_up_to()..].to_vec();
96 }
97
98 if !unstable_bytes.is_empty() {
99 match self.encoder.get(&unstable_bytes) {
100 Some(token) => tokens.push(*token),
101 None => {
102 tokens.extend(&byte_pair_encode(&unstable_bytes, &self.encoder))
103 }
104 }
105 }
106 tokens
107 }
108 }
109 })
110 }
111
112 #[pyo3(name = "encode_with_unstable")]
113 fn py_encode_with_unstable(
114 &self,
115 py: Python,
116 text: &str,
117 allowed_special: HashSet<PyBackedStr>,
118 ) -> Py<PyTuple> {
119 let (tokens, completions) = py.allow_threads(|| {
120 let allowed_special: HashSet<&str> =
121 allowed_special.iter().map(|s| s.as_ref()).collect();
122 self._encode_unstable_native(text, &allowed_special)
123 });
124 let py_completions = PyList::new_bound(
125 py,
126 completions
127 .iter()
128 .map(|seq| PyList::new_bound(py, &seq[..])),
129 );
130 (tokens, py_completions).into_py(py)
131 }
132
133 fn encode_single_token(&self, piece: &[u8]) -> PyResult<Rank> {
134 if let Some(token) = self.encoder.get(piece).copied() {
135 return Ok(token);
136 }
137 if let Ok(piece_str) = std::str::from_utf8(piece) {
138 if let Some(token) = self.special_tokens_encoder.get(piece_str).copied() {
139 return Ok(token);
140 }
141 }
142 Err(PyErr::new::<exceptions::PyKeyError, _>(piece.to_owned()))
143 }
144
145 fn encode_single_piece(&self, piece: &[u8]) -> Vec<Rank> {
146 if let Some(token) = self.encoder.get(piece) {
147 return vec![*token];
148 }
149 byte_pair_encode(piece, &self.encoder)
150 }
151
152 // ====================
153 // Decoding
154 // ====================
155
156 #[pyo3(name = "decode_bytes")]
157 fn py_decode_bytes(&self, py: Python, tokens: Vec<Rank>) -> Result<Py<PyBytes>, PyErr> {
158 match py.allow_threads(|| self.decode_bytes(&tokens)) {
159 Ok(bytes) => Ok(PyBytes::new_bound(py, &bytes).into()),
160 Err(e) => Err(pyo3::exceptions::PyKeyError::new_err(format!("{}", e))),
161 }
162 }
163
164 fn decode_single_token_bytes(&self, py: Python, token: Rank) -> PyResult<Py<PyBytes>> {
165 if let Some(bytes) = self.decoder.get(&token) {
166 return Ok(PyBytes::new_bound(py, bytes).into());
167 }
168 if let Some(bytes) = self.special_tokens_decoder.get(&token) {
169 return Ok(PyBytes::new_bound(py, bytes).into());
170 }
171 Err(PyErr::new::<exceptions::PyKeyError, _>(token.to_string()))
172 }
173
174 // ====================
175 // Miscellaneous
176 // ====================
177
178 fn token_byte_values(&self, py: Python) -> Vec<Py<PyBytes>> {
179 self.sorted_token_bytes
180 .iter()
181 .map(|x| PyBytes::new_bound(py, x).into())
182 .collect()
183 }
184}
185
186#[pyclass]
187struct TiktokenBuffer {
188 tokens: Vec<Rank>,
189}
190
191#[pymethods]
192impl TiktokenBuffer {
193 // Based on https://github.com/PyO3/pyo3/blob/v0.22.2/tests/test_buffer_protocol.rs#L25
194 unsafe fn __getbuffer__(
195 slf: Bound<'_, Self>,
196 view: *mut pyo3::ffi::Py_buffer,
197 flags: std::os::raw::c_int,
198 ) -> PyResult<()> {
199 if view.is_null() {
200 return Err(pyo3::exceptions::PyBufferError::new_err("View is null"));
201 }
202 if (flags & pyo3::ffi::PyBUF_WRITABLE) == pyo3::ffi::PyBUF_WRITABLE {
203 return Err(pyo3::exceptions::PyBufferError::new_err(
204 "Object is not writable",
205 ));
206 }
207
208 (*view).obj = slf.clone().into_any().into_ptr();
209
210 let data = &slf.borrow().tokens;
211 (*view).buf = data.as_ptr() as *mut std::os::raw::c_void;
212 (*view).len = (data.len() * std::mem::size_of::<Rank>()) as isize;
213 (*view).readonly = 1;
214 (*view).itemsize = std::mem::size_of::<Rank>() as isize;
215 (*view).format = if (flags & pyo3::ffi::PyBUF_FORMAT) == pyo3::ffi::PyBUF_FORMAT {
216 let msg = std::ffi::CString::new("I").unwrap();
217 msg.into_raw()
218 } else {
219 std::ptr::null_mut()
220 };
221 (*view).ndim = 1;
222 (*view).shape = if (flags & pyo3::ffi::PyBUF_ND) == pyo3::ffi::PyBUF_ND {
223 &mut (*view).len
224 } else {
225 std::ptr::null_mut()
226 };
227 (*view).strides = if (flags & pyo3::ffi::PyBUF_STRIDES) == pyo3::ffi::PyBUF_STRIDES {
228 &mut (*view).itemsize
229 } else {
230 std::ptr::null_mut()
231 };
232 (*view).suboffsets = std::ptr::null_mut();
233 (*view).internal = std::ptr::null_mut();
234
235 Ok(())
236 }
237
238 unsafe fn __releasebuffer__(&self, view: *mut pyo3::ffi::Py_buffer) {
239 std::mem::drop(std::ffi::CString::from_raw((*view).format));
240 }
241}
242
243#[pymodule]
244fn _tiktoken(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
245 m.add_class::<CoreBPE>()?;
246 Ok(())
247}
248