microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
test/test_torch_ops.py
65lines · modecode
| 1 | import unittest |
| 2 | from onnx import load |
| 3 | import torch |
| 4 | import onnxruntime as _ort |
| 5 | import io |
| 6 | import numpy |
| 7 | |
| 8 | from torch.onnx import register_custom_op_symbolic |
| 9 | from onnxruntime_customops import ( |
| 10 | onnx_op, |
| 11 | get_library_path as _get_library_path) |
| 12 | |
| 13 | |
| 14 | def my_inverse(g, self): |
| 15 | return g.op("ai.onnx.contrib::Inverse", self) |
| 16 | |
| 17 | |
| 18 | class CustomInverse(torch.nn.Module): |
| 19 | def forward(self, x): |
| 20 | return torch.inverse(x) + x |
| 21 | |
| 22 | |
| 23 | class TestPyTorchCustomOp(unittest.TestCase): |
| 24 | |
| 25 | @classmethod |
| 26 | def setUpClass(cls): |
| 27 | |
| 28 | @onnx_op(op_type="Inverse") |
| 29 | def inverse(x): |
| 30 | # the user custom op implementation here: |
| 31 | return numpy.linalg.inv(x) |
| 32 | |
| 33 | def test_custom_pythonop_pytorch(self): |
| 34 | |
| 35 | # register_custom_op_symbolic( |
| 36 | # '<namespace>::inverse', my_inverse, <opset_version>) |
| 37 | register_custom_op_symbolic('::inverse', my_inverse, 1) |
| 38 | |
| 39 | x = torch.randn(3, 3) |
| 40 | |
| 41 | # Export model to ONNX |
| 42 | f = io.BytesIO() |
| 43 | torch.onnx.export(CustomInverse(), (x,), f) |
| 44 | onnx_model = load(io.BytesIO(f.getvalue())) |
| 45 | self.assertIn('domain: "ai.onnx.contrib"', str(onnx_model)) |
| 46 | |
| 47 | model = CustomInverse() |
| 48 | pt_outputs = model(x) |
| 49 | |
| 50 | so = _ort.SessionOptions() |
| 51 | so.register_custom_ops_library(_get_library_path()) |
| 52 | |
| 53 | # Run the exported model with ONNX Runtime |
| 54 | ort_sess = _ort.InferenceSession(f.getvalue(), so) |
| 55 | ort_inputs = dict((ort_sess.get_inputs()[i].name, input.cpu().numpy()) |
| 56 | for i, input in enumerate((x,))) |
| 57 | ort_outputs = ort_sess.run(None, ort_inputs) |
| 58 | |
| 59 | # Validate PyTorch and ONNX Runtime results |
| 60 | numpy.testing.assert_allclose(pt_outputs.cpu().numpy(), |
| 61 | ort_outputs[0], rtol=1e-03, atol=1e-05) |
| 62 | |
| 63 | |
| 64 | if __name__ == "__main__": |
| 65 | unittest.main() |
| 66 | |