microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
9ba649e134b9c44591359358cb2a64fb0431d492

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/_ocos.py

186lines · modecode

1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License. See License.txt in the project root for
3# license information.
4###############################################################################
5"""
6_ocos.py: PythonOp implementation
7"""
8
9import sys
10import copy
11import onnx
12from onnx import helper
13from ._extensions_pydll import ( # noqa
14 PyCustomOpDef, enable_py_op, add_custom_op, hash_64, default_opset_domain)
15
16
17def get_library_path():
18 """
19 The custom operator library binary path
20 :return: A string of this library path.
21 """
22 mod = sys.modules['onnxruntime_extensions._extensions_pydll']
23 return mod.__file__
24
25
26class Opdef:
27
28 _odlist = {}
29
30 def __init__(self, op_type, func):
31 self.op_type = op_type
32 self.body = func
33 self._id = id(self)
34
35 @staticmethod
36 def declare(*args, **kwargs):
37 if len(args) > 0 and hasattr(args[0], '__call__'):
38 raise RuntimeError("Unexpected arguments {}.".format(args))
39 # return Opdef._create(args[0])
40 return lambda f: Opdef.create(f, *args, **kwargs)
41
42 @staticmethod
43 def create(func, *args, **kwargs):
44 name = kwargs.get('op_type', None)
45 op_type = name or func.__name__
46 opdef = Opdef(op_type, func)
47 od_id = id(opdef)
48
49 # Tells python this object cannot be destroyed
50 # because it is also stored in C++ container.
51 Opdef._odlist[od_id] = opdef
52 opdef._nativedef = PyCustomOpDef()
53 opdef._nativedef.op_type = op_type
54 opdef._nativedef.obj_id = od_id
55
56 inputs = kwargs.get('inputs', None)
57 if inputs is None:
58 inputs = [PyCustomOpDef.dt_float]
59 opdef._nativedef.input_types = inputs
60 outputs = kwargs.get('outputs', None)
61 if outputs is None:
62 outputs = [PyCustomOpDef.dt_float]
63 opdef._nativedef.output_types = outputs
64 attrs = kwargs.get('attrs', None)
65 if attrs is None:
66 attrs = {}
67 elif isinstance(attrs, (list, tuple)):
68 attrs = {k: PyCustomOpDef.dt_string for k in attrs}
69 opdef._nativedef.attrs = attrs
70 add_custom_op(opdef._nativedef)
71 return opdef
72
73 def __call__(self, *args, **kwargs):
74 return self.body(*args, **kwargs)
75
76 def cast_attributes(self, attributes):
77 res = {}
78 for k, v in attributes.items():
79 if self._nativedef.attrs[k] == PyCustomOpDef.dt_int64:
80 res[k] = int(v)
81 elif self._nativedef.attrs[k] == PyCustomOpDef.dt_float:
82 res[k] = float(v)
83 elif self._nativedef.attrs[k] == PyCustomOpDef.dt_string:
84 res[k] = v
85 else:
86 raise RuntimeError("Unsupported attribute type {}.".format(
87 self._nativedef.attrs[k]))
88 return res
89
90
91def _on_pyop_invocation(k_id, feed, attributes):
92 if k_id not in Opdef._odlist:
93 raise RuntimeError(
94 "Unable to find function id={}. "
95 "Did you decorate the operator with @onnx_op?.".format(k_id))
96 op_ = Opdef._odlist[k_id]
97 rv = op_.body(*feed, **op_.cast_attributes(attributes))
98 if isinstance(rv, tuple):
99 # Multiple outputs.
100 res = []
101 for r in rv:
102 res.append(r.shape)
103 res.append(r.flatten().tolist())
104 res = tuple(res)
105 else:
106 res = (rv.shape, rv.flatten().tolist())
107 return (k_id, ) + res
108
109
110def _ensure_opset_domain(model):
111 op_domain_name = default_opset_domain()
112 domain_missing = True
113 for oi_ in model.opset_import:
114 if oi_.domain == op_domain_name:
115 domain_missing = False
116
117 if domain_missing:
118 model.opset_import.extend([helper.make_operatorsetid(op_domain_name, 1)])
119
120 return model
121
122
123def expand_onnx_inputs(model, target_input, extra_nodes, new_inputs):
124 """
125 Replace the existing inputs of a model with the new inputs, plus some extra nodes
126 :param model: The ONNX model loaded as ModelProto
127 :param target_input: The input name to be replaced
128 :param extra_nodes: The extra nodes to be added
129 :param new_inputs: The new input (type: ValueInfoProto) sequence
130 :return: The ONNX model after modification
131 """
132 graph = model.graph
133 new_inputs = [n for n in graph.input if n.name != target_input] + new_inputs
134 new_nodes = list(model.graph.node) + extra_nodes
135 new_graph = helper.make_graph(
136 new_nodes, graph.name, new_inputs, list(graph.output), list(graph.initializer))
137
138 new_model = copy.deepcopy(model)
139 new_model.graph.CopyFrom(new_graph)
140
141 return _ensure_opset_domain(new_model)
142
143
144def hook_model_op(model, node_name, hook_func, input_types):
145 """
146 Add a hook function node in the ONNX Model, which could be used for the model diagnosis.
147 :param model: The ONNX model loaded as ModelProto
148 :param node_name: The node name where the hook will be installed
149 :param hook_func: The hook function, callback on the model inference
150 :param input_types: The input types as a list
151 :return: The ONNX model with the hook installed
152 """
153
154 # onnx.shape_inference is very unstable, useless.
155 # hkd_model = shape_inference.infer_shapes(model)
156 hkd_model = model
157
158 n_idx = 0
159 hnode, nnode = (None, None)
160 nodes = list(hkd_model.graph.node)
161 brkpt_name = node_name + '_hkd'
162 optype_name = "op_{}_{}".format(hook_func.__name__, node_name)
163 for n_ in nodes:
164 if n_.name == node_name:
165 input_names = list(n_.input)
166 brk_output_name = [i_ + '_hkd' for i_ in input_names]
167 hnode = onnx.helper.make_node(
168 optype_name, n_.input, brk_output_name, name=brkpt_name, domain=default_opset_domain())
169 nnode = n_
170 del nnode.input[:]
171 nnode.input.extend(brk_output_name)
172 break
173 n_idx += 1
174
175 if hnode is None:
176 raise ValueError("{} is not an operator node name".format(node_name))
177
178 repacked = nodes[:n_idx] + [hnode, nnode] + nodes[n_idx+1:]
179 del hkd_model.graph.node[:]
180 hkd_model.graph.node.extend(repacked)
181
182 Opdef.create(hook_func, op_type=optype_name, inputs=input_types, outputs=input_types)
183 return _ensure_opset_domain(hkd_model)
184
185
186PyCustomOpDef.install_hooker(_on_pyop_invocation)
187