microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
bfbfa5a3044ec8d1312f3782c78ea3b9246bf667

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

test/test_math_ops.py

88lines · modecode

1# coding: utf-8
2import unittest
3import numpy as np
4from onnx import helper, onnx_pb as onnx_proto
5import onnxruntime as _ort
6from onnxruntime_extensions import (
7 onnx_op, PyCustomOpDef, make_onnx_model,
8 get_library_path as _get_library_path)
9
10
11def _create_test_model_segment_sum(prefix, domain='ai.onnx.contrib'):
12 nodes = []
13 nodes.append(helper.make_node('Identity', ['data'], ['id1']))
14 nodes.append(helper.make_node('Identity', ['segment_ids'], ['id2']))
15 nodes.append(
16 helper.make_node(
17 '%sSegmentSum' % prefix, ['id1', 'id2'], ['z'], domain=domain))
18
19 input0 = helper.make_tensor_value_info(
20 'data', onnx_proto.TensorProto.FLOAT, [])
21 input1 = helper.make_tensor_value_info(
22 'segment_ids', onnx_proto.TensorProto.INT64, [])
23 output0 = helper.make_tensor_value_info(
24 'z', onnx_proto.TensorProto.FLOAT, [])
25
26 graph = helper.make_graph(nodes, 'test0', [input0, input1], [output0])
27 model = make_onnx_model(graph)
28 return model
29
30
31class TestMathOpString(unittest.TestCase):
32
33 @classmethod
34 def setUpClass(cls):
35
36 @onnx_op(op_type="PySegmentSum",
37 inputs=[PyCustomOpDef.dt_float, PyCustomOpDef.dt_int64],
38 outputs=[PyCustomOpDef.dt_float])
39 def segment_sum(data, segment_ids):
40 # segment_ids is sorted
41 nb_seg = segment_ids[-1] + 1
42 sh = (nb_seg, ) + data.shape[1:]
43 res = np.zeros(sh, dtype=data.dtype)
44 for seg, row in zip(segment_ids, data):
45 res[seg] += row
46 return res
47
48 def test_segment_sum_cc(self):
49 so = _ort.SessionOptions()
50 so.register_custom_ops_library(_get_library_path())
51 onnx_model = _create_test_model_segment_sum('')
52 self.assertIn('op_type: "SegmentSum"', str(onnx_model))
53 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
54 data = np.array([[1, 2, 3, 4], [4, 3, 2, 1], [5, 6, 7, 8]],
55 dtype=np.float32)
56 segment_ids = np.array([0, 0, 1], dtype=np.int64)
57 exp = np.array([[5, 5, 5, 5], [5, 6, 7, 8]], dtype=np.float32)
58 txout = sess.run(None, {'data': data, 'segment_ids': segment_ids})
59 self.assertEqual(exp.shape, txout[0].shape)
60 self.assertEqual(exp.tolist(), txout[0].tolist())
61
62 def test_segment_sum_python(self):
63 so = _ort.SessionOptions()
64 so.register_custom_ops_library(_get_library_path())
65 onnx_model = _create_test_model_segment_sum('Py')
66 self.assertIn('op_type: "PySegmentSum"', str(onnx_model))
67 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
68 data = np.array([[1, 2, 3, 4], [4, 3, 2, 1], [5, 6, 7, 8]],
69 dtype=np.float32)
70 segment_ids = np.array([0, 0, 1], dtype=np.int64)
71 exp = np.array([[5, 5, 5, 5], [5, 6, 7, 8]], dtype=np.float32)
72 txout = sess.run(None, {'data': data, 'segment_ids': segment_ids})
73 self.assertEqual(exp.shape, txout[0].shape)
74 self.assertEqual(exp.tolist(), txout[0].tolist())
75
76 try:
77 from tensorflow.raw_ops import SegmentSum
78 dotf = True
79 except ImportError:
80 dotf = False
81 if dotf:
82 tfres = SegmentSum(data=data, segment_ids=segment_ids)
83 self.assertEqual(tfres.shape, txout[0].shape)
84 self.assertEqual(tfres.numpy().tolist(), txout[0].tolist())
85
86
87if __name__ == "__main__":
88 unittest.main()
89