microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
ba200b4a0e391b45c0df4f9b1a506f0a9f574dd4

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

test/test_string_length.py

65lines · modecode

1import os
2from pathlib import Path
3import unittest
4import numpy as np
5from onnx import helper, onnx_pb as onnx_proto
6import onnxruntime as _ort
7from onnxruntime_extensions import (
8 onnx_op,
9 enable_custom_op,
10 PyCustomOpDef,
11 expand_onnx_inputs,
12 get_library_path as _get_library_path)
13
14
15def _create_test_model(intput_dims, output_dims):
16 nodes = []
17 nodes[0:] = [helper.make_node('Identity', ['input_1'], ['identity1'])]
18 nodes[1:] = [helper.make_node(
19 'StringLength', ['identity1'], ['customout'], domain='ai.onnx.contrib')]
20
21 input0 = helper.make_tensor_value_info(
22 'input_1', onnx_proto.TensorProto.STRING, [None] * intput_dims)
23 output0 = helper.make_tensor_value_info(
24 'customout', onnx_proto.TensorProto.INT64, [None] * output_dims)
25
26 graph = helper.make_graph(nodes, 'test0', [input0], [output0])
27 model = helper.make_model(
28 graph, opset_imports=[helper.make_operatorsetid("", 12)])
29 return model
30
31
32def _run_string_length(input):
33 model = _create_test_model(input.ndim, input.ndim)
34
35 so = _ort.SessionOptions()
36 so.register_custom_ops_library(_get_library_path())
37 sess = _ort.InferenceSession(model.SerializeToString(), so)
38 result = sess.run(None, {'input_1': input})
39
40 # verify
41 output = np.array([len(elem) for elem in input.flatten()], dtype=np.int64).reshape(input.shape)
42 np.testing.assert_array_equal(result, [output])
43
44
45class TestStringLength(unittest.TestCase):
46 @classmethod
47 def setUpClass(cls):
48 pass
49
50 def test_vector_to_(self):
51 _run_string_length(input=np.array(["a", "ab", "abc", "abcd"]))
52 _run_string_length(input=np.array([" \t\n", ",.", "~!@#$%^&*()_+{}|:\"<>?[]\\;',./", "1234567890"]))
53 _run_string_length(input=np.array([["we", "test", "whether"], ["it", "could", "output"], ["the", "same", "shape"]]))
54 _run_string_length(input=np.array(["d(・`ω´・d*)", "(σ`д′)σ", "(;゚∀゚)=3ハァハァ", "(´థ౪థ)σ", "||ヽ(* ̄▽ ̄*)ノミ|Ю"]))
55 _run_string_length(input=np.array(["你", "好", "这是一个", "测试", ]))
56 _run_string_length(input=np.array(["すみません、これはテストです。"]))
57 _run_string_length(input=np.array(["Bonjour, c'est un test."]))
58 _run_string_length(input=np.array(["Hallo, das ist ein Test."]))
59 _run_string_length(input=np.array([" مرحبا هذا هو اختبار "]))
60 _run_string_length(input=np.array(["👾 🤖 🎃 😺 😸 😹 😻 😼 😽 🙀 😿 😾"]))
61 _run_string_length(input=np.array(["龖龘讋"]))
62
63
64if __name__ == "__main__":
65 unittest.main()