microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
92f6b51106c9e9143c452e537cb5e41d2dcaa266

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

test/test_gpt2tok.py

113lines · modecode

1import unittest
2import numpy as np
3import onnxruntime as _ort
4
5from pathlib import Path
6from onnx import helper, onnx_pb as onnx_proto
7from transformers import GPT2Tokenizer
8from onnxruntime_customops import (
9 onnx_op,
10 enable_custom_op,
11 PyCustomOpDef,
12 expand_onnx_inputs,
13 get_library_path as _get_library_path)
14
15
16def _get_file_content(path):
17 with open(path, "rb") as file:
18 return file.read()
19
20
21def _get_test_data_file(*sub_dirs):
22 test_dir = Path(__file__).parent
23 return str(test_dir.joinpath(*sub_dirs))
24
25
26def _create_test_model():
27 nodes = [helper.make_node(
28 'Identity', ['input_1'], ['output_1'])]
29
30 input1 = helper.make_tensor_value_info(
31 'input_1', onnx_proto.TensorProto.INT64, [None, None])
32 output1 = helper.make_tensor_value_info(
33 'output_1', onnx_proto.TensorProto.INT64, [None, None])
34
35 graph = helper.make_graph(nodes, 'test0', [input1], [output1])
36 model = helper.make_model(graph, opset_imports=[helper.make_operatorsetid('', 12)])
37 return model
38
39
40def _bind_tokenizer(model, **kwargs):
41 vocab_file = kwargs["vocab_file"]
42 merges_file = kwargs["merges_file"]
43
44 return expand_onnx_inputs(
45 model, 'input_1',
46 [helper.make_node(
47 'GPT2Tokenizer', ['string_input'], ['input_1'], vocab=_get_file_content(vocab_file),
48 merges=_get_file_content(merges_file), name='bpetok',
49 domain='ai.onnx.contrib')],
50 [helper.make_tensor_value_info('string_input', onnx_proto.TensorProto.STRING, [None])],
51 )
52
53
54class TestGPT2Tokenizer(unittest.TestCase):
55 @classmethod
56 def setUpClass(cls):
57 tokjson = _get_test_data_file('data', 'gpt2.vocab')
58 merges = tokjson.replace('.vocab', '.merges.txt')
59 cls.tokenizer = GPT2Tokenizer(tokjson, merges)
60
61 model = _create_test_model()
62 cls.binded_model = _bind_tokenizer(model, vocab_file=tokjson, merges_file=merges)
63
64 @onnx_op(op_type="GPT2Tokenizer",
65 inputs=[PyCustomOpDef.dt_string],
66 outputs=[PyCustomOpDef.dt_int64])
67 def bpe_toenizer(s):
68 # The user custom op implementation here.
69 TestGPT2Tokenizer.pyop_invoked = True
70 return np.array(
71 [TestGPT2Tokenizer.tokenizer.encode(st_) for st_ in s])
72
73 def _run_tokenizer(self, pyop_flag, test_sentence):
74 enable_custom_op(pyop_flag)
75
76 so = _ort.SessionOptions()
77 so.register_custom_ops_library(_get_library_path())
78 sess = _ort.InferenceSession(self.binded_model.SerializeToString(), so)
79 input_text = np.array([test_sentence])
80 txtout = sess.run(None, {'string_input': input_text})
81
82 np.testing.assert_array_equal(txtout[0], np.array([self.tokenizer.encode(test_sentence)]))
83 del sess
84 del so
85
86 def test_tokenizer(self):
87 TestGPT2Tokenizer.pyop_invoked = False
88
89 self._run_tokenizer(False, "I can feel the magic, can you?")
90 self.assertFalse(TestGPT2Tokenizer.pyop_invoked)
91
92 self._run_tokenizer(False, "Hey Cortana")
93 self._run_tokenizer(False, "你好123。david")
94 self._run_tokenizer(False, "women'thinsulate 3 button leather car co")
95 self._run_tokenizer(False, "#$%^&()!@?><L:{}\\[];',./`ǠǡǢǣǤǥǦǧǨ")
96 self._run_tokenizer(False, "ڠڡڢڣڤڥڦڧڨکڪګڬڭڮگ")
97 self._run_tokenizer(False, "⛀⛁⛂⛃⛄⛅⛆⛇⛈⛉⛊⛋⛌⛍⛎⛏")
98
99 def test_tokenizer_pyop(self):
100 TestGPT2Tokenizer.pyop_invoked = False
101 self._run_tokenizer(True, "I can feel the magic, can you?")
102 self.assertTrue(TestGPT2Tokenizer.pyop_invoked)
103
104 self._run_tokenizer(True, "Hey Cortana")
105 self._run_tokenizer(True, "你好123。david")
106 self._run_tokenizer(True, "women'thinsulate 3 button leather car co")
107 self._run_tokenizer(True, "#$%^&()!@?><L:{}\\[];',./`ǠǡǢǣǤǥǦǧǨ")
108 self._run_tokenizer(True, "ڠڡڢڣڤڥڦڧڨکڪګڬڭڮگ")
109 self._run_tokenizer(True, "⛀⛁⛂⛃⛄⛅⛆⛇⛈⛉⛊⛋⛌⛍⛎⛏")
110
111
112if __name__ == "__main__":
113 unittest.main()
114