microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
bfbfa5a3044ec8d1312f3782c78ea3b9246bf667

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

test/test_bert_tokenizer_decoder.py

65lines · modecode

1from pathlib import Path
2import unittest
3import numpy as np
4import transformers
5from onnxruntime_extensions import PyOrtFunction, BertTokenizerDecoder
6
7bert_cased_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-cased')
8bert_uncased_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased')
9
10
11def _get_test_data_file(*sub_dirs):
12 test_dir = Path(__file__).parent
13 return str(test_dir.joinpath(*sub_dirs))
14
15
16def _run_basic_case(input, vocab_path):
17 t2stc = PyOrtFunction.from_customop(BertTokenizerDecoder, vocab_file=vocab_path)
18 ids = np.array(bert_cased_tokenizer.encode(input), dtype=np.int64)
19 position = np.array([[]], dtype=np.int64)
20
21 result = t2stc(ids, position)
22 np.testing.assert_array_equal(result[0],
23 bert_cased_tokenizer.decode(bert_cased_tokenizer.encode(input)))
24
25
26def _run_indices_case(input, indices, vocab_path):
27 t2stc = PyOrtFunction.from_customop(BertTokenizerDecoder, vocab_file=vocab_path, use_indices=1)
28 ids = np.array(bert_cased_tokenizer.encode(input), dtype=np.int64)
29 position = np.array(indices, dtype=np.int64)
30
31 expect_result = []
32 for index in indices:
33 if len(index) > 0:
34 result = bert_cased_tokenizer.decode(bert_cased_tokenizer.encode(input)[index[0]:index[1]])
35 result = result.split(' ')
36 if result[0].startswith('##'):
37 result.pop(0)
38 expect_result.append(" ".join(result))
39
40 result = t2stc(ids, position)
41 np.testing.assert_array_equal(result, expect_result, True, False)
42
43
44class TestBertTokenizerDecoder(unittest.TestCase):
45
46 def test_text_to_case1(self):
47 _run_basic_case(input="Input 'text' must not be empty.",
48 vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
49 _run_basic_case(input="网易云音乐", vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
50 _run_basic_case(input="网 易 云 音 乐",
51 vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
52 _run_basic_case(input="cat is playing toys",
53 vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
54 _run_basic_case(input="cat isnot playing toyssss",
55 vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
56
57 _run_indices_case(input="cat isnot playing toyssss", indices=[[]],
58 vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
59
60 _run_indices_case(input="cat isnot playing toyssss", indices=[[1, 2], [3, 5]],
61 vocab_path=_get_test_data_file('data', 'bert_basic_cased_vocab.txt'))
62
63
64if __name__ == "__main__":
65 unittest.main()
66