microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
352b2003bc6c4604f4285d64133d8bdf11549253

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

test/test_string_concat.py

70lines · modecode

1import os
2from pathlib import Path
3import unittest
4import numpy as np
5from onnx import helper, onnx_pb as onnx_proto
6import onnxruntime as _ort
7from onnxruntime_extensions import (
8 onnx_op,
9 enable_custom_op,
10 PyCustomOpDef,
11 expand_onnx_inputs,
12 get_library_path as _get_library_path)
13
14
15def _create_test_model(input_dims, output_dims):
16 nodes = []
17 nodes[0:] = [helper.make_node('Identity', ['input_1'], ['left'])]
18 nodes[1:] = [helper.make_node('Identity', ['input_2'], ['right'])]
19 nodes[2:] = [helper.make_node(
20 'StringConcat', ['left', 'right'], ['output'], domain='ai.onnx.contrib')]
21
22 input1 = helper.make_tensor_value_info(
23 'input_1', onnx_proto.TensorProto.STRING, [None] * input_dims)
24 input2 = helper.make_tensor_value_info(
25 'input_2', onnx_proto.TensorProto.STRING, [None] * input_dims)
26 output = helper.make_tensor_value_info(
27 'output', onnx_proto.TensorProto.STRING, [None] * output_dims)
28
29 graph = helper.make_graph(nodes, 'test0', [input1, input2], [output])
30 model = helper.make_model(
31 graph, opset_imports=[helper.make_operatorsetid("", 12)])
32 return model
33
34
35def _run_string_concat(input1, input2):
36 model = _create_test_model(input1.ndim, input1.ndim)
37
38 so = _ort.SessionOptions()
39 so.register_custom_ops_library(_get_library_path())
40 sess = _ort.InferenceSession(model.SerializeToString(), so)
41 result = sess.run(None, {'input_1': input1, 'input_2': input2})
42
43 # verify
44 output = []
45 shape = input1.shape
46 input1 = input1.flatten()
47 input2 = input2.flatten()
48 for i in range(len(input1)):
49 output.append(input1[i] + input2[i])
50 output = np.array(output).reshape(shape)
51 np.testing.assert_array_equal(result, [output])
52
53
54class TestStringConcat(unittest.TestCase):
55 @classmethod
56 def setUpClass(cls):
57 pass
58
59 def test_string_concat(self):
60 _run_string_concat(np.array(["a"]), np.array(["b"]))
61 _run_string_concat(np.array(["a", "b", "c", "d"]), np.array(["d", "c", "b", "a"]))
62 _run_string_concat(np.array([["a", "b"], ["c", "d"]]), np.array([["d", "c"], ["b", "a"]]))
63 _run_string_concat(np.array(["你好"]), np.array(["这是一个测试"]))
64 _run_string_concat(np.array(["すみません"]), np.array(["これはテストです"]))
65 _run_string_concat(np.array(["👾 🤖 🎃 😺 😸 😹"]), np.array(["😻 😼 😽 🙀 😿 😾"]))
66 _run_string_concat(np.array(["龖"]), np.array(["龘讋"]))
67
68
69if __name__ == "__main__":
70 unittest.main()