microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
test/test_torch_ops.py
92lines · modecode
| 1 | import io |
| 2 | import onnx |
| 3 | import numpy |
| 4 | import unittest |
| 5 | import platform |
| 6 | import torch, torchvision |
| 7 | import onnxruntime as _ort |
| 8 | |
| 9 | from onnx import load |
| 10 | from torch.onnx import register_custom_op_symbolic |
| 11 | from onnxruntime_customops import ( |
| 12 | onnx_op, |
| 13 | hook_model_op, |
| 14 | get_library_path as _get_library_path) |
| 15 | |
| 16 | |
| 17 | def my_inverse(g, self): |
| 18 | return g.op("ai.onnx.contrib::Inverse", self) |
| 19 | |
| 20 | |
| 21 | class CustomInverse(torch.nn.Module): |
| 22 | def forward(self, x): |
| 23 | return torch.inverse(x) + x |
| 24 | |
| 25 | |
| 26 | class TestPyTorchCustomOp(unittest.TestCase): |
| 27 | |
| 28 | _hooked = False |
| 29 | |
| 30 | @classmethod |
| 31 | def setUpClass(cls): |
| 32 | |
| 33 | @onnx_op(op_type="Inverse") |
| 34 | def inverse(x): |
| 35 | # the user custom op implementation here: |
| 36 | return numpy.linalg.inv(x) |
| 37 | |
| 38 | def test_custom_pythonop_pytorch(self): |
| 39 | |
| 40 | # register_custom_op_symbolic( |
| 41 | # '<namespace>::inverse', my_inverse, <opset_version>) |
| 42 | register_custom_op_symbolic('::inverse', my_inverse, 1) |
| 43 | |
| 44 | x = torch.randn(3, 3) |
| 45 | |
| 46 | # Export model to ONNX |
| 47 | f = io.BytesIO() |
| 48 | torch.onnx.export(CustomInverse(), (x,), f) |
| 49 | onnx_model = load(io.BytesIO(f.getvalue())) |
| 50 | self.assertIn('domain: "ai.onnx.contrib"', str(onnx_model)) |
| 51 | |
| 52 | model = CustomInverse() |
| 53 | pt_outputs = model(x) |
| 54 | |
| 55 | so = _ort.SessionOptions() |
| 56 | so.register_custom_ops_library(_get_library_path()) |
| 57 | |
| 58 | # Run the exported model with ONNX Runtime |
| 59 | ort_sess = _ort.InferenceSession(f.getvalue(), so) |
| 60 | ort_inputs = dict((ort_sess.get_inputs()[i].name, input.cpu().numpy()) |
| 61 | for i, input in enumerate((x,))) |
| 62 | ort_outputs = ort_sess.run(None, ort_inputs) |
| 63 | |
| 64 | # Validate PyTorch and ONNX Runtime results |
| 65 | numpy.testing.assert_allclose(pt_outputs.cpu().numpy(), |
| 66 | ort_outputs[0], rtol=1e-03, atol=1e-05) |
| 67 | |
| 68 | @staticmethod |
| 69 | def on_hook(*x): |
| 70 | TestPyTorchCustomOp._hooked = True |
| 71 | return x |
| 72 | |
| 73 | @unittest.skipIf(platform.system() == 'Darwin', "pytorch.onnx crashed for this case!") |
| 74 | def test_pyop_hooking(self): # type: () -> None |
| 75 | model = torchvision.models.mobilenet_v2(pretrained=False) |
| 76 | x = torch.rand(1, 3, 224, 224) |
| 77 | with io.BytesIO() as f: |
| 78 | torch.onnx.export(model, (x, ), f) |
| 79 | model = onnx.load_model_from_string(f.getvalue()) |
| 80 | |
| 81 | hkd_model = hook_model_op(model, model.graph.node[5].name, TestPyTorchCustomOp.on_hook) |
| 82 | |
| 83 | so = _ort.SessionOptions() |
| 84 | so.register_custom_ops_library(_get_library_path()) |
| 85 | sess = _ort.InferenceSession(hkd_model.SerializeToString(), so) |
| 86 | TestPyTorchCustomOp._hooked = False |
| 87 | sess.run(None, {'input.1': x.numpy()}) |
| 88 | self.assertTrue(TestPyTorchCustomOp._hooked) |
| 89 | |
| 90 | |
| 91 | if __name__ == "__main__": |
| 92 | unittest.main() |
| 93 | |