microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
test/test_string_length.py
65lines · modecode
| 1 | import os |
| 2 | from pathlib import Path |
| 3 | import unittest |
| 4 | import numpy as np |
| 5 | from onnx import helper, onnx_pb as onnx_proto |
| 6 | import onnxruntime as _ort |
| 7 | from onnxruntime_extensions import ( |
| 8 | onnx_op, |
| 9 | enable_custom_op, |
| 10 | PyCustomOpDef, |
| 11 | expand_onnx_inputs, |
| 12 | get_library_path as _get_library_path) |
| 13 | |
| 14 | |
| 15 | def _create_test_model(intput_dims, output_dims): |
| 16 | nodes = [] |
| 17 | nodes[0:] = [helper.make_node('Identity', ['input_1'], ['identity1'])] |
| 18 | nodes[1:] = [helper.make_node( |
| 19 | 'StringLength', ['identity1'], ['customout'], domain='ai.onnx.contrib')] |
| 20 | |
| 21 | input0 = helper.make_tensor_value_info( |
| 22 | 'input_1', onnx_proto.TensorProto.STRING, [None] * intput_dims) |
| 23 | output0 = helper.make_tensor_value_info( |
| 24 | 'customout', onnx_proto.TensorProto.INT64, [None] * output_dims) |
| 25 | |
| 26 | graph = helper.make_graph(nodes, 'test0', [input0], [output0]) |
| 27 | model = helper.make_model( |
| 28 | graph, opset_imports=[helper.make_operatorsetid("", 12)]) |
| 29 | return model |
| 30 | |
| 31 | |
| 32 | def _run_string_length(input): |
| 33 | model = _create_test_model(input.ndim, input.ndim) |
| 34 | |
| 35 | so = _ort.SessionOptions() |
| 36 | so.register_custom_ops_library(_get_library_path()) |
| 37 | sess = _ort.InferenceSession(model.SerializeToString(), so) |
| 38 | result = sess.run(None, {'input_1': input}) |
| 39 | |
| 40 | # verify |
| 41 | output = np.array([len(elem) for elem in input.flatten()], dtype=np.int64).reshape(input.shape) |
| 42 | np.testing.assert_array_equal(result, [output]) |
| 43 | |
| 44 | |
| 45 | class TestStringLength(unittest.TestCase): |
| 46 | @classmethod |
| 47 | def setUpClass(cls): |
| 48 | pass |
| 49 | |
| 50 | def test_vector_to_(self): |
| 51 | _run_string_length(input=np.array(["a", "ab", "abc", "abcd"])) |
| 52 | _run_string_length(input=np.array([" \t\n", ",.", "~!@#$%^&*()_+{}|:\"<>?[]\\;',./", "1234567890"])) |
| 53 | _run_string_length(input=np.array([["we", "test", "whether"], ["it", "could", "output"], ["the", "same", "shape"]])) |
| 54 | _run_string_length(input=np.array(["d(・`ω´・d*)", "(σ`д′)σ", "(;゚∀゚)=3ハァハァ", "(´థ౪థ)σ", "||ヽ(* ̄▽ ̄*)ノミ|Ю"])) |
| 55 | _run_string_length(input=np.array(["你", "好", "这是一个", "测试", ])) |
| 56 | _run_string_length(input=np.array(["すみません、これはテストです。"])) |
| 57 | _run_string_length(input=np.array(["Bonjour, c'est un test."])) |
| 58 | _run_string_length(input=np.array(["Hallo, das ist ein Test."])) |
| 59 | _run_string_length(input=np.array([" مرحبا هذا هو اختبار "])) |
| 60 | _run_string_length(input=np.array(["👾 🤖 🎃 😺 😸 😹 😻 😼 😽 🙀 😿 😾"])) |
| 61 | _run_string_length(input=np.array(["龖龘讋"])) |
| 62 | |
| 63 | |
| 64 | if __name__ == "__main__": |
| 65 | unittest.main() |