microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
test/test_string_concat.py
70lines · modecode
| 1 | import os |
| 2 | from pathlib import Path |
| 3 | import unittest |
| 4 | import numpy as np |
| 5 | from onnx import helper, onnx_pb as onnx_proto |
| 6 | import onnxruntime as _ort |
| 7 | from onnxruntime_extensions import ( |
| 8 | onnx_op, |
| 9 | enable_custom_op, |
| 10 | PyCustomOpDef, |
| 11 | expand_onnx_inputs, |
| 12 | get_library_path as _get_library_path) |
| 13 | |
| 14 | |
| 15 | def _create_test_model(input_dims, output_dims): |
| 16 | nodes = [] |
| 17 | nodes[0:] = [helper.make_node('Identity', ['input_1'], ['left'])] |
| 18 | nodes[1:] = [helper.make_node('Identity', ['input_2'], ['right'])] |
| 19 | nodes[2:] = [helper.make_node( |
| 20 | 'StringConcat', ['left', 'right'], ['output'], domain='ai.onnx.contrib')] |
| 21 | |
| 22 | input1 = helper.make_tensor_value_info( |
| 23 | 'input_1', onnx_proto.TensorProto.STRING, [None] * input_dims) |
| 24 | input2 = helper.make_tensor_value_info( |
| 25 | 'input_2', onnx_proto.TensorProto.STRING, [None] * input_dims) |
| 26 | output = helper.make_tensor_value_info( |
| 27 | 'output', onnx_proto.TensorProto.STRING, [None] * output_dims) |
| 28 | |
| 29 | graph = helper.make_graph(nodes, 'test0', [input1, input2], [output]) |
| 30 | model = helper.make_model( |
| 31 | graph, opset_imports=[helper.make_operatorsetid("", 12)]) |
| 32 | return model |
| 33 | |
| 34 | |
| 35 | def _run_string_concat(input1, input2): |
| 36 | model = _create_test_model(input1.ndim, input1.ndim) |
| 37 | |
| 38 | so = _ort.SessionOptions() |
| 39 | so.register_custom_ops_library(_get_library_path()) |
| 40 | sess = _ort.InferenceSession(model.SerializeToString(), so) |
| 41 | result = sess.run(None, {'input_1': input1, 'input_2': input2}) |
| 42 | |
| 43 | # verify |
| 44 | output = [] |
| 45 | shape = input1.shape |
| 46 | input1 = input1.flatten() |
| 47 | input2 = input2.flatten() |
| 48 | for i in range(len(input1)): |
| 49 | output.append(input1[i] + input2[i]) |
| 50 | output = np.array(output).reshape(shape) |
| 51 | np.testing.assert_array_equal(result, [output]) |
| 52 | |
| 53 | |
| 54 | class TestStringConcat(unittest.TestCase): |
| 55 | @classmethod |
| 56 | def setUpClass(cls): |
| 57 | pass |
| 58 | |
| 59 | def test_string_concat(self): |
| 60 | _run_string_concat(np.array(["a"]), np.array(["b"])) |
| 61 | _run_string_concat(np.array(["a", "b", "c", "d"]), np.array(["d", "c", "b", "a"])) |
| 62 | _run_string_concat(np.array([["a", "b"], ["c", "d"]]), np.array([["d", "c"], ["b", "a"]])) |
| 63 | _run_string_concat(np.array(["你好"]), np.array(["这是一个测试"])) |
| 64 | _run_string_concat(np.array(["すみません"]), np.array(["これはテストです"])) |
| 65 | _run_string_concat(np.array(["👾 🤖 🎃 😺 😸 😹"]), np.array(["😻 😼 😽 🙀 😿 😾"])) |
| 66 | _run_string_concat(np.array(["龖"]), np.array(["龘讋"])) |
| 67 | |
| 68 | |
| 69 | if __name__ == "__main__": |
| 70 | unittest.main() |