microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
b3a300d7bf6f3e99fb907e682f7246ed5ed5805e

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

test/test_sentencepiece_ops.py

353lines · modecode

1# coding: utf-8
2import unittest
3import os
4import base64
5import numpy as np
6from numpy.testing import assert_almost_equal
7from onnx import helper, onnx_pb as onnx_proto
8import onnxruntime as _ort
9from onnxruntime_customops import (
10 onnx_op, PyCustomOpDef,
11 get_library_path as _get_library_path)
12import tensorflow as tf
13from tensorflow_text import SentencepieceTokenizer
14
15
16def load_piece(name):
17 fullname = os.path.join(
18 os.path.dirname(__file__), "data",
19 "%s_%s.txt" % (
20 os.path.splitext(os.path.split(__file__)[-1])[0],
21 name))
22 with open(fullname, "r") as f:
23 content = f.read()
24 t = base64.decodebytes(content.encode())
25 b64 = base64.b64encode(t)
26 return np.array(list(t), dtype=np.uint8), b64
27
28
29def _create_test_model_sentencepiece(
30 prefix, model_b64, domain='ai.onnx.contrib'):
31 nodes = []
32 mkv = helper.make_tensor_value_info
33 if model_b64 is None:
34 nodes.append(helper.make_node(
35 '%sSentencepieceTokenizer' % prefix,
36 inputs=[
37 'model', # model__6
38 'inputs', # inputs
39 'nbest_size',
40 'alpha',
41 'add_bos',
42 'add_eos',
43 'reverse',
44 ],
45 outputs=['out0', 'out1'],
46 name='SentencepieceTokenizeOpName',
47 domain='ai.onnx.contrib',
48 ))
49 inputs = [
50 mkv('model', onnx_proto.TensorProto.UINT8, [None]),
51 mkv('inputs', onnx_proto.TensorProto.STRING, [None]),
52 mkv('nbest_size', onnx_proto.TensorProto.INT64, [None]),
53 mkv('alpha', onnx_proto.TensorProto.FLOAT, [None]),
54 mkv('add_bos', onnx_proto.TensorProto.BOOL, [None]),
55 mkv('add_eos', onnx_proto.TensorProto.BOOL, [None]),
56 mkv('reverse', onnx_proto.TensorProto.BOOL, [None])
57 ]
58 else:
59 nodes.append(helper.make_node(
60 '%sSentencepieceTokenizer' % prefix,
61 inputs=[
62 'inputs', # inputs
63 'nbest_size',
64 'alpha',
65 'add_bos',
66 'add_eos',
67 'reverse',
68 ],
69 outputs=['out0', 'out1'],
70 model=model_b64,
71 name='SentencepieceTokenizeOpName',
72 domain='ai.onnx.contrib',
73 ))
74 inputs = [
75 mkv('inputs', onnx_proto.TensorProto.STRING, [None]),
76 mkv('nbest_size', onnx_proto.TensorProto.INT64, [None]),
77 mkv('alpha', onnx_proto.TensorProto.FLOAT, [None]),
78 mkv('add_bos', onnx_proto.TensorProto.BOOL, [None]),
79 mkv('add_eos', onnx_proto.TensorProto.BOOL, [None]),
80 mkv('reverse', onnx_proto.TensorProto.BOOL, [None])
81 ]
82
83 graph = helper.make_graph(
84 nodes, 'test0', inputs, [
85 mkv('out0', onnx_proto.TensorProto.INT32, [None]),
86 mkv('out1', onnx_proto.TensorProto.INT64, [None])
87 ])
88 model = helper.make_model(
89 graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
90 return model
91
92
93def _create_test_model_ragged_to_sparse(
94 prefix, model_b64, domain='ai.onnx.contrib'):
95 nodes = []
96 mkv = helper.make_tensor_value_info
97 if model_b64 is None:
98 nodes.append(helper.make_node(
99 '%sSentencepieceTokenizer' % prefix,
100 inputs=[
101 'model', # model__6
102 'inputs', # inputs
103 'nbest_size',
104 'alpha',
105 'add_bos',
106 'add_eos',
107 'reverse',
108 ],
109 outputs=['tokout0', 'tokout1'],
110 name='SentencepieceTokenizeOpName',
111 domain='ai.onnx.contrib',
112 ))
113 inputs = [
114 mkv('model', onnx_proto.TensorProto.UINT8, [None]),
115 mkv('inputs', onnx_proto.TensorProto.STRING, [None]),
116 mkv('nbest_size', onnx_proto.TensorProto.INT64, [None]),
117 mkv('alpha', onnx_proto.TensorProto.FLOAT, [None]),
118 mkv('add_bos', onnx_proto.TensorProto.BOOL, [None]),
119 mkv('add_eos', onnx_proto.TensorProto.BOOL, [None]),
120 mkv('reverse', onnx_proto.TensorProto.BOOL, [None])
121 ]
122
123 nodes.append(helper.make_node(
124 '%sRaggedTensorToSparse' % prefix,
125 inputs=['tokout1', 'tokout0'],
126 outputs=['out0', 'out1', 'out2'],
127 name='RaggedTensorToSparse',
128 domain='ai.onnx.contrib',
129 ))
130 else:
131 nodes.append(helper.make_node(
132 '%sSentencepieceTokenizer' % prefix,
133 inputs=[
134 'inputs', # inputs
135 'nbest_size',
136 'alpha',
137 'add_bos',
138 'add_eos',
139 'reverse',
140 ],
141 outputs=['tokout0', 'tokout1'],
142 model=model_b64,
143 name='SentencepieceTokenizeOpName',
144 domain='ai.onnx.contrib',
145 ))
146 inputs = [
147 mkv('inputs', onnx_proto.TensorProto.STRING, [None]),
148 mkv('nbest_size', onnx_proto.TensorProto.INT64, [None]),
149 mkv('alpha', onnx_proto.TensorProto.FLOAT, [None]),
150 mkv('add_bos', onnx_proto.TensorProto.BOOL, [None]),
151 mkv('add_eos', onnx_proto.TensorProto.BOOL, [None]),
152 mkv('reverse', onnx_proto.TensorProto.BOOL, [None])
153 ]
154
155 nodes.append(helper.make_node(
156 'Shape', inputs=['tokout1'], outputs=['n_els']))
157
158 nodes.append(helper.make_node(
159 'RaggedTensorToSparse',
160 inputs=['tokout1'],
161 outputs=['out0', 'out2'],
162 name='RaggedTensorToSparse',
163 domain='ai.onnx.contrib',
164 ))
165
166 nodes.append(helper.make_node(
167 'Identity', inputs=['tokout0'], outputs=['out1']))
168
169 graph = helper.make_graph(
170 nodes, 'test0', inputs, [
171 mkv('out0', onnx_proto.TensorProto.INT64, [None]),
172 mkv('out1', onnx_proto.TensorProto.INT32, [None]),
173 mkv('out2', onnx_proto.TensorProto.INT64, [None])
174 ])
175 model = helper.make_model(
176 graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
177 return model
178
179
180class TestPythonOpSentencePiece(unittest.TestCase):
181
182 @classmethod
183 def setUpClass(cls):
184
185 @onnx_op(op_type="PySentencepieceTokenizer",
186 inputs=[PyCustomOpDef.dt_uint8, # 0: input,
187 PyCustomOpDef.dt_string, # 1: input
188 PyCustomOpDef.dt_int64, # 2: nbest_size
189 PyCustomOpDef.dt_float, # 3: alpha
190 PyCustomOpDef.dt_bool, # 4: add_bos
191 PyCustomOpDef.dt_bool, # 5: add_eos
192 PyCustomOpDef.dt_bool], # 6: reverse
193 outputs=[PyCustomOpDef.dt_int32,
194 PyCustomOpDef.dt_int64])
195 def sentence_piece_tokenizer_op(model, inputs, nbest_size,
196 alpha, add_bos, add_eos, reverse):
197 """Implements `text.SentencepieceTokenizer
198 <https://github.com/tensorflow/text/blob/master/docs/
199 api_docs/python/text/SentencepieceTokenizer.md>`_."""
200 # The custom op implementation.
201 tokenizer = SentencepieceTokenizer(
202 model=model.tobytes(),
203 reverse=reverse[0],
204 add_bos=add_bos[0],
205 add_eos=add_eos[0],
206 alpha=alpha[0],
207 nbest_size=nbest_size[0])
208 ragged_tensor = tokenizer.tokenize(inputs)
209 output_values = ragged_tensor.flat_values.numpy()
210 output_splits = ragged_tensor.nested_row_splits[0].numpy()
211 return output_values, output_splits
212
213 cls.SentencepieceTokenizer = sentence_piece_tokenizer_op
214
215 @onnx_op(op_type="PyRaggedTensorToSparse",
216 inputs=[PyCustomOpDef.dt_int64,
217 PyCustomOpDef.dt_int32],
218 outputs=[PyCustomOpDef.dt_int64,
219 PyCustomOpDef.dt_int32,
220 PyCustomOpDef.dt_int64])
221 def ragged_tensor_to_sparse(nested_splits, dense_values):
222 sparse_indices, sparse_values, sparse_dense_shape = \
223 tf.raw_ops.RaggedTensorToSparse(
224 rt_nested_splits=[nested_splits],
225 rt_dense_values=dense_values)
226 return (sparse_indices.numpy(),
227 sparse_values.numpy(),
228 sparse_dense_shape.numpy())
229
230 cls.RaggedTensorToSparse = ragged_tensor_to_sparse
231
232 def test_string_ragged_string_to_sparse_python(self):
233 so = _ort.SessionOptions()
234 so.register_custom_ops_library(_get_library_path())
235 model, model_b64 = load_piece('model__6')
236 onnx_model = _create_test_model_ragged_to_sparse('Py', None)
237 self.assertIn('op_type: "PyRaggedTensorToSparse"', str(onnx_model))
238 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
239
240 inputs = dict(
241 model=model,
242 inputs=np.array(
243 ["Hello world", "Hello world louder"], dtype=np.object),
244 nbest_size=np.array([0], dtype=np.int64),
245 alpha=np.array([0], dtype=np.float32),
246 add_bos=np.array([0], dtype=np.bool_),
247 add_eos=np.array([0], dtype=np.bool_),
248 reverse=np.array([0], dtype=np.bool_))
249 txout = sess.run(None, inputs)
250 temp = self.SentencepieceTokenizer(**inputs)
251 exp = self.RaggedTensorToSparse(temp[1], temp[0])
252 for i in range(0, 3):
253 assert_almost_equal(exp[i], txout[i])
254
255 def test_string_ragged_string_to_sparse_cc(self):
256 so = _ort.SessionOptions()
257 so.register_custom_ops_library(_get_library_path())
258 model, model_b64 = load_piece('model__6')
259 onnx_model = _create_test_model_ragged_to_sparse('', model_b64)
260 self.assertIn('op_type: "RaggedTensorToSparse"', str(onnx_model))
261 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
262
263 inputs = dict(
264 model=model,
265 inputs=np.array(
266 ["Hello world", "Hello world louder"], dtype=np.object),
267 nbest_size=np.array([0], dtype=np.int64),
268 alpha=np.array([0], dtype=np.float32),
269 add_bos=np.array([0], dtype=np.bool_),
270 add_eos=np.array([0], dtype=np.bool_),
271 reverse=np.array([0], dtype=np.bool_))
272 temp = self.SentencepieceTokenizer(**inputs)
273 exp = self.RaggedTensorToSparse(temp[1], temp[0])
274 del inputs['model']
275 txout = sess.run(None, inputs)
276 assert_almost_equal(exp[0], txout[0])
277 assert_almost_equal(exp[1], txout[1])
278 assert_almost_equal(exp[2], txout[2])
279
280 def test_string_sentencepiece_tokenizer(self):
281 so = _ort.SessionOptions()
282 so.register_custom_ops_library(_get_library_path())
283 model, model_b64 = load_piece('model__6')
284 py_onnx_model = _create_test_model_sentencepiece('Py', None)
285 self.assertIn(
286 'op_type: "PySentencepieceTokenizer"', str(py_onnx_model))
287 cc_onnx_model = _create_test_model_sentencepiece('', model_b64)
288 self.assertIn('op_type: "SentencepieceTokenizer"', str(cc_onnx_model))
289 py_sess = _ort.InferenceSession(py_onnx_model.SerializeToString(), so)
290 cc_sess = _ort.InferenceSession(cc_onnx_model.SerializeToString(), so)
291
292 for alpha in [0, 0.5]:
293 for nbest_size in [0, 1]:
294 for bools in range(0, 8):
295 with self.subTest(
296 alpha=alpha, nbest_size=nbest_size, bools=bools):
297 inputs = dict(
298 model=model,
299 inputs=np.array(
300 ["Hello world", "Hello world louder"],
301 dtype=np.object),
302 nbest_size=np.array(
303 [nbest_size], dtype=np.int64),
304 alpha=np.array([alpha], dtype=np.float32),
305 add_bos=np.array([bools & 1], dtype=np.bool_),
306 add_eos=np.array([bools & 2], dtype=np.bool_),
307 reverse=np.array([bools & 4], dtype=np.bool_))
308 exp = self.SentencepieceTokenizer(**inputs)
309 py_txout = py_sess.run(None, inputs)
310 del inputs['model']
311 cc_txout = cc_sess.run(None, inputs)
312 for i in range(0, 2):
313 assert_almost_equal(exp[i], py_txout[i])
314 assert_almost_equal(exp[i], cc_txout[i])
315
316 def test_string_sentencepiece_tokenizer_bin(self):
317 so = _ort.SessionOptions()
318 so.register_custom_ops_library(_get_library_path())
319 model, model_b64 = load_piece('model__6')
320 modelb = bytes(model)
321 py_onnx_model = _create_test_model_sentencepiece('Py', None)
322 self.assertIn(
323 'op_type: "PySentencepieceTokenizer"', str(py_onnx_model))
324 cc_onnx_model = _create_test_model_sentencepiece('', modelb)
325 self.assertIn('op_type: "SentencepieceTokenizer"', str(cc_onnx_model))
326 py_sess = _ort.InferenceSession(py_onnx_model.SerializeToString(), so)
327 cc_sess = _ort.InferenceSession(cc_onnx_model.SerializeToString(), so)
328
329 alpha = 0
330 nbest_size = 0
331 bools = 0
332 inputs = dict(
333 model=model,
334 inputs=np.array(
335 ["Hello world", "Hello world louder"],
336 dtype=np.object),
337 nbest_size=np.array(
338 [nbest_size], dtype=np.int64),
339 alpha=np.array([alpha], dtype=np.float32),
340 add_bos=np.array([bools & 1], dtype=np.bool_),
341 add_eos=np.array([bools & 2], dtype=np.bool_),
342 reverse=np.array([bools & 4], dtype=np.bool_))
343 exp = self.SentencepieceTokenizer(**inputs)
344 py_txout = py_sess.run(None, inputs)
345 del inputs['model']
346 cc_txout = cc_sess.run(None, inputs)
347 for i in range(0, 2):
348 assert_almost_equal(exp[i], py_txout[i])
349 assert_almost_equal(exp[i], cc_txout[i])
350
351
352if __name__ == "__main__":
353 unittest.main()
354