microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
b5ba84a185c5c85c436e744c40ac9a8cbd9d3f1f

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

test/test_torch_ops.py

92lines · modecode

1import io
2import onnx
3import numpy
4import unittest
5import platform
6import torch, torchvision
7import onnxruntime as _ort
8
9from onnx import load
10from torch.onnx import register_custom_op_symbolic
11from onnxruntime_customops import (
12 onnx_op,
13 hook_model_op,
14 get_library_path as _get_library_path)
15
16
17def my_inverse(g, self):
18 return g.op("ai.onnx.contrib::Inverse", self)
19
20
21class CustomInverse(torch.nn.Module):
22 def forward(self, x):
23 return torch.inverse(x) + x
24
25
26class TestPyTorchCustomOp(unittest.TestCase):
27
28 _hooked = False
29
30 @classmethod
31 def setUpClass(cls):
32
33 @onnx_op(op_type="Inverse")
34 def inverse(x):
35 # the user custom op implementation here:
36 return numpy.linalg.inv(x)
37
38 def test_custom_pythonop_pytorch(self):
39
40 # register_custom_op_symbolic(
41 # '<namespace>::inverse', my_inverse, <opset_version>)
42 register_custom_op_symbolic('::inverse', my_inverse, 1)
43
44 x = torch.randn(3, 3)
45
46 # Export model to ONNX
47 f = io.BytesIO()
48 torch.onnx.export(CustomInverse(), (x,), f)
49 onnx_model = load(io.BytesIO(f.getvalue()))
50 self.assertIn('domain: "ai.onnx.contrib"', str(onnx_model))
51
52 model = CustomInverse()
53 pt_outputs = model(x)
54
55 so = _ort.SessionOptions()
56 so.register_custom_ops_library(_get_library_path())
57
58 # Run the exported model with ONNX Runtime
59 ort_sess = _ort.InferenceSession(f.getvalue(), so)
60 ort_inputs = dict((ort_sess.get_inputs()[i].name, input.cpu().numpy())
61 for i, input in enumerate((x,)))
62 ort_outputs = ort_sess.run(None, ort_inputs)
63
64 # Validate PyTorch and ONNX Runtime results
65 numpy.testing.assert_allclose(pt_outputs.cpu().numpy(),
66 ort_outputs[0], rtol=1e-03, atol=1e-05)
67
68 @staticmethod
69 def on_hook(*x):
70 TestPyTorchCustomOp._hooked = True
71 return x
72
73 @unittest.skipIf(platform.system() == 'Darwin', "pytorch.onnx crashed for this case!")
74 def test_pyop_hooking(self): # type: () -> None
75 model = torchvision.models.mobilenet_v2(pretrained=False)
76 x = torch.rand(1, 3, 224, 224)
77 with io.BytesIO() as f:
78 torch.onnx.export(model, (x, ), f)
79 model = onnx.load_model_from_string(f.getvalue())
80
81 hkd_model = hook_model_op(model, model.graph.node[5].name, TestPyTorchCustomOp.on_hook)
82
83 so = _ort.SessionOptions()
84 so.register_custom_ops_library(_get_library_path())
85 sess = _ort.InferenceSession(hkd_model.SerializeToString(), so)
86 TestPyTorchCustomOp._hooked = False
87 sess.run(None, {'input.1': x.numpy()})
88 self.assertTrue(TestPyTorchCustomOp._hooked)
89
90
91if __name__ == "__main__":
92 unittest.main()
93