microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
b3a300d7bf6f3e99fb907e682f7246ed5ed5805e

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

test/test_torch_ops.py

65lines · modecode

1import unittest
2from onnx import load
3import torch
4import onnxruntime as _ort
5import io
6import numpy
7
8from torch.onnx import register_custom_op_symbolic
9from onnxruntime_customops import (
10 onnx_op,
11 get_library_path as _get_library_path)
12
13
14def my_inverse(g, self):
15 return g.op("ai.onnx.contrib::Inverse", self)
16
17
18class CustomInverse(torch.nn.Module):
19 def forward(self, x):
20 return torch.inverse(x) + x
21
22
23class TestPyTorchCustomOp(unittest.TestCase):
24
25 @classmethod
26 def setUpClass(cls):
27
28 @onnx_op(op_type="Inverse")
29 def inverse(x):
30 # the user custom op implementation here:
31 return numpy.linalg.inv(x)
32
33 def test_custom_pythonop_pytorch(self):
34
35 # register_custom_op_symbolic(
36 # '<namespace>::inverse', my_inverse, <opset_version>)
37 register_custom_op_symbolic('::inverse', my_inverse, 1)
38
39 x = torch.randn(3, 3)
40
41 # Export model to ONNX
42 f = io.BytesIO()
43 torch.onnx.export(CustomInverse(), (x,), f)
44 onnx_model = load(io.BytesIO(f.getvalue()))
45 self.assertIn('domain: "ai.onnx.contrib"', str(onnx_model))
46
47 model = CustomInverse()
48 pt_outputs = model(x)
49
50 so = _ort.SessionOptions()
51 so.register_custom_ops_library(_get_library_path())
52
53 # Run the exported model with ONNX Runtime
54 ort_sess = _ort.InferenceSession(f.getvalue(), so)
55 ort_inputs = dict((ort_sess.get_inputs()[i].name, input.cpu().numpy())
56 for i, input in enumerate((x,)))
57 ort_outputs = ort_sess.run(None, ort_inputs)
58
59 # Validate PyTorch and ONNX Runtime results
60 numpy.testing.assert_allclose(pt_outputs.cpu().numpy(),
61 ort_outputs[0], rtol=1e-03, atol=1e-05)
62
63
64if __name__ == "__main__":
65 unittest.main()
66