openai/tiktoken
Publicmirrored fromhttps://github.com/openai/tiktokenAvailable
tests/test_offsets.py
79lines · modecode
| 1 | from typing import Callable |
| 2 | |
| 3 | import hypothesis |
| 4 | import pytest |
| 5 | from hypothesis import strategies as st |
| 6 | |
| 7 | import tiktoken |
| 8 | |
| 9 | from .test_helpers import MAX_EXAMPLES, SOME_ENCODING_FACTORIES |
| 10 | |
| 11 | |
| 12 | def _common_prefix_len(a, b): |
| 13 | i = 0 |
| 14 | while i < len(a) and i < len(b) and a[i] == b[i]: |
| 15 | i += 1 |
| 16 | return i |
| 17 | |
| 18 | |
| 19 | def _token_offsets_reference(enc, tokens): |
| 20 | text = enc.decode(tokens, errors="strict") |
| 21 | res = [] |
| 22 | for i in range(len(tokens)): |
| 23 | prefix = enc.decode(tokens[:i], errors="ignore") |
| 24 | res.append(_common_prefix_len(text, prefix)) |
| 25 | return res |
| 26 | |
| 27 | |
| 28 | @pytest.mark.parametrize("make_enc", SOME_ENCODING_FACTORIES) |
| 29 | @hypothesis.given(data=st.data()) |
| 30 | @hypothesis.settings(deadline=None, max_examples=MAX_EXAMPLES) |
| 31 | def test_hyp_offsets(make_enc: Callable[[], tiktoken.Encoding], data): |
| 32 | enc = make_enc() |
| 33 | |
| 34 | tokens_st = st.lists( |
| 35 | st.integers(0, enc.n_vocab - 1).filter( |
| 36 | lambda x: x in enc._special_tokens.values() or x in enc._mergeable_ranks.values() |
| 37 | ), |
| 38 | min_size=1, |
| 39 | max_size=20, |
| 40 | ) |
| 41 | tokens = data.draw(tokens_st) |
| 42 | |
| 43 | # This is a dumb hack to make sure that our tokens are a valid UTF-8 string |
| 44 | # We could potentially drop this, see the TODO in decode_with_offsets |
| 45 | tokens = enc.encode(enc.decode(tokens, errors="ignore"), allowed_special="all") |
| 46 | assert enc.decode_with_offsets(tokens)[1] == _token_offsets_reference(enc, tokens) |
| 47 | |
| 48 | |
| 49 | def test_basic_offsets(): |
| 50 | enc = tiktoken.get_encoding("cl100k_base") |
| 51 | |
| 52 | prompt = "hello world" |
| 53 | p, o = enc.decode_with_offsets(enc.encode(prompt)) |
| 54 | assert p == prompt |
| 55 | assert o == [0, 5] |
| 56 | |
| 57 | prompt = "hello world<|endoftext|> green cow" |
| 58 | p, o = enc.decode_with_offsets(enc.encode(prompt, allowed_special="all")) |
| 59 | assert p == prompt |
| 60 | assert o == [0, 5, 11, 24, 30] |
| 61 | |
| 62 | prompt = "我非常渴望与人工智能一起工作" |
| 63 | p, o = enc.decode_with_offsets(enc.encode(prompt)) |
| 64 | assert p == prompt |
| 65 | assert o == [0, 1, 2, 3, 3, 4, 4, 5, 6, 7, 8, 8, 9, 10, 11, 12, 13] |
| 66 | |
| 67 | # contains the interesting tokens b'\xe0\xae\xbf\xe0\xae' and b'\xe0\xaf\x8d\xe0\xae' |
| 68 | # in which \xe0 is the start of a 3-byte UTF-8 character |
| 69 | prompt = "நடிகர் சூர்யா" |
| 70 | p, o = enc.decode_with_offsets(enc.encode(prompt)) |
| 71 | assert p == prompt |
| 72 | assert o == [0, 0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 8, 8, 9, 9, 10, 11, 12, 12] |
| 73 | |
| 74 | # contains the interesting token b'\xa0\xe9\x99\xa4' |
| 75 | # in which \xe9 is the start of a 3-byte UTF-8 character and \xa0 is a continuation byte |
| 76 | prompt = " Ġ除" |
| 77 | p, o = enc.decode_with_offsets(enc.encode(prompt)) |
| 78 | assert p == prompt |
| 79 | assert o == [0, 1] |
| 80 | |