microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
ba200b4a0e391b45c0df4f9b1a506f0a9f574dd4

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

test/test_gpt2tok.py

126lines · modecode

1import unittest
2import numpy as np
3import onnxruntime as _ort
4
5from pathlib import Path
6from onnx import helper, onnx_pb as onnx_proto
7from transformers import GPT2Tokenizer
8from onnxruntime_extensions import (
9 onnx_op,
10 enable_custom_op,
11 PyCustomOpDef,
12 expand_onnx_inputs,
13 get_library_path as _get_library_path)
14
15
16def _get_file_content(path):
17 with open(path, "rb") as file:
18 return file.read()
19
20
21def _get_test_data_file(*sub_dirs):
22 test_dir = Path(__file__).parent
23 return str(test_dir.joinpath(*sub_dirs))
24
25
26def _create_test_model(**kwargs):
27 vocab_file = kwargs["vocab_file"]
28 merges_file = kwargs["merges_file"]
29 max_length = kwargs["max_length"]
30
31 node = [helper.make_node(
32 'GPT2Tokenizer', ['string_input'], ['input_ids', 'attention_mask'], vocab=_get_file_content(vocab_file),
33 merges=_get_file_content(merges_file), name='bpetok', padding_length=max_length,
34 domain='ai.onnx.contrib')]
35 input1 = helper.make_tensor_value_info(
36 'string_input', onnx_proto.TensorProto.STRING, [None])
37 output1 = helper.make_tensor_value_info(
38 'input_ids', onnx_proto.TensorProto.INT64, [None, None])
39 output2 = helper.make_tensor_value_info(
40 'attention_mask', onnx_proto.TensorProto.INT64, [None, None])
41
42 graph = helper.make_graph(node, 'test0', [input1], [output1, output2])
43 model = helper.make_model(graph, opset_imports=[helper.make_operatorsetid('', 12)])
44 return model
45
46
47class MyGPT2Tokenizer:
48 def __init__(self, token_json, merges):
49 self.tokenizer = GPT2Tokenizer(token_json, merges)
50 # not ensure which pad_token should be
51 self.tokenizer.pad_token = '!' # padding token = 0
52
53 def tokenizer_sentence(self, test_sentence, padding_length):
54 if padding_length == -1:
55 input_ids = np.array(self.tokenizer(test_sentence, padding=True)["input_ids"])
56 attention_mask = np.array(self.tokenizer(test_sentence, padding=True)["attention_mask"])
57 else:
58 input_ids = np.array(
59 self.tokenizer(test_sentence, padding="max_length", truncation=True, max_length=padding_length)[
60 "input_ids"])
61 attention_mask = np.array(
62 self.tokenizer(test_sentence, padding="max_length", truncation=True, max_length=padding_length)[
63 "attention_mask"])
64 return input_ids, attention_mask
65
66
67class TestGPT2Tokenizer(unittest.TestCase):
68 @classmethod
69 def setUpClass(cls):
70 cls.tokjson = _get_test_data_file('data', 'gpt2.vocab')
71 cls.merges = _get_test_data_file('data', 'gpt2.merges.txt')
72 cls.tokenizer = MyGPT2Tokenizer(cls.tokjson, cls.merges)
73
74 @onnx_op(op_type="GPT2Tokenizer",
75 inputs=[PyCustomOpDef.dt_string],
76 outputs=[PyCustomOpDef.dt_int64, PyCustomOpDef.dt_int64],
77 attrs=["padding_length"])
78 def bpe_tokenizer(s, **kwargs):
79 padding_length = kwargs["padding_length"]
80 input_ids, attention_mask = cls.tokenizer.tokenizer_sentence(s, padding_length)
81 return input_ids, attention_mask
82
83 def _run_tokenizer(self, test_sentence, padding_length=-1):
84 model = _create_test_model(vocab_file=self.tokjson, merges_file=self.merges, max_length=padding_length)
85 so = _ort.SessionOptions()
86 so.register_custom_ops_library(_get_library_path())
87 sess = _ort.InferenceSession(model.SerializeToString(), so)
88 input_text = np.array(test_sentence)
89 input_ids, attention_mask = sess.run(None, {'string_input': input_text})
90 expect_input_ids, expect_attention_mask = self.tokenizer.tokenizer_sentence(test_sentence, padding_length)
91 np.testing.assert_array_equal(expect_input_ids, input_ids)
92 np.testing.assert_array_equal(expect_attention_mask, attention_mask)
93
94 del sess
95 del so
96
97 def test_tokenizer(self):
98 enable_custom_op(False)
99
100 self._run_tokenizer(["I can feel the magic, can you?"])
101 self._run_tokenizer(["Hey Cortana"])
102 self._run_tokenizer(["你好123。david"])
103 self._run_tokenizer(["爱你一三一四"])
104 self._run_tokenizer(["women'thinsulate 3 button leather car co"])
105 self._run_tokenizer(["#$%^&()!@?><L:{}\\[];',./`ǠǡǢǣǤǥǦǧǨ"])
106 self._run_tokenizer(["ڠڡڢڣڤڥڦڧڨکڪګڬڭڮگ"])
107 self._run_tokenizer(["⛀⛁⛂⛃⛄⛅⛆⛇⛈⛉⛊⛋⛌⛍⛎⛏"])
108 self._run_tokenizer(["I can feel the magic, can you?", "Yes I do."])
109 self._run_tokenizer(["I can feel the magic, can you?", "Yes I do."], 100)
110
111 enable_custom_op(True)
112
113
114# def test_tokenizer_pyop(self):
115# self._run_tokenizer(["I can feel the magic, can you?"])
116# self._run_tokenizer(["Hey Cortana"])
117# self._run_tokenizer(["你好123。david"])
118# self._run_tokenizer(["爱你一三一四"])
119# self._run_tokenizer(["women'thinsulate 3 button leather car co"])
120# self._run_tokenizer(["#$%^&()!@?><L:{}\\[];',./`ǠǡǢǣǤǥǦǧǨ"])
121# self._run_tokenizer(["ڠڡڢڣڤڥڦڧڨکڪګڬڭڮگ"])
122# self._run_tokenizer(["⛀⛁⛂⛃⛄⛅⛆⛇⛈⛉⛊⛋⛌⛍⛎⛏"])
123
124
125if __name__ == "__main__":
126 unittest.main()
127