microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
67c77d9fbc31110f20c226267331b4297c2d379c

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/cmd.py

40lines · modecode

1import os
2import fire
3import onnx
4import numpy
5
6from onnx import onnx_pb, save_tensor, numpy_helper
7from ._ortapi2 import OrtPyFunction
8
9
10class ORTExtCommands:
11 def __init__(self, model='model.onnx', testdata_dir=None) -> None:
12 self._model = model
13 self._testdata_dir = testdata_dir
14
15 def run(self, *args):
16 """
17 Run an onnx model with the arguments as its inputs
18 """
19 op_func = OrtPyFunction.from_model(self._model)
20 np_args = [numpy.asarray(_x) for _x in args]
21 for _idx, _sch in enumerate(op_func.inputs):
22 if _sch.type.tensor_type.elem_type == onnx_pb.TensorProto.FLOAT:
23 np_args[_idx] = np_args[_idx].astype(numpy.float32)
24
25 print(op_func(*np_args))
26 if self._testdata_dir:
27 testdir = os.path.expanduser(self._testdata_dir)
28 target_dir = os.path.join(testdir, 'test_data_set_0')
29 os.makedirs(target_dir, exist_ok=True)
30 for _idx, _x in enumerate(np_args):
31 fn = os.path.join(target_dir, "input_{}.pb".format(_idx))
32 save_tensor(numpy_helper.from_array(_x, op_func.inputs[_idx].name), fn)
33 onnx.save_model(op_func.onnx_model, os.path.join(testdir, 'model.onnx'))
34
35 def selfcheck(self, *args):
36 print("The extensions loaded, status: OK.")
37
38
39if __name__ == '__main__':
40 fire.Fire(ORTExtCommands)
41