microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
1f0c76cefaa9359b9fc74ff4b9ed93214486205a

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/cmd.py

57lines · modecode

1import os
2import argparse
3import onnx
4import numpy
5
6from onnx import onnx_pb, save_tensor, numpy_helper
7from ._ortapi2 import OrtPyFunction
8
9
10class ORTExtCommands:
11 def __init__(self, model='model.onnx', testdata_dir=None) -> None:
12 self._model = model
13 self._testdata_dir = testdata_dir
14
15 def run(self, *args):
16 """
17 Run an onnx model with the arguments as its inputs
18 """
19 op_func = OrtPyFunction.from_model(self._model)
20 np_args = [numpy.asarray(_x) for _x in args]
21 for _idx, _sch in enumerate(op_func.inputs):
22 if _sch.type.tensor_type.elem_type == onnx_pb.TensorProto.FLOAT:
23 np_args[_idx] = np_args[_idx].astype(numpy.float32)
24
25 print(op_func(*np_args))
26 if self._testdata_dir:
27 testdir = os.path.expanduser(self._testdata_dir)
28 target_dir = os.path.join(testdir, 'test_data_set_0')
29 os.makedirs(target_dir, exist_ok=True)
30 for _idx, _x in enumerate(np_args):
31 fn = os.path.join(target_dir, "input_{}.pb".format(_idx))
32 save_tensor(numpy_helper.from_array(_x, op_func.inputs[_idx].name), fn)
33 onnx.save_model(op_func.onnx_model, os.path.join(testdir, 'model.onnx'))
34
35 def selfcheck(self, *args):
36 print("The extensions loaded, status: OK.")
37
38
39def main():
40 parser = argparse.ArgumentParser(description="ORT Extension commands")
41 parser.add_argument("command", choices=["run", "selfcheck"])
42 parser.add_argument("--model", default="model.onnx", help="Path to the ONNX model file")
43 parser.add_argument("--testdata-dir", help="Path to the test data directory")
44 parser.add_argument("args", nargs=argparse.REMAINDER, help="Additional arguments")
45
46 args = parser.parse_args()
47
48 ort_commands = ORTExtCommands(model=args.model, testdata_dir=args.testdata_dir)
49
50 if args.command == "run":
51 ort_commands.run(*args.args)
52 elif args.command == "selfcheck":
53 ort_commands.selfcheck(*args.args)
54
55
56if __name__ == '__main__':
57 main()
58