microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
test/test_bert_tokenizer_decoder.py
65lines · modecode
| 1 | from pathlib import Path |
| 2 | import unittest |
| 3 | import numpy as np |
| 4 | import transformers |
| 5 | from onnxruntime_extensions import PyOrtFunction, BertTokenizerDecoder |
| 6 | |
| 7 | bert_cased_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-cased') |
| 8 | bert_uncased_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased') |
| 9 | |
| 10 | |
| 11 | def _get_test_data_file(*sub_dirs): |
| 12 | test_dir = Path(__file__).parent |
| 13 | return str(test_dir.joinpath(*sub_dirs)) |
| 14 | |
| 15 | |
| 16 | def _run_basic_case(input, vocab_path): |
| 17 | t2stc = PyOrtFunction.from_customop(BertTokenizerDecoder, vocab_file=vocab_path) |
| 18 | ids = np.array(bert_cased_tokenizer.encode(input), dtype=np.int64) |
| 19 | position = np.array([[]], dtype=np.int64) |
| 20 | |
| 21 | result = t2stc(ids, position) |
| 22 | np.testing.assert_array_equal(result[0], |
| 23 | bert_cased_tokenizer.decode(bert_cased_tokenizer.encode(input))) |
| 24 | |
| 25 | |
| 26 | def _run_indices_case(input, indices, vocab_path): |
| 27 | t2stc = PyOrtFunction.from_customop(BertTokenizerDecoder, vocab_file=vocab_path, use_indices=1) |
| 28 | ids = np.array(bert_cased_tokenizer.encode(input), dtype=np.int64) |
| 29 | position = np.array(indices, dtype=np.int64) |
| 30 | |
| 31 | expect_result = [] |
| 32 | for index in indices: |
| 33 | if len(index) > 0: |
| 34 | result = bert_cased_tokenizer.decode(bert_cased_tokenizer.encode(input)[index[0]:index[1]]) |
| 35 | result = result.split(' ') |
| 36 | if result[0].startswith('##'): |
| 37 | result.pop(0) |
| 38 | expect_result.append(" ".join(result)) |
| 39 | |
| 40 | result = t2stc(ids, position) |
| 41 | np.testing.assert_array_equal(result, expect_result, True, False) |
| 42 | |
| 43 | |
| 44 | class TestBertTokenizerDecoder(unittest.TestCase): |
| 45 | |
| 46 | def test_text_to_case1(self): |
| 47 | _run_basic_case(input="Input 'text' must not be empty.", |
| 48 | vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt')) |
| 49 | _run_basic_case(input="网易云音乐", vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt')) |
| 50 | _run_basic_case(input="网 易 云 音 乐", |
| 51 | vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt')) |
| 52 | _run_basic_case(input="cat is playing toys", |
| 53 | vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt')) |
| 54 | _run_basic_case(input="cat isnot playing toyssss", |
| 55 | vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt')) |
| 56 | |
| 57 | _run_indices_case(input="cat isnot playing toyssss", indices=[[]], |
| 58 | vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt')) |
| 59 | |
| 60 | _run_indices_case(input="cat isnot playing toyssss", indices=[[1, 2], [3, 5]], |
| 61 | vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt')) |
| 62 | |
| 63 | |
| 64 | if __name__ == "__main__": |
| 65 | unittest.main() |
| 66 | |