microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
352b2003bc6c4604f4285d64133d8bdf11549253

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

test/test_math_ops.py

89lines · modecode

1# coding: utf-8
2import unittest
3import numpy as np
4from onnx import helper, onnx_pb as onnx_proto
5import onnxruntime as _ort
6from onnxruntime_extensions import (
7 onnx_op, PyCustomOpDef,
8 get_library_path as _get_library_path)
9
10
11def _create_test_model_segment_sum(prefix, domain='ai.onnx.contrib'):
12 nodes = []
13 nodes.append(helper.make_node('Identity', ['data'], ['id1']))
14 nodes.append(helper.make_node('Identity', ['segment_ids'], ['id2']))
15 nodes.append(
16 helper.make_node(
17 '%sSegmentSum' % prefix, ['id1', 'id2'], ['z'], domain=domain))
18
19 input0 = helper.make_tensor_value_info(
20 'data', onnx_proto.TensorProto.FLOAT, [])
21 input1 = helper.make_tensor_value_info(
22 'segment_ids', onnx_proto.TensorProto.INT64, [])
23 output0 = helper.make_tensor_value_info(
24 'z', onnx_proto.TensorProto.FLOAT, [])
25
26 graph = helper.make_graph(nodes, 'test0', [input0, input1], [output0])
27 model = helper.make_model(
28 graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
29 return model
30
31
32class TestMathOpString(unittest.TestCase):
33
34 @classmethod
35 def setUpClass(cls):
36
37 @onnx_op(op_type="PySegmentSum",
38 inputs=[PyCustomOpDef.dt_float, PyCustomOpDef.dt_int64],
39 outputs=[PyCustomOpDef.dt_float])
40 def segment_sum(data, segment_ids):
41 # segment_ids is sorted
42 nb_seg = segment_ids[-1] + 1
43 sh = (nb_seg, ) + data.shape[1:]
44 res = np.zeros(sh, dtype=data.dtype)
45 for seg, row in zip(segment_ids, data):
46 res[seg] += row
47 return res
48
49 def test_segment_sum_cc(self):
50 so = _ort.SessionOptions()
51 so.register_custom_ops_library(_get_library_path())
52 onnx_model = _create_test_model_segment_sum('')
53 self.assertIn('op_type: "SegmentSum"', str(onnx_model))
54 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
55 data = np.array([[1, 2, 3, 4], [4, 3, 2, 1], [5, 6, 7, 8]],
56 dtype=np.float32)
57 segment_ids = np.array([0, 0, 1], dtype=np.int64)
58 exp = np.array([[5, 5, 5, 5], [5, 6, 7, 8]], dtype=np.float32)
59 txout = sess.run(None, {'data': data, 'segment_ids': segment_ids})
60 self.assertEqual(exp.shape, txout[0].shape)
61 self.assertEqual(exp.tolist(), txout[0].tolist())
62
63 def test_segment_sum_python(self):
64 so = _ort.SessionOptions()
65 so.register_custom_ops_library(_get_library_path())
66 onnx_model = _create_test_model_segment_sum('Py')
67 self.assertIn('op_type: "PySegmentSum"', str(onnx_model))
68 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
69 data = np.array([[1, 2, 3, 4], [4, 3, 2, 1], [5, 6, 7, 8]],
70 dtype=np.float32)
71 segment_ids = np.array([0, 0, 1], dtype=np.int64)
72 exp = np.array([[5, 5, 5, 5], [5, 6, 7, 8]], dtype=np.float32)
73 txout = sess.run(None, {'data': data, 'segment_ids': segment_ids})
74 self.assertEqual(exp.shape, txout[0].shape)
75 self.assertEqual(exp.tolist(), txout[0].tolist())
76
77 try:
78 from tensorflow.raw_ops import SegmentSum
79 dotf = True
80 except ImportError:
81 dotf = False
82 if dotf:
83 tfres = SegmentSum(data=data, segment_ids=segment_ids)
84 self.assertEqual(tfres.shape, txout[0].shape)
85 self.assertEqual(tfres.numpy().tolist(), txout[0].tolist())
86
87
88if __name__ == "__main__":
89 unittest.main()
90