openai/tiktoken

Public

mirrored from https://github.com/openai/tiktokenAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
subquad2

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

src/lib.rs

675lines · modecode

1use std::collections::HashSet;
2use std::num::NonZeroU64;
3use std::thread;
4
5use fancy_regex::Regex;
6#[cfg(feature = "python")]
7use pyo3::prelude::*;
8use rustc_hash::FxHashMap as HashMap;
9
10#[cfg(feature = "python")]
11mod py;
12
13pub type Rank = u32;
14
15use std::collections::BinaryHeap;
16
17#[derive(Eq, PartialEq, Clone, Copy)]
18struct Merge {
19 start: usize,
20 rank: Rank,
21}
22
23impl Ord for Merge {
24 #[inline]
25 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
26 other
27 .rank
28 .cmp(&self.rank)
29 .then_with(|| other.start.cmp(&self.start))
30 }
31}
32
33impl PartialOrd for Merge {
34 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
35 Some(self.cmp(other))
36 }
37}
38
39struct State {
40 prev: usize,
41 end: usize,
42 next_end: usize,
43 next_rank: Rank,
44 cur_rank: Rank,
45}
46
47fn _byte_pair_merge_large(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<Rank> {
48 let mut state = Vec::with_capacity(piece.len());
49 state.push(State {
50 prev: usize::MAX,
51 end: 1,
52 next_end: 2,
53 next_rank: Rank::MAX,
54 cur_rank: Rank::MAX,
55 });
56
57 let mut heap = BinaryHeap::with_capacity(piece.len());
58 for i in 0..piece.len() - 1 {
59 if let Some(&rank) = ranks.get(&piece[i..i + 2]) {
60 heap.push(Merge { start: i, rank });
61 state[i].next_rank = rank;
62 }
63 // note this is happening offset by 1
64 state.push(State {
65 prev: i,
66 end: i + 2,
67 next_end: i + 3,
68 next_rank: Rank::MAX,
69 cur_rank: Rank::MAX,
70 });
71 }
72
73 // Repeatedly find the valid merge with smallest rank. We merge the (left) token that
74 // starts at `start` and ends at `state[start].end` with the (right) token that starts at
75 // `state[start].end` and ends at `state[start].next_end`. We invalidate the old merges
76 // (the ones that started at `state[start].end` and ended at `state[start]`) and add the two
77 // new potential merges to the heap.
78
79 let potential_merge = {
80 #[inline(always)]
81 |state: &mut Vec<State>,
82 heap: &mut BinaryHeap<Merge>,
83 start: usize,
84 next_end_item: usize| {
85 state[start].next_end = next_end_item;
86 state[start].next_rank = Rank::MAX; // Always invalidate the old merge
87 if next_end_item <= piece.len() {
88 if let Some(&rank) = ranks.get(&piece[start..next_end_item]) {
89 // We have a valid potential merge!
90 heap.push(Merge { start, rank });
91 state[start].next_rank = rank;
92 }
93 }
94 }
95 };
96
97 while let Some(left) = heap.pop() {
98 if left.rank == Rank::MAX {
99 break;
100 }
101 if left.rank != state[left.start].next_rank {
102 continue; // This merge was invalidated, ignore it
103 }
104
105 let left_start = left.start;
106 let right_start = state[left_start].end;
107 let right_end = state[left_start].next_end;
108 debug_assert!(right_end == state[right_start].end);
109 let right_next_end = state[right_start].next_end;
110
111 // Merge left and right into a single token
112 state[left_start].cur_rank = state[left_start].next_rank;
113 state[left_start].end = right_end;
114 potential_merge(&mut state, &mut heap, left_start, right_next_end);
115 if right_end < state.len() {
116 state[right_end].prev = left_start;
117 }
118 // Update the merge that ends at left_start
119 if left_start > 0 {
120 let prev_start = state[left_start].prev;
121 potential_merge(&mut state, &mut heap, prev_start, right_end);
122 }
123 // Invalidate the merge starting at right_start, so we ignore it when it comes off the heap
124 state[right_start].next_rank = Rank::MAX;
125 }
126
127 let mut result = Vec::new();
128 let mut i = 0;
129 while i < state.len() {
130 if state[i].cur_rank != Rank::MAX {
131 result.push(state[i].cur_rank);
132 } else {
133 result.push(ranks[&piece[i..state[i].end]]);
134 }
135 i = state[i].end;
136 }
137 result
138}
139
140fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
141 // This is a vector of (start, rank).
142 // The rank is of the pair starting at position start.
143 let mut parts = Vec::with_capacity(piece.len() + 1);
144
145 // Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE
146 // the way we currently do, this is equivalent. An easy way to break this would be to decouple
147 // merge priority from token index or to prevent specific token merges.
148 let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
149 for i in 0..piece.len() - 1 {
150 let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
151 if rank < min_rank.0 {
152 min_rank = (rank, i);
153 }
154 parts.push((i, rank));
155 }
156 parts.push((piece.len() - 1, Rank::MAX));
157 parts.push((piece.len(), Rank::MAX));
158
159 let get_rank = {
160 #[inline(always)]
161 |parts: &Vec<(usize, Rank)>, i: usize| {
162 if (i + 3) < parts.len() {
163 // Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted
164 // parts[i + 1], see comment in the main loop.
165 *ranks
166 .get(&piece[parts[i].0..parts[i + 3].0])
167 .unwrap_or(&Rank::MAX)
168 } else {
169 Rank::MAX
170 }
171 }
172 };
173
174 // If you have n parts and m merges, this does O(mn) work.
175 // We could do something with a heap and do O(m log n) work.
176 // n is often very small so considerations like cache-locality outweigh the algorithmic
177 // complexity downsides of the `parts` vector.
178 while min_rank.0 != Rank::MAX {
179 let i = min_rank.1;
180 // Update parts[i] and parts[i - 1] before removing parts[i + 1], since
181 // `parts.remove(i + 1)` will thrash the cache.
182 if i > 0 {
183 parts[i - 1].1 = get_rank(&parts, i - 1);
184 }
185 parts[i].1 = get_rank(&parts, i);
186 parts.remove(i + 1);
187
188 min_rank = (Rank::MAX, usize::MAX);
189 for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
190 if rank < min_rank.0 {
191 min_rank = (rank, i);
192 }
193 }
194 }
195 parts
196}
197
198pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
199 let piece_len = piece.len();
200
201 if piece_len == 1 {
202 return vec![ranks[piece]];
203 }
204 if piece_len < 100 {
205 return _byte_pair_merge(ranks, piece)
206 .windows(2)
207 .map(|part| ranks[&piece[part[0].0..part[1].0]])
208 .collect();
209 }
210 return _byte_pair_merge_large(ranks, piece);
211}
212
213pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<&'a [u8]> {
214 assert!(piece.len() > 1);
215 return _byte_pair_merge(ranks, piece)
216 .windows(2)
217 .map(|part| &piece[part[0].0..part[1].0])
218 .collect();
219 // TODO: _byte_pair_merge_large
220}
221
222// Various performance notes:
223//
224// Regex
225// =====
226// Most of the time is spent in regex. The easiest way to speed this up is by using less fancy
227// regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than
228// the usual regex we use.
229//
230// However, given that we're using a regex parse-able by `regex`, there isn't much difference
231// between using the `regex` crate and using the `fancy_regex` crate.
232//
233// There is an important interaction between threading, `regex` and `fancy_regex`.
234// When using `fancy_regex`, we hit `regex.find_at`. It turns out that this causes contention on
235// some mutable scratch space inside of `regex`. This absolutely kills performance. When using plain
236// old `regex`, we don't hit this, because `find_iter` has a different code path.
237// Related: https://github.com/rust-lang/regex/blob/master/PERFORMANCE.md
238// Anyway, the way we get around this is with having a (mostly) thread local clone of the regex for
239// each thread.
240//
241// Threading
242// =========
243// I tried using `rayon`. It wasn't really faster than using Python threads and releasing the GIL.
244// So goodbye `rayon`! Let thread count etc be in control of our Python users.
245//
246// Caching
247// =======
248// The reference tokeniser has an lru cache over the equivalent of `byte_pair_encode`.
249// Originally, we had one too! Without it, we were only vaguely faster than Python.
250// I used an RWLock to protect the cache. This didn't seem to hurt single threaded performance
251// noticeably, but it did affect multi-threaded performance. Weirdly, it seemed to affect
252// multi-threaded performance even when I only had readers (maybed I messed something up?).
253// Anyway, I realised that we could get rid of the cache, if we treat the set of tokens as a cache!
254// These are exactly the set or merges that are likely to be hot. And now we don't have to think
255// about interior mutability, memory use, or cloning.
256//
257// Hashing
258// =======
259// We use FxHashMap instead of the standard HashMap. This is maybe like a 5-10% win?
260// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made
261// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster.
262
263struct FakeThreadId(NonZeroU64);
264
265fn hash_current_thread() -> usize {
266 // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter
267 // that works great for our use case of avoiding collisions in our array. Unfortunately,
268 // it's private. However, there are only so many ways you can layout a u64, so just transmute
269 // https://github.com/rust-lang/rust/issues/67939
270 const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()];
271 const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
272 let x = unsafe {
273 std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0
274 };
275 u64::from(x) as usize
276}
277
278#[derive(Debug, Clone)]
279pub struct DecodeKeyError {
280 pub token: Rank,
281}
282
283impl std::fmt::Display for DecodeKeyError {
284 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
285 write!(f, "Invalid token for decoding: {}", self.token)
286 }
287}
288
289impl std::error::Error for DecodeKeyError {}
290
291#[derive(Debug, Clone)]
292pub struct DecodeError {
293 pub message: String,
294}
295
296impl std::fmt::Display for DecodeError {
297 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
298 write!(f, "Could not decode tokens: {}", self.message)
299 }
300}
301
302impl std::error::Error for DecodeError {}
303
304const MAX_NUM_THREADS: usize = 128;
305
306#[cfg_attr(feature = "python", pyclass)]
307#[derive(Clone)]
308pub struct CoreBPE {
309 encoder: HashMap<Vec<u8>, Rank>,
310 special_tokens_encoder: HashMap<String, Rank>,
311 decoder: HashMap<Rank, Vec<u8>>,
312 special_tokens_decoder: HashMap<Rank, Vec<u8>>,
313 regex_tls: Vec<Regex>,
314 special_regex_tls: Vec<Regex>,
315 sorted_token_bytes: Vec<Vec<u8>>,
316}
317
318impl CoreBPE {
319 fn _get_tl_regex(&self) -> &Regex {
320 // See performance notes above for what this is about
321 // It's also a little janky, please make a better version of it!
322 // However, it's nice that this doesn't leak memory to short-lived threads
323 &self.regex_tls[hash_current_thread() % MAX_NUM_THREADS]
324 }
325
326 fn _get_tl_special_regex(&self) -> &Regex {
327 &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS]
328 }
329
330 /// Decodes tokens into a list of bytes.
331 ///
332 /// The bytes are not gauranteed to be a valid utf-8 string.
333 fn decode_bytes(&self, tokens: &[Rank]) -> Result<Vec<u8>, DecodeKeyError> {
334 let mut ret = Vec::with_capacity(tokens.len() * 2);
335 for &token in tokens {
336 let token_bytes = match self.decoder.get(&token) {
337 Some(bytes) => bytes,
338 None => self
339 .special_tokens_decoder
340 .get(&token)
341 .ok_or(DecodeKeyError { token })?,
342 };
343 ret.extend(token_bytes);
344 }
345 Ok(ret)
346 }
347
348 pub fn encode_ordinary(&self, text: &str) -> Vec<Rank> {
349 // This is the core of the encoding logic; the other functions in here
350 // just make things complicated :-)
351 let regex = self._get_tl_regex();
352 let mut ret = vec![];
353 for mat in regex.find_iter(text) {
354 let piece = mat.unwrap().as_str().as_bytes();
355 match self.encoder.get(piece) {
356 Some(token) => ret.push(*token),
357 None => ret.extend(&byte_pair_encode(piece, &self.encoder)),
358 }
359 }
360 ret
361 }
362
363 pub fn encode(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<Rank>, usize) {
364 let special_regex = self._get_tl_special_regex();
365 let regex = self._get_tl_regex();
366 let mut ret = vec![];
367
368 let mut start = 0;
369 let mut last_piece_token_len = 0;
370 loop {
371 let mut next_special;
372 let mut start_find = start;
373 loop {
374 // Find the next allowed special token, if any
375 next_special = special_regex.find_from_pos(text, start_find).unwrap();
376 match next_special {
377 Some(m) => {
378 if allowed_special.contains(&text[m.start()..m.end()]) {
379 break;
380 }
381 start_find = m.start() + 1;
382 }
383 None => break,
384 }
385 }
386 let end = next_special.map_or(text.len(), |m| m.start());
387
388 // Okay, here we go, compare this logic to encode_ordinary
389 for mat in regex.find_iter(&text[start..end]) {
390 let piece = mat.unwrap().as_str().as_bytes();
391 if let Some(token) = self.encoder.get(piece) {
392 last_piece_token_len = 1;
393 ret.push(*token);
394 continue;
395 }
396 let tokens = byte_pair_encode(piece, &self.encoder);
397 last_piece_token_len = tokens.len();
398 ret.extend(&tokens);
399 }
400
401 match next_special {
402 // And here we push the special token
403 Some(m) => {
404 let piece = m.as_str();
405 let token = self.special_tokens_encoder[piece];
406 ret.push(token);
407 start = m.end();
408 last_piece_token_len = 0;
409 }
410 None => break,
411 }
412 }
413
414 // last_piece_token_len is how many tokens came from the last regex split. This is used
415 // for determining unstable tokens, since you can't merge across (stable) regex splits
416 (ret, last_piece_token_len)
417 }
418
419 fn _increase_last_piece_token_len(
420 &self,
421 tokens: Vec<Rank>,
422 mut last_piece_token_len: usize,
423 ) -> (Vec<Rank>, usize) {
424 // Unfortunately, the locations where our regex splits can be unstable.
425 // For the purposes of determining unstable tokens, unstable regex splitting
426 // is only a problem if a split that was present disappears, since this can
427 // lead to merging of tokens otherwise thought to be stable.
428 // cl100k_base makes our life hard by including the \s*[\r\n]+
429 // pattern. This can e.g. cause "\n" + " " to become "\n \n".
430 // Here is a quick and dirty fix:
431 {
432 let token_is_all_space = |token| {
433 self.decoder
434 .get(token)
435 .map(|token_bytes| {
436 token_bytes
437 .iter()
438 .rev()
439 .all(|&b| [b' ', b'\n', b'\t'].contains(&b))
440 })
441 .unwrap_or(false)
442 };
443 if last_piece_token_len > 0
444 && token_is_all_space(&tokens[tokens.len() - last_piece_token_len])
445 {
446 while (last_piece_token_len < tokens.len())
447 && token_is_all_space(&tokens[tokens.len() - last_piece_token_len - 1])
448 {
449 last_piece_token_len += 1;
450 }
451 }
452 }
453 debug_assert!(last_piece_token_len <= tokens.len());
454
455 (tokens, last_piece_token_len)
456 }
457
458 pub fn _encode_unstable_native(
459 &self,
460 text: &str,
461 allowed_special: &HashSet<&str>,
462 ) -> (Vec<Rank>, HashSet<Vec<Rank>>) {
463 let (tokens, last_piece_token_len) = self.encode(text, allowed_special);
464 if last_piece_token_len == 0 {
465 // If last_piece_token_len is zero, the last token was a special token and we have
466 // no unstable bytes
467 return (tokens, HashSet::new());
468 }
469 let (mut tokens, last_piece_token_len) =
470 self._increase_last_piece_token_len(tokens, last_piece_token_len);
471
472 let unstable_bytes = self
473 .decode_bytes(&tokens[tokens.len() - last_piece_token_len..])
474 .unwrap();
475 tokens.truncate(tokens.len() - last_piece_token_len);
476
477 // TODO: we should try harder to find additional stable tokens
478 // This would reduce the amount of retokenising when determining completions
479 // Refer to the logic in an older version of this file
480
481 let mut completions = HashSet::new();
482 if unstable_bytes.is_empty() {
483 return (tokens, completions);
484 }
485
486 // This is the easy bit. Just find all single tokens that start with unstable_bytes
487 // (including tokens that exactly match unstable_bytes)
488 // Separating this from the loop below helps with performance in a common case.
489 let mut point = self
490 .sorted_token_bytes
491 .partition_point(|x| x.as_slice() < unstable_bytes.as_slice());
492 while point < self.sorted_token_bytes.len()
493 && self.sorted_token_bytes[point].starts_with(&unstable_bytes)
494 {
495 completions.insert(vec![
496 self.encoder[self.sorted_token_bytes[point].as_slice()],
497 ]);
498 point += 1;
499 }
500
501 // Now apply even more brute force. At every (other) possible position for the straddling
502 // token, concatenate additional bytes from that token (if any) to unstable_bytes,
503 // and retokenise the whole thing and see what we get.
504 for i in 1..unstable_bytes.len() {
505 let prefix = &unstable_bytes[..i];
506 let suffix = &unstable_bytes[i..];
507 let mut point = self
508 .sorted_token_bytes
509 .partition_point(|x| x.as_slice() < suffix);
510 // TODO: Perf optimisation if suffix starts with " "?
511 while point < self.sorted_token_bytes.len()
512 && self.sorted_token_bytes[point].starts_with(suffix)
513 {
514 let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat();
515 let encoded = match std::str::from_utf8(&possibility) {
516 // Morally, this is byte_pair_encode(&possibility, &self.encoder)
517 // But we might have introduced a regex split which would prevent merges.
518 // (particularly possible in the presence of unstable regex splits)
519 // So convert to UTF-8 and do regex splitting.
520 // E.g. with cl100k_base " !" gets split to " " + " !",
521 // but byte_pair_encode(" !") != byte_pair_encode(" ")
522 Ok(s) => self.encode_ordinary(s),
523
524 // Technically, whether or not this arm is correct depends on whether there
525 // would be a regex split before the UTF-8 truncation point.
526 // Probably niche enough that no one will ever notice (after all, people didn't
527 // notice all the big holes in the previous unstable token implementation)
528 Err(_) => byte_pair_encode(&possibility, &self.encoder),
529 // Something like the following is intriguing but incorrect:
530 // Err(e) => self.encode_ordinary(unsafe {
531 // std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()])
532 // }),
533 };
534 let mut seq = Vec::new();
535 let mut seq_len = 0;
536 for token in encoded {
537 seq.push(token);
538 seq_len += self.decoder[&token].len();
539 if seq_len >= unstable_bytes.len() {
540 break;
541 }
542 }
543 completions.insert(seq);
544 point += 1;
545 }
546 }
547
548 // This is also not straightforward. While we generally assume that regex splits are stable,
549 // unfortunately, they are not. That is, if adding bytes were to make a split appear in
550 // unstable_bytes, this could make tokens possible which our logic would otherwise think
551 // would be merged.
552 // For example, with gpt2, the use of \s+(?!\S) means that "\n\n" could
553 // develop a split, e.g. "\n\n0" splits into "\n"+"\n"+"0", making "\n" a possible token.
554 // Here is a quick and dirty fix:
555 // This isn't right if we ever remove \s+(?!\S)
556 if unstable_bytes.len() > 1 {
557 let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice());
558 if unstable_bytes.len() - last_decoded.1 > 0
559 && last_decoded.0.map_or(false, |c| c.is_whitespace())
560 {
561 let mut reencoded = byte_pair_encode(
562 &unstable_bytes[..unstable_bytes.len() - last_decoded.1],
563 &self.encoder,
564 );
565 reencoded.extend(byte_pair_encode(
566 &unstable_bytes[unstable_bytes.len() - last_decoded.1..],
567 &self.encoder,
568 ));
569 completions.insert(reencoded);
570 }
571 }
572
573 (tokens, completions)
574 }
575
576 pub fn new<E, SE, NSE>(
577 encoder: E,
578 special_tokens_encoder: SE,
579 pattern: &str,
580 ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>>
581 where
582 E: IntoIterator<Item = (Vec<u8>, Rank)>,
583 SE: IntoIterator<Item = (String, Rank)>,
584 NSE: IntoIterator<Item = (String, (Rank, Rank))>,
585 {
586 Self::new_internal(
587 HashMap::from_iter(encoder),
588 HashMap::from_iter(special_tokens_encoder),
589 pattern,
590 )
591 }
592
593 fn new_internal(
594 encoder: HashMap<Vec<u8>, Rank>,
595 special_tokens_encoder: HashMap<String, Rank>,
596 pattern: &str,
597 ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
598 let regex = Regex::new(pattern)?;
599
600 let special_regex = {
601 let parts = special_tokens_encoder
602 .keys()
603 .map(|s| fancy_regex::escape(s))
604 .collect::<Vec<_>>();
605 Regex::new(&parts.join("|"))?
606 };
607
608 let decoder: HashMap<Rank, Vec<u8>> =
609 encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
610
611 assert!(
612 encoder.len() == decoder.len(),
613 "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
614 );
615
616 let special_tokens_decoder: HashMap<Rank, Vec<u8>> = special_tokens_encoder
617 .iter()
618 .map(|(k, v)| (*v, k.as_bytes().to_vec()))
619 .collect();
620
621 // Clone because I don't know how to tell Rust I'm not going to change the map
622 let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
623 sorted_token_bytes.sort();
624
625 Ok(Self {
626 encoder,
627 special_tokens_encoder,
628 decoder,
629 special_tokens_decoder,
630 regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
631 special_regex_tls: (0..MAX_NUM_THREADS)
632 .map(|_| special_regex.clone())
633 .collect(),
634 sorted_token_bytes,
635 })
636 }
637
638 pub fn special_tokens(&self) -> HashSet<&str> {
639 self.special_tokens_encoder
640 .keys()
641 .map(|s| s.as_str())
642 .collect()
643 }
644
645 pub fn encode_with_special_tokens(&self, text: &str) -> Vec<Rank> {
646 let allowed_special = self.special_tokens();
647 self.encode(text, &allowed_special).0
648 }
649}
650
651#[cfg(test)]
652mod tests {
653
654 use rustc_hash::FxHashMap as HashMap;
655
656 use crate::{byte_pair_split, Rank};
657
658 fn setup_ranks() -> HashMap<Vec<u8>, Rank> {
659 HashMap::from_iter([(b"ab".to_vec(), 0), (b"cd".to_vec(), 1)])
660 }
661
662 #[test]
663 fn test_simple_characters() {
664 let ranks = setup_ranks();
665 let res = byte_pair_split(b"abcd", &ranks);
666 assert_eq!(res, vec![b"ab", b"cd"]);
667 }
668
669 #[test]
670 fn test_repeated_characters() {
671 let ranks = setup_ranks();
672 let res = byte_pair_split(b"abab", &ranks);
673 assert_eq!(res, vec![b"ab", b"ab"]);
674 }
675}
676