microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
9baed694a34dcb074db66e9f74c93a4e5bb6f9a2

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/_ocos.py

167lines · modecode

1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License. See License.txt in the project root for
3# license information.
4###############################################################################
5
6import sys
7import copy
8import onnx
9from onnx import helper
10from ._extensions_pydll import ( # noqa
11 PyCustomOpDef, enable_py_op, add_custom_op, hash_64, default_opset_domain)
12
13
14def get_library_path():
15 """
16 The custom operator library binary path
17 :return: A string of the this library path.
18 """
19 mod = sys.modules['onnxruntime_extensions._extensions_pydll']
20 return mod.__file__
21
22
23class Opdef:
24
25 _odlist = {}
26
27 def __init__(self, op_type, func):
28 self.op_type = op_type
29 self.body = func
30 self._id = id(self)
31
32 @staticmethod
33 def declare(*args, **kwargs):
34 if len(args) > 0 and hasattr(args[0], '__call__'):
35 raise RuntimeError("Unexpected arguments {}.".format(args))
36 # return Opdef._create(args[0])
37 return lambda f: Opdef.create(f, *args, **kwargs)
38
39 @staticmethod
40 def create(func, *args, **kwargs):
41 name = kwargs.get('op_type', None)
42 op_type = name or func.__name__
43 opdef = Opdef(op_type, func)
44 od_id = id(opdef)
45
46 # Tells python this object cannot be destroyed
47 # because it is also stored in C++ container.
48 Opdef._odlist[od_id] = opdef
49 opdef._nativedef = PyCustomOpDef()
50 opdef._nativedef.op_type = op_type
51 opdef._nativedef.obj_id = od_id
52
53 inputs = kwargs.get('inputs', None)
54 if inputs is None:
55 inputs = [PyCustomOpDef.dt_float]
56 opdef._nativedef.input_types = inputs
57 outputs = kwargs.get('outputs', None)
58 if outputs is None:
59 outputs = [PyCustomOpDef.dt_float]
60 opdef._nativedef.output_types = outputs
61 attrs = kwargs.get('attrs', None)
62 if attrs is None:
63 attrs = []
64 opdef._nativedef.attrs = attrs
65 add_custom_op(opdef._nativedef)
66 return opdef
67
68 def __call__(self, *args, **kwargs):
69 return self.body(*args, **kwargs)
70
71
72def _on_pyop_invocation(k_id, feed, attributes):
73 if k_id not in Opdef._odlist:
74 raise RuntimeError(
75 "Unable to find function id={}. "
76 "Did you decorate the operator with @onnx_op?.".format(k_id))
77 op_ = Opdef._odlist[k_id]
78 rv = op_.body(*feed, **attributes)
79 if isinstance(rv, tuple):
80 # Multiple outputs.
81 res = []
82 for r in rv:
83 res.append(r.shape)
84 res.append(r.flatten().tolist())
85 res = tuple(res)
86 else:
87 res = (rv.shape, rv.flatten().tolist())
88 return (k_id, ) + res
89
90
91def _ensure_opset_domain(model):
92 op_domain_name = default_opset_domain()
93 domain_missing = True
94 for oi_ in model.opset_import:
95 if oi_.domain == op_domain_name:
96 domain_missing = False
97
98 if domain_missing:
99 model.opset_import.extend([helper.make_operatorsetid(op_domain_name, 1)])
100
101 return model
102
103
104def expand_onnx_inputs(model, target_input, extra_nodes, new_inputs):
105 """
106 Replace the existing inputs of a model with the new inputs, plus some extra nodes
107 :param model: The ONNX model loaded as ModelProto
108 :param target_input: The input name to be replaced
109 :param extra_nodes: The extra nodes to be added
110 :param new_inputs: The new input (type: ValueInfoProto) sequence
111 :return: The ONNX model after modification
112 """
113 graph = model.graph
114 new_inputs = [n for n in graph.input if n.name != target_input] + new_inputs
115 new_nodes = list(model.graph.node) + extra_nodes
116 new_graph = helper.make_graph(
117 new_nodes, graph.name, new_inputs, list(graph.output), list(graph.initializer))
118
119 new_model = copy.deepcopy(model)
120 new_model.graph.CopyFrom(new_graph)
121
122 return _ensure_opset_domain(new_model)
123
124
125def hook_model_op(model, node_name, hook_func, input_types):
126 """
127 Add a hook function node in the ONNX Model, which could be used for the model diagnosis.
128 :param model: The ONNX model loaded as ModelProto
129 :param node_name: The node name where the hook will be installed
130 :param hook_func: The hook function, callback on the model inference
131 :param input_types: The input types as a list
132 :return: The ONNX model with the hook installed
133 """
134
135 # onnx.shape_inference is very unstable, useless.
136 # hkd_model = shape_inference.infer_shapes(model)
137 hkd_model = model
138
139 n_idx = 0
140 hnode, nnode = (None, None)
141 nodes = list(hkd_model.graph.node)
142 brkpt_name = node_name + '_hkd'
143 optype_name = "op_{}_{}".format(hook_func.__name__, node_name)
144 for n_ in nodes:
145 if n_.name == node_name:
146 input_names = list(n_.input)
147 brk_output_name = [i_ + '_hkd' for i_ in input_names]
148 hnode = onnx.helper.make_node(
149 optype_name, n_.input, brk_output_name, name=brkpt_name, domain=default_opset_domain())
150 nnode = n_
151 del nnode.input[:]
152 nnode.input.extend(brk_output_name)
153 break
154 n_idx += 1
155
156 if hnode is None:
157 raise ValueError("{} is not an operator node name".format(node_name))
158
159 repacked = nodes[:n_idx] + [hnode, nnode] + nodes[n_idx+1:]
160 del hkd_model.graph.node[:]
161 hkd_model.graph.node.extend(repacked)
162
163 Opdef.create(hook_func, op_type=optype_name, inputs=input_types, outputs=input_types)
164 return _ensure_opset_domain(hkd_model)
165
166
167PyCustomOpDef.install_hooker(_on_pyop_invocation)
168