microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
a98c29f6d258c378600c1632385f620737b297a8

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_customops/_ocos.py

104lines · modecode

1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License. See License.txt in the project root for
3# license information.
4###############################################################################
5
6import sys
7import copy
8from onnx import helper
9from ._ortcustomops import ( # noqa
10 PyCustomOpDef, enable_custom_op, add_custom_op, hash_64)
11
12
13def get_library_path():
14 mod = sys.modules['onnxruntime_customops._ortcustomops']
15 return mod.__file__
16
17
18class Opdef:
19
20 _odlist = {}
21
22 def __init__(self, op_type, func):
23 self.op_type = op_type
24 self.body = func
25 self._id = id(self)
26
27 @staticmethod
28 def declare(*args, **kwargs):
29 if len(args) > 0 and hasattr(args[0], '__call__'):
30 raise RuntimeError("Unexpected arguments {}.".format(args))
31 # return Opdef._create(args[0])
32 return lambda f: Opdef._create(f, *args, **kwargs)
33
34 @staticmethod
35 def _create(func, *args, **kwargs):
36 name = kwargs.get('op_type', None)
37 op_type = name or func.__name__
38 opdef = Opdef(op_type, func)
39 od_id = id(opdef)
40
41 # Tells python this object cannot be destroyed
42 # because it is also stored in C++ container.
43 Opdef._odlist[od_id] = opdef
44 opdef._nativedef = PyCustomOpDef()
45 opdef._nativedef.op_type = op_type
46 opdef._nativedef.obj_id = od_id
47
48 # TODO: add handle more types and multiple inputs/outputs.
49 # by default the op is single in/out
50 inputs = kwargs.get('inputs', None)
51 if inputs is None:
52 inputs = [PyCustomOpDef.dt_float]
53 opdef._nativedef.input_types = inputs
54 outputs = kwargs.get('outputs', None)
55 if outputs is None:
56 outputs = [PyCustomOpDef.dt_float]
57 opdef._nativedef.output_types = outputs
58 add_custom_op(opdef._nativedef)
59 return opdef
60
61 def __call__(self, *args, **kwargs):
62 return self.body(*args, **kwargs)
63
64
65def _on_pyop_invocation(k_id, feed):
66 if k_id not in Opdef._odlist:
67 raise RuntimeError(
68 "Unable to find function id={}. "
69 "Did you decorate the operator with @onnx_op?.".format(k_id))
70 op_ = Opdef._odlist[k_id]
71 rv = op_.body(*feed)
72 if isinstance(rv, tuple):
73 # Multiple outputs.
74 res = []
75 for r in rv:
76 res.append(r.shape)
77 res.append(r.flatten().tolist())
78 res = tuple(res)
79 else:
80 res = (rv.shape, rv.flatten().tolist())
81 return (k_id, ) + res
82
83
84def expand_onnx_inputs(model, target_input, extra_nodes, new_inputs):
85 graph = model.graph
86 new_inputs = [n for n in graph.input if n.name != target_input] + new_inputs
87 new_nodes = list(model.graph.node) + extra_nodes
88 new_graph = helper.make_graph(
89 new_nodes, graph.name, new_inputs, list(graph.output), list(graph.initializer))
90
91 new_model = copy.deepcopy(model)
92 new_model.graph.CopyFrom(new_graph)
93 domain_missing = True
94 for oi_ in model.opset_import:
95 if oi_.domain == 'ai.onnx.contrib':
96 domain_missing = False
97
98 if domain_missing:
99 new_model.opset_import.extend([helper.make_operatorsetid('ai.onnx.contrib', 1)])
100
101 return new_model
102
103
104PyCustomOpDef.install_hooker(_on_pyop_invocation)
105