microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
test/test_math_ops.py
89lines · modecode
| 1 | # coding: utf-8 |
| 2 | import unittest |
| 3 | import numpy as np |
| 4 | from onnx import helper, onnx_pb as onnx_proto |
| 5 | import onnxruntime as _ort |
| 6 | from onnxruntime_extensions import ( |
| 7 | onnx_op, PyCustomOpDef, |
| 8 | get_library_path as _get_library_path) |
| 9 | |
| 10 | |
| 11 | def _create_test_model_segment_sum(prefix, domain='ai.onnx.contrib'): |
| 12 | nodes = [] |
| 13 | nodes.append(helper.make_node('Identity', ['data'], ['id1'])) |
| 14 | nodes.append(helper.make_node('Identity', ['segment_ids'], ['id2'])) |
| 15 | nodes.append( |
| 16 | helper.make_node( |
| 17 | '%sSegmentSum' % prefix, ['id1', 'id2'], ['z'], domain=domain)) |
| 18 | |
| 19 | input0 = helper.make_tensor_value_info( |
| 20 | 'data', onnx_proto.TensorProto.FLOAT, []) |
| 21 | input1 = helper.make_tensor_value_info( |
| 22 | 'segment_ids', onnx_proto.TensorProto.INT64, []) |
| 23 | output0 = helper.make_tensor_value_info( |
| 24 | 'z', onnx_proto.TensorProto.FLOAT, []) |
| 25 | |
| 26 | graph = helper.make_graph(nodes, 'test0', [input0, input1], [output0]) |
| 27 | model = helper.make_model( |
| 28 | graph, opset_imports=[helper.make_operatorsetid(domain, 1)]) |
| 29 | return model |
| 30 | |
| 31 | |
| 32 | class TestMathOpString(unittest.TestCase): |
| 33 | |
| 34 | @classmethod |
| 35 | def setUpClass(cls): |
| 36 | |
| 37 | @onnx_op(op_type="PySegmentSum", |
| 38 | inputs=[PyCustomOpDef.dt_float, PyCustomOpDef.dt_int64], |
| 39 | outputs=[PyCustomOpDef.dt_float]) |
| 40 | def segment_sum(data, segment_ids): |
| 41 | # segment_ids is sorted |
| 42 | nb_seg = segment_ids[-1] + 1 |
| 43 | sh = (nb_seg, ) + data.shape[1:] |
| 44 | res = np.zeros(sh, dtype=data.dtype) |
| 45 | for seg, row in zip(segment_ids, data): |
| 46 | res[seg] += row |
| 47 | return res |
| 48 | |
| 49 | def test_segment_sum_cc(self): |
| 50 | so = _ort.SessionOptions() |
| 51 | so.register_custom_ops_library(_get_library_path()) |
| 52 | onnx_model = _create_test_model_segment_sum('') |
| 53 | self.assertIn('op_type: "SegmentSum"', str(onnx_model)) |
| 54 | sess = _ort.InferenceSession(onnx_model.SerializeToString(), so) |
| 55 | data = np.array([[1, 2, 3, 4], [4, 3, 2, 1], [5, 6, 7, 8]], |
| 56 | dtype=np.float32) |
| 57 | segment_ids = np.array([0, 0, 1], dtype=np.int64) |
| 58 | exp = np.array([[5, 5, 5, 5], [5, 6, 7, 8]], dtype=np.float32) |
| 59 | txout = sess.run(None, {'data': data, 'segment_ids': segment_ids}) |
| 60 | self.assertEqual(exp.shape, txout[0].shape) |
| 61 | self.assertEqual(exp.tolist(), txout[0].tolist()) |
| 62 | |
| 63 | def test_segment_sum_python(self): |
| 64 | so = _ort.SessionOptions() |
| 65 | so.register_custom_ops_library(_get_library_path()) |
| 66 | onnx_model = _create_test_model_segment_sum('Py') |
| 67 | self.assertIn('op_type: "PySegmentSum"', str(onnx_model)) |
| 68 | sess = _ort.InferenceSession(onnx_model.SerializeToString(), so) |
| 69 | data = np.array([[1, 2, 3, 4], [4, 3, 2, 1], [5, 6, 7, 8]], |
| 70 | dtype=np.float32) |
| 71 | segment_ids = np.array([0, 0, 1], dtype=np.int64) |
| 72 | exp = np.array([[5, 5, 5, 5], [5, 6, 7, 8]], dtype=np.float32) |
| 73 | txout = sess.run(None, {'data': data, 'segment_ids': segment_ids}) |
| 74 | self.assertEqual(exp.shape, txout[0].shape) |
| 75 | self.assertEqual(exp.tolist(), txout[0].tolist()) |
| 76 | |
| 77 | try: |
| 78 | from tensorflow.raw_ops import SegmentSum |
| 79 | dotf = True |
| 80 | except ImportError: |
| 81 | dotf = False |
| 82 | if dotf: |
| 83 | tfres = SegmentSum(data=data, segment_ids=segment_ids) |
| 84 | self.assertEqual(tfres.shape, txout[0].shape) |
| 85 | self.assertEqual(tfres.numpy().tolist(), txout[0].tolist()) |
| 86 | |
| 87 | |
| 88 | if __name__ == "__main__": |
| 89 | unittest.main() |
| 90 | |