microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
test/test_sentencepiece_ops.py
434lines · modecode
| 1 | # coding: utf-8 |
| 2 | import unittest |
| 3 | import os |
| 4 | import base64 |
| 5 | import numpy as np |
| 6 | from numpy.testing import assert_almost_equal |
| 7 | from onnx import helper, onnx_pb as onnx_proto |
| 8 | import onnxruntime as _ort |
| 9 | from onnxruntime_extensions import ( |
| 10 | onnx_op, PyCustomOpDef, |
| 11 | get_library_path as _get_library_path) |
| 12 | import tensorflow as tf |
| 13 | from tensorflow_text import SentencepieceTokenizer |
| 14 | |
| 15 | |
| 16 | def load_piece(name): |
| 17 | fullname = os.path.join( |
| 18 | os.path.dirname(__file__), "data", |
| 19 | "%s_%s.txt" % ( |
| 20 | os.path.splitext(os.path.split(__file__)[-1])[0], |
| 21 | name)) |
| 22 | with open(fullname, "r") as f: |
| 23 | content = f.read() |
| 24 | t = base64.decodebytes(content.encode()) |
| 25 | b64 = base64.b64encode(t) |
| 26 | return np.array(list(t), dtype=np.uint8), b64 |
| 27 | |
| 28 | |
| 29 | def _create_test_model_sentencepiece( |
| 30 | prefix, model_b64, domain='ai.onnx.contrib'): |
| 31 | nodes = [] |
| 32 | mkv = helper.make_tensor_value_info |
| 33 | if model_b64 is None: |
| 34 | nodes.append(helper.make_node( |
| 35 | '%sSentencepieceTokenizer' % prefix, |
| 36 | inputs=[ |
| 37 | 'model', # model__6 |
| 38 | 'inputs', # inputs |
| 39 | 'nbest_size', |
| 40 | 'alpha', |
| 41 | 'add_bos', |
| 42 | 'add_eos', |
| 43 | 'reverse', |
| 44 | ], |
| 45 | outputs=['out0', 'out1'], |
| 46 | name='SentencepieceTokenizeOpName', |
| 47 | domain='ai.onnx.contrib', |
| 48 | )) |
| 49 | inputs = [ |
| 50 | mkv('model', onnx_proto.TensorProto.UINT8, [None]), |
| 51 | mkv('inputs', onnx_proto.TensorProto.STRING, [None]), |
| 52 | mkv('nbest_size', onnx_proto.TensorProto.INT64, [None]), |
| 53 | mkv('alpha', onnx_proto.TensorProto.FLOAT, [None]), |
| 54 | mkv('add_bos', onnx_proto.TensorProto.BOOL, [None]), |
| 55 | mkv('add_eos', onnx_proto.TensorProto.BOOL, [None]), |
| 56 | mkv('reverse', onnx_proto.TensorProto.BOOL, [None]) |
| 57 | ] |
| 58 | else: |
| 59 | nodes.append(helper.make_node( |
| 60 | '%sSentencepieceTokenizer' % prefix, |
| 61 | inputs=[ |
| 62 | 'inputs', # inputs |
| 63 | 'nbest_size', |
| 64 | 'alpha', |
| 65 | 'add_bos', |
| 66 | 'add_eos', |
| 67 | 'reverse', |
| 68 | ], |
| 69 | outputs=['out0', 'out1'], |
| 70 | model=model_b64, |
| 71 | name='SentencepieceTokenizeOpName', |
| 72 | domain='ai.onnx.contrib', |
| 73 | )) |
| 74 | inputs = [ |
| 75 | mkv('inputs', onnx_proto.TensorProto.STRING, [None]), |
| 76 | mkv('nbest_size', onnx_proto.TensorProto.INT64, [None]), |
| 77 | mkv('alpha', onnx_proto.TensorProto.FLOAT, [None]), |
| 78 | mkv('add_bos', onnx_proto.TensorProto.BOOL, [None]), |
| 79 | mkv('add_eos', onnx_proto.TensorProto.BOOL, [None]), |
| 80 | mkv('reverse', onnx_proto.TensorProto.BOOL, [None]) |
| 81 | ] |
| 82 | |
| 83 | graph = helper.make_graph( |
| 84 | nodes, 'test0', inputs, [ |
| 85 | mkv('out0', onnx_proto.TensorProto.INT32, [None]), |
| 86 | mkv('out1', onnx_proto.TensorProto.INT64, [None]) |
| 87 | ]) |
| 88 | model = helper.make_model( |
| 89 | graph, opset_imports=[helper.make_operatorsetid(domain, 1)]) |
| 90 | return model |
| 91 | |
| 92 | |
| 93 | def _create_test_model_ragged_to_sparse( |
| 94 | prefix, model_b64, domain='ai.onnx.contrib'): |
| 95 | nodes = [] |
| 96 | mkv = helper.make_tensor_value_info |
| 97 | if model_b64 is None: |
| 98 | nodes.append(helper.make_node( |
| 99 | '%sSentencepieceTokenizer' % prefix, |
| 100 | inputs=[ |
| 101 | 'model', # model__6 |
| 102 | 'inputs', # inputs |
| 103 | 'nbest_size', |
| 104 | 'alpha', |
| 105 | 'add_bos', |
| 106 | 'add_eos', |
| 107 | 'reverse', |
| 108 | ], |
| 109 | outputs=['tokout0', 'tokout1'], |
| 110 | name='SentencepieceTokenizeOpName', |
| 111 | domain='ai.onnx.contrib', |
| 112 | )) |
| 113 | inputs = [ |
| 114 | mkv('model', onnx_proto.TensorProto.UINT8, [None]), |
| 115 | mkv('inputs', onnx_proto.TensorProto.STRING, [None]), |
| 116 | mkv('nbest_size', onnx_proto.TensorProto.INT64, [None]), |
| 117 | mkv('alpha', onnx_proto.TensorProto.FLOAT, [None]), |
| 118 | mkv('add_bos', onnx_proto.TensorProto.BOOL, [None]), |
| 119 | mkv('add_eos', onnx_proto.TensorProto.BOOL, [None]), |
| 120 | mkv('reverse', onnx_proto.TensorProto.BOOL, [None]) |
| 121 | ] |
| 122 | |
| 123 | nodes.append(helper.make_node( |
| 124 | '%sRaggedTensorToSparse' % prefix, |
| 125 | inputs=['tokout1', 'tokout0'], |
| 126 | outputs=['out0', 'out1', 'out2'], |
| 127 | name='RaggedTensorToSparse', |
| 128 | domain='ai.onnx.contrib', |
| 129 | )) |
| 130 | else: |
| 131 | nodes.append(helper.make_node( |
| 132 | '%sSentencepieceTokenizer' % prefix, |
| 133 | inputs=[ |
| 134 | 'inputs', # inputs |
| 135 | 'nbest_size', |
| 136 | 'alpha', |
| 137 | 'add_bos', |
| 138 | 'add_eos', |
| 139 | 'reverse', |
| 140 | ], |
| 141 | outputs=['tokout0', 'tokout1'], |
| 142 | model=model_b64, |
| 143 | name='SentencepieceTokenizeOpName', |
| 144 | domain='ai.onnx.contrib', |
| 145 | )) |
| 146 | inputs = [ |
| 147 | mkv('inputs', onnx_proto.TensorProto.STRING, [None]), |
| 148 | mkv('nbest_size', onnx_proto.TensorProto.INT64, [None]), |
| 149 | mkv('alpha', onnx_proto.TensorProto.FLOAT, [None]), |
| 150 | mkv('add_bos', onnx_proto.TensorProto.BOOL, [None]), |
| 151 | mkv('add_eos', onnx_proto.TensorProto.BOOL, [None]), |
| 152 | mkv('reverse', onnx_proto.TensorProto.BOOL, [None]) |
| 153 | ] |
| 154 | |
| 155 | nodes.append(helper.make_node( |
| 156 | 'Shape', inputs=['tokout1'], outputs=['n_els'])) |
| 157 | |
| 158 | nodes.append(helper.make_node( |
| 159 | 'RaggedTensorToSparse', |
| 160 | inputs=['tokout1'], |
| 161 | outputs=['out0', 'out2'], |
| 162 | name='RaggedTensorToSparse', |
| 163 | domain='ai.onnx.contrib', |
| 164 | )) |
| 165 | |
| 166 | nodes.append(helper.make_node( |
| 167 | 'Identity', inputs=['tokout0'], outputs=['out1'])) |
| 168 | |
| 169 | graph = helper.make_graph( |
| 170 | nodes, 'test0', inputs, [ |
| 171 | mkv('out0', onnx_proto.TensorProto.INT64, [None]), |
| 172 | mkv('out1', onnx_proto.TensorProto.INT32, [None]), |
| 173 | mkv('out2', onnx_proto.TensorProto.INT64, [None]) |
| 174 | ]) |
| 175 | model = helper.make_model( |
| 176 | graph, opset_imports=[helper.make_operatorsetid(domain, 1)]) |
| 177 | return model |
| 178 | |
| 179 | |
| 180 | def _create_test_model_ragged_to_dense( |
| 181 | prefix, model_b64, domain='ai.onnx.contrib'): |
| 182 | nodes = [] |
| 183 | mkv = helper.make_tensor_value_info |
| 184 | nodes.append(helper.make_node( |
| 185 | '%sSentencepieceTokenizer' % prefix, |
| 186 | inputs=[ |
| 187 | 'inputs', # inputs |
| 188 | 'nbest_size', |
| 189 | 'alpha', |
| 190 | 'add_bos', |
| 191 | 'add_eos', |
| 192 | 'reverse', |
| 193 | ], |
| 194 | outputs=['tokout0', 'tokout1'], |
| 195 | model=model_b64, |
| 196 | name='SentencepieceTokenizeOpName', |
| 197 | domain='ai.onnx.contrib', |
| 198 | )) |
| 199 | inputs = [ |
| 200 | mkv('inputs', onnx_proto.TensorProto.STRING, [None]), |
| 201 | mkv('nbest_size', onnx_proto.TensorProto.INT64, [None]), |
| 202 | mkv('alpha', onnx_proto.TensorProto.FLOAT, [None]), |
| 203 | mkv('add_bos', onnx_proto.TensorProto.BOOL, [None]), |
| 204 | mkv('add_eos', onnx_proto.TensorProto.BOOL, [None]), |
| 205 | mkv('reverse', onnx_proto.TensorProto.BOOL, [None]) |
| 206 | ] |
| 207 | |
| 208 | nodes.append(helper.make_node( |
| 209 | 'Shape', inputs=['tokout1'], outputs=['n_els'])) |
| 210 | nodes.append(helper.make_node( |
| 211 | 'Cast', inputs=['tokout0'], outputs=['tokout064'], to=onnx_proto.TensorProto.INT64)) |
| 212 | |
| 213 | default_value = helper.make_tensor("default_value", onnx_proto.TensorProto.INT64, [1, ], [-1]) |
| 214 | unused = helper.make_tensor("unused", onnx_proto.TensorProto.INT64, [0, ], []) |
| 215 | |
| 216 | nodes.append(helper.make_node( |
| 217 | '%sRaggedTensorToDense' % prefix, |
| 218 | inputs=['unused', 'tokout064', 'default_value', 'tokout1'], |
| 219 | outputs=['out0'], |
| 220 | name='RaggedTensorToDense', |
| 221 | domain='ai.onnx.contrib', |
| 222 | )) |
| 223 | |
| 224 | nodes.append(helper.make_node( |
| 225 | 'Identity', inputs=['tokout0'], outputs=['out1'])) |
| 226 | |
| 227 | graph = helper.make_graph( |
| 228 | nodes, 'test0', inputs, [ |
| 229 | mkv('out0', onnx_proto.TensorProto.INT64, [None]), |
| 230 | mkv('out1', onnx_proto.TensorProto.INT32, [None]), |
| 231 | ], [default_value, unused]) |
| 232 | model = helper.make_model( |
| 233 | graph, opset_imports=[helper.make_operatorsetid(domain, 1)]) |
| 234 | return model |
| 235 | |
| 236 | |
| 237 | class TestPythonOpSentencePiece(unittest.TestCase): |
| 238 | |
| 239 | @classmethod |
| 240 | def setUpClass(cls): |
| 241 | |
| 242 | @onnx_op(op_type="PySentencepieceTokenizer", |
| 243 | inputs=[PyCustomOpDef.dt_uint8, # 0: input, |
| 244 | PyCustomOpDef.dt_string, # 1: input |
| 245 | PyCustomOpDef.dt_int64, # 2: nbest_size |
| 246 | PyCustomOpDef.dt_float, # 3: alpha |
| 247 | PyCustomOpDef.dt_bool, # 4: add_bos |
| 248 | PyCustomOpDef.dt_bool, # 5: add_eos |
| 249 | PyCustomOpDef.dt_bool], # 6: reverse |
| 250 | outputs=[PyCustomOpDef.dt_int32, |
| 251 | PyCustomOpDef.dt_int64]) |
| 252 | def sentence_piece_tokenizer_op(model, inputs, nbest_size, |
| 253 | alpha, add_bos, add_eos, reverse): |
| 254 | """Implements `text.SentencepieceTokenizer |
| 255 | <https://github.com/tensorflow/text/blob/master/docs/ |
| 256 | api_docs/python/text/SentencepieceTokenizer.md>`_.""" |
| 257 | # The custom op implementation. |
| 258 | tokenizer = SentencepieceTokenizer( |
| 259 | model=model.tobytes(), |
| 260 | reverse=reverse[0], |
| 261 | add_bos=add_bos[0], |
| 262 | add_eos=add_eos[0], |
| 263 | alpha=alpha[0], |
| 264 | nbest_size=nbest_size[0]) |
| 265 | ragged_tensor = tokenizer.tokenize(inputs) |
| 266 | output_values = ragged_tensor.flat_values.numpy() |
| 267 | output_splits = ragged_tensor.nested_row_splits[0].numpy() |
| 268 | return output_values, output_splits |
| 269 | |
| 270 | cls.SentencepieceTokenizer = sentence_piece_tokenizer_op |
| 271 | |
| 272 | @onnx_op(op_type="PyRaggedTensorToSparse", |
| 273 | inputs=[PyCustomOpDef.dt_int64, |
| 274 | PyCustomOpDef.dt_int32], |
| 275 | outputs=[PyCustomOpDef.dt_int64, |
| 276 | PyCustomOpDef.dt_int32, |
| 277 | PyCustomOpDef.dt_int64]) |
| 278 | def ragged_tensor_to_sparse(nested_splits, dense_values): |
| 279 | sparse_indices, sparse_values, sparse_dense_shape = \ |
| 280 | tf.raw_ops.RaggedTensorToSparse( |
| 281 | rt_nested_splits=[nested_splits], |
| 282 | rt_dense_values=dense_values) |
| 283 | return (sparse_indices.numpy(), |
| 284 | sparse_values.numpy(), |
| 285 | sparse_dense_shape.numpy()) |
| 286 | |
| 287 | cls.RaggedTensorToSparse = ragged_tensor_to_sparse |
| 288 | |
| 289 | def test_string_ragged_string_to_sparse_python(self): |
| 290 | so = _ort.SessionOptions() |
| 291 | so.register_custom_ops_library(_get_library_path()) |
| 292 | model, model_b64 = load_piece('model__6') |
| 293 | onnx_model = _create_test_model_ragged_to_sparse('Py', None) |
| 294 | self.assertIn('op_type: "PyRaggedTensorToSparse"', str(onnx_model)) |
| 295 | sess = _ort.InferenceSession(onnx_model.SerializeToString(), so) |
| 296 | |
| 297 | inputs = dict( |
| 298 | model=model, |
| 299 | inputs=np.array( |
| 300 | ["Hello world", "Hello world louder"], dtype=np.object), |
| 301 | nbest_size=np.array([0], dtype=np.int64), |
| 302 | alpha=np.array([0], dtype=np.float32), |
| 303 | add_bos=np.array([0], dtype=np.bool_), |
| 304 | add_eos=np.array([0], dtype=np.bool_), |
| 305 | reverse=np.array([0], dtype=np.bool_)) |
| 306 | txout = sess.run(None, inputs) |
| 307 | temp = self.SentencepieceTokenizer(**inputs) |
| 308 | exp = self.RaggedTensorToSparse(temp[1], temp[0]) |
| 309 | for i in range(0, 3): |
| 310 | assert_almost_equal(exp[i], txout[i]) |
| 311 | |
| 312 | def test_string_ragged_string_to_sparse_cc(self): |
| 313 | so = _ort.SessionOptions() |
| 314 | so.register_custom_ops_library(_get_library_path()) |
| 315 | model, model_b64 = load_piece('model__6') |
| 316 | onnx_model = _create_test_model_ragged_to_sparse('', model_b64) |
| 317 | self.assertIn('op_type: "RaggedTensorToSparse"', str(onnx_model)) |
| 318 | sess = _ort.InferenceSession(onnx_model.SerializeToString(), so) |
| 319 | |
| 320 | inputs = dict( |
| 321 | model=model, |
| 322 | inputs=np.array( |
| 323 | ["Hello world", "Hello world louder"], dtype=np.object), |
| 324 | nbest_size=np.array([0], dtype=np.int64), |
| 325 | alpha=np.array([0], dtype=np.float32), |
| 326 | add_bos=np.array([0], dtype=np.bool_), |
| 327 | add_eos=np.array([0], dtype=np.bool_), |
| 328 | reverse=np.array([0], dtype=np.bool_)) |
| 329 | temp = self.SentencepieceTokenizer(**inputs) |
| 330 | exp = self.RaggedTensorToSparse(temp[1], temp[0]) |
| 331 | del inputs['model'] |
| 332 | txout = sess.run(None, inputs) |
| 333 | assert_almost_equal(exp[0], txout[0]) |
| 334 | assert_almost_equal(exp[1], txout[1]) |
| 335 | assert_almost_equal(exp[2], txout[2]) |
| 336 | |
| 337 | def test_string_ragged_string_to_dense_cc(self): |
| 338 | so = _ort.SessionOptions() |
| 339 | so.register_custom_ops_library(_get_library_path()) |
| 340 | model, model_b64 = load_piece('model__6') |
| 341 | onnx_model = _create_test_model_ragged_to_dense('', model_b64) |
| 342 | self.assertIn('op_type: "RaggedTensorToDense"', str(onnx_model)) |
| 343 | sess = _ort.InferenceSession(onnx_model.SerializeToString(), so) |
| 344 | |
| 345 | inputs = dict( |
| 346 | model=model, |
| 347 | inputs=np.array( |
| 348 | ["Hello world", "Hello world louder"], dtype=np.object), |
| 349 | nbest_size=np.array([0], dtype=np.int64), |
| 350 | alpha=np.array([0], dtype=np.float32), |
| 351 | add_bos=np.array([0], dtype=np.bool_), |
| 352 | add_eos=np.array([0], dtype=np.bool_), |
| 353 | reverse=np.array([0], dtype=np.bool_)) |
| 354 | del inputs['model'] |
| 355 | txout = sess.run(None, inputs) |
| 356 | assert_almost_equal( |
| 357 | txout[0], np.array([[17486, 1017, -1, -1], [17486, 1017, 155, 21869]], dtype=np.int64)) |
| 358 | assert_almost_equal( |
| 359 | txout[1], np.array([17486, 1017, 17486, 1017, 155, 21869], dtype=np.int32)) |
| 360 | |
| 361 | def test_string_sentencepiece_tokenizer(self): |
| 362 | so = _ort.SessionOptions() |
| 363 | so.register_custom_ops_library(_get_library_path()) |
| 364 | model, model_b64 = load_piece('model__6') |
| 365 | py_onnx_model = _create_test_model_sentencepiece('Py', None) |
| 366 | self.assertIn( |
| 367 | 'op_type: "PySentencepieceTokenizer"', str(py_onnx_model)) |
| 368 | cc_onnx_model = _create_test_model_sentencepiece('', model_b64) |
| 369 | self.assertIn('op_type: "SentencepieceTokenizer"', str(cc_onnx_model)) |
| 370 | py_sess = _ort.InferenceSession(py_onnx_model.SerializeToString(), so) |
| 371 | cc_sess = _ort.InferenceSession(cc_onnx_model.SerializeToString(), so) |
| 372 | |
| 373 | for alpha in [0, 0.5]: |
| 374 | for nbest_size in [0, 1]: |
| 375 | for bools in range(0, 8): |
| 376 | with self.subTest( |
| 377 | alpha=alpha, nbest_size=nbest_size, bools=bools): |
| 378 | inputs = dict( |
| 379 | model=model, |
| 380 | inputs=np.array( |
| 381 | ["Hello world", "Hello world louder"], |
| 382 | dtype=np.object), |
| 383 | nbest_size=np.array( |
| 384 | [nbest_size], dtype=np.int64), |
| 385 | alpha=np.array([alpha], dtype=np.float32), |
| 386 | add_bos=np.array([bools & 1], dtype=np.bool_), |
| 387 | add_eos=np.array([bools & 2], dtype=np.bool_), |
| 388 | reverse=np.array([bools & 4], dtype=np.bool_)) |
| 389 | exp = self.SentencepieceTokenizer(**inputs) |
| 390 | py_txout = py_sess.run(None, inputs) |
| 391 | del inputs['model'] |
| 392 | cc_txout = cc_sess.run(None, inputs) |
| 393 | for i in range(0, 2): |
| 394 | assert_almost_equal(exp[i], py_txout[i]) |
| 395 | assert_almost_equal(exp[i], cc_txout[i]) |
| 396 | |
| 397 | def test_string_sentencepiece_tokenizer_bin(self): |
| 398 | so = _ort.SessionOptions() |
| 399 | so.register_custom_ops_library(_get_library_path()) |
| 400 | model, model_b64 = load_piece('model__6') |
| 401 | modelb = bytes(model) |
| 402 | py_onnx_model = _create_test_model_sentencepiece('Py', None) |
| 403 | self.assertIn( |
| 404 | 'op_type: "PySentencepieceTokenizer"', str(py_onnx_model)) |
| 405 | cc_onnx_model = _create_test_model_sentencepiece('', modelb) |
| 406 | self.assertIn('op_type: "SentencepieceTokenizer"', str(cc_onnx_model)) |
| 407 | py_sess = _ort.InferenceSession(py_onnx_model.SerializeToString(), so) |
| 408 | cc_sess = _ort.InferenceSession(cc_onnx_model.SerializeToString(), so) |
| 409 | |
| 410 | alpha = 0 |
| 411 | nbest_size = 0 |
| 412 | bools = 0 |
| 413 | inputs = dict( |
| 414 | model=model, |
| 415 | inputs=np.array( |
| 416 | ["Hello world", "Hello world louder"], |
| 417 | dtype=np.object), |
| 418 | nbest_size=np.array( |
| 419 | [nbest_size], dtype=np.int64), |
| 420 | alpha=np.array([alpha], dtype=np.float32), |
| 421 | add_bos=np.array([bools & 1], dtype=np.bool_), |
| 422 | add_eos=np.array([bools & 2], dtype=np.bool_), |
| 423 | reverse=np.array([bools & 4], dtype=np.bool_)) |
| 424 | exp = self.SentencepieceTokenizer(**inputs) |
| 425 | py_txout = py_sess.run(None, inputs) |
| 426 | del inputs['model'] |
| 427 | cc_txout = cc_sess.run(None, inputs) |
| 428 | for i in range(0, 2): |
| 429 | assert_almost_equal(exp[i], py_txout[i]) |
| 430 | assert_almost_equal(exp[i], cc_txout[i]) |
| 431 | |
| 432 | |
| 433 | if __name__ == "__main__": |
| 434 | unittest.main() |
| 435 | |