microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
test/test_onnxprocess.py
86lines · modecode
| 1 | import io |
| 2 | import onnx |
| 3 | import unittest |
| 4 | import torchvision |
| 5 | import numpy as np |
| 6 | from onnxruntime_extensions import eager_op, hook_model_op, PyOp |
| 7 | from onnxruntime_extensions.onnxprocess import torch_wrapper as torch |
| 8 | from onnxruntime_extensions.onnxprocess import trace_for_onnx, pyfunc_from_model |
| 9 | |
| 10 | |
| 11 | class TestTorchE2E(unittest.TestCase): |
| 12 | @classmethod |
| 13 | def setUpClass(cls): |
| 14 | cls.mobilenet = torchvision.models.mobilenet_v2(pretrained=True) |
| 15 | cls.argmax_input = None |
| 16 | |
| 17 | @staticmethod |
| 18 | def on_hook(*x): |
| 19 | TestTorchE2E.argmax_input = x[0] |
| 20 | return x |
| 21 | |
| 22 | def test_range(self): |
| 23 | num = 10 |
| 24 | |
| 25 | f = io.BytesIO() |
| 26 | with trace_for_onnx(num, names=['count']) as tc_sess: |
| 27 | num_in = tc_sess.get_inputs()[0] |
| 28 | done = torch.tensor(True) |
| 29 | st_0 = torch.tensor(0) |
| 30 | cfg = torch.control_flow() |
| 31 | for _ in cfg.loop(num_in, done, st_0): |
| 32 | iter_num, *v = _ |
| 33 | cfg.flow_output(done, st_0, iter_num + 0) |
| 34 | |
| 35 | *_, rout = cfg.finalize() |
| 36 | tc_sess.save_as_onnx(f, rout) |
| 37 | |
| 38 | m = onnx.load_model_from_string(f.getvalue()) |
| 39 | onnx.save_model(m, 'temp_range.onnx') |
| 40 | fu_m = eager_op.EagerOp.from_model(m) |
| 41 | result = fu_m(num) |
| 42 | np.testing.assert_array_equal(result, np.array(range(num))) |
| 43 | |
| 44 | def test_sequence(self): |
| 45 | input_text = ['test sentence', 'sentence 2'] |
| 46 | f = io.BytesIO() |
| 47 | with trace_for_onnx(input_text, names=['in_text']) as tc_sess: |
| 48 | tc_inputs = tc_sess.get_inputs()[0] |
| 49 | batchsize = tc_inputs.size()[0] |
| 50 | shape = [batchsize, 2] |
| 51 | fuse_output = torch.zeros(*shape).size() |
| 52 | tc_sess.save_as_onnx(f, fuse_output) |
| 53 | |
| 54 | m = onnx.load_model_from_string(f.getvalue()) |
| 55 | onnx.save_model(m, 'temp_test00.onnx') |
| 56 | fu_m = eager_op.EagerOp.from_model(m) |
| 57 | result = fu_m(input_text) |
| 58 | np.testing.assert_array_equal(result, [2, 2]) |
| 59 | |
| 60 | def test_imagenet_postprocess(self): |
| 61 | mb_core_path = 'temp_mobilev2.onnx' |
| 62 | mb_full_path = 'temp_mobilev2_full.onnx' |
| 63 | dummy_input = torch.randn(10, 3, 224, 224) |
| 64 | np_input = dummy_input.numpy() |
| 65 | torch.onnx.export(self.mobilenet, dummy_input, mb_core_path, opset_version=11) |
| 66 | mbnet2 = pyfunc_from_model(mb_core_path) |
| 67 | |
| 68 | with trace_for_onnx(dummy_input, names=['b10_input']) as tc_sess: |
| 69 | scores = mbnet2(*tc_sess.get_inputs()) |
| 70 | probabilities = torch.softmax(scores, dim=1) |
| 71 | batch_top1 = probabilities.argmax(dim=1) |
| 72 | |
| 73 | np_argmax = probabilities.numpy() # for the result comparison |
| 74 | np_output = batch_top1.numpy() |
| 75 | |
| 76 | tc_sess.save_as_onnx(mb_full_path, batch_top1) |
| 77 | |
| 78 | hkdmdl = hook_model_op(onnx.load_model(mb_full_path), 'argmax', self.on_hook, [PyOp.dt_float]) |
| 79 | mbnet2_full = eager_op.EagerOp.from_model(hkdmdl) |
| 80 | batch_top1_2 = mbnet2_full(np_input) |
| 81 | np.testing.assert_allclose(np_argmax, self.argmax_input, rtol=1e-5) |
| 82 | np.testing.assert_array_equal(batch_top1_2, np_output) |
| 83 | |
| 84 | |
| 85 | if __name__ == "__main__": |
| 86 | unittest.main() |
| 87 | |