microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
3e82549bcb53bc93d04c25d64f16a52d495725c1

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

test/test_onnxprocess.py

86lines · modecode

1import io
2import onnx
3import unittest
4import torchvision
5import numpy as np
6from onnxruntime_extensions import eager_op, hook_model_op, PyOp
7from onnxruntime_extensions.onnxprocess import torch_wrapper as torch
8from onnxruntime_extensions.onnxprocess import trace_for_onnx, pyfunc_from_model
9
10
11class TestTorchE2E(unittest.TestCase):
12 @classmethod
13 def setUpClass(cls):
14 cls.mobilenet = torchvision.models.mobilenet_v2(pretrained=True)
15 cls.argmax_input = None
16
17 @staticmethod
18 def on_hook(*x):
19 TestTorchE2E.argmax_input = x[0]
20 return x
21
22 def test_range(self):
23 num = 10
24
25 f = io.BytesIO()
26 with trace_for_onnx(num, names=['count']) as tc_sess:
27 num_in = tc_sess.get_inputs()[0]
28 done = torch.tensor(True)
29 st_0 = torch.tensor(0)
30 cfg = torch.control_flow()
31 for _ in cfg.loop(num_in, done, st_0):
32 iter_num, *v = _
33 cfg.flow_output(done, st_0, iter_num + 0)
34
35 *_, rout = cfg.finalize()
36 tc_sess.save_as_onnx(f, rout)
37
38 m = onnx.load_model_from_string(f.getvalue())
39 onnx.save_model(m, 'temp_range.onnx')
40 fu_m = eager_op.EagerOp.from_model(m)
41 result = fu_m(num)
42 np.testing.assert_array_equal(result, np.array(range(num)))
43
44 def test_sequence(self):
45 input_text = ['test sentence', 'sentence 2']
46 f = io.BytesIO()
47 with trace_for_onnx(input_text, names=['in_text']) as tc_sess:
48 tc_inputs = tc_sess.get_inputs()[0]
49 batchsize = tc_inputs.size()[0]
50 shape = [batchsize, 2]
51 fuse_output = torch.zeros(*shape).size()
52 tc_sess.save_as_onnx(f, fuse_output)
53
54 m = onnx.load_model_from_string(f.getvalue())
55 onnx.save_model(m, 'temp_test00.onnx')
56 fu_m = eager_op.EagerOp.from_model(m)
57 result = fu_m(input_text)
58 np.testing.assert_array_equal(result, [2, 2])
59
60 def test_imagenet_postprocess(self):
61 mb_core_path = 'temp_mobilev2.onnx'
62 mb_full_path = 'temp_mobilev2_full.onnx'
63 dummy_input = torch.randn(10, 3, 224, 224)
64 np_input = dummy_input.numpy()
65 torch.onnx.export(self.mobilenet, dummy_input, mb_core_path, opset_version=11)
66 mbnet2 = pyfunc_from_model(mb_core_path)
67
68 with trace_for_onnx(dummy_input, names=['b10_input']) as tc_sess:
69 scores = mbnet2(*tc_sess.get_inputs())
70 probabilities = torch.softmax(scores, dim=1)
71 batch_top1 = probabilities.argmax(dim=1)
72
73 np_argmax = probabilities.numpy() # for the result comparison
74 np_output = batch_top1.numpy()
75
76 tc_sess.save_as_onnx(mb_full_path, batch_top1)
77
78 hkdmdl = hook_model_op(onnx.load_model(mb_full_path), 'argmax', self.on_hook, [PyOp.dt_float])
79 mbnet2_full = eager_op.EagerOp.from_model(hkdmdl)
80 batch_top1_2 = mbnet2_full(np_input)
81 np.testing.assert_allclose(np_argmax, self.argmax_input, rtol=1e-5)
82 np.testing.assert_array_equal(batch_top1_2, np_output)
83
84
85if __name__ == "__main__":
86 unittest.main()
87