microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
onnxruntime_extensions/_ocos.py
186lines · modecode
| 1 | # Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | # Licensed under the MIT License. See License.txt in the project root for |
| 3 | # license information. |
| 4 | ############################################################################### |
| 5 | """ |
| 6 | _ocos.py: PythonOp implementation |
| 7 | """ |
| 8 | |
| 9 | import sys |
| 10 | import copy |
| 11 | import onnx |
| 12 | from onnx import helper |
| 13 | from ._extensions_pydll import ( # noqa |
| 14 | PyCustomOpDef, enable_py_op, add_custom_op, hash_64, default_opset_domain) |
| 15 | |
| 16 | |
| 17 | def get_library_path(): |
| 18 | """ |
| 19 | The custom operator library binary path |
| 20 | :return: A string of this library path. |
| 21 | """ |
| 22 | mod = sys.modules['onnxruntime_extensions._extensions_pydll'] |
| 23 | return mod.__file__ |
| 24 | |
| 25 | |
| 26 | class Opdef: |
| 27 | |
| 28 | _odlist = {} |
| 29 | |
| 30 | def __init__(self, op_type, func): |
| 31 | self.op_type = op_type |
| 32 | self.body = func |
| 33 | self._id = id(self) |
| 34 | |
| 35 | @staticmethod |
| 36 | def declare(*args, **kwargs): |
| 37 | if len(args) > 0 and hasattr(args[0], '__call__'): |
| 38 | raise RuntimeError("Unexpected arguments {}.".format(args)) |
| 39 | # return Opdef._create(args[0]) |
| 40 | return lambda f: Opdef.create(f, *args, **kwargs) |
| 41 | |
| 42 | @staticmethod |
| 43 | def create(func, *args, **kwargs): |
| 44 | name = kwargs.get('op_type', None) |
| 45 | op_type = name or func.__name__ |
| 46 | opdef = Opdef(op_type, func) |
| 47 | od_id = id(opdef) |
| 48 | |
| 49 | # Tells python this object cannot be destroyed |
| 50 | # because it is also stored in C++ container. |
| 51 | Opdef._odlist[od_id] = opdef |
| 52 | opdef._nativedef = PyCustomOpDef() |
| 53 | opdef._nativedef.op_type = op_type |
| 54 | opdef._nativedef.obj_id = od_id |
| 55 | |
| 56 | inputs = kwargs.get('inputs', None) |
| 57 | if inputs is None: |
| 58 | inputs = [PyCustomOpDef.dt_float] |
| 59 | opdef._nativedef.input_types = inputs |
| 60 | outputs = kwargs.get('outputs', None) |
| 61 | if outputs is None: |
| 62 | outputs = [PyCustomOpDef.dt_float] |
| 63 | opdef._nativedef.output_types = outputs |
| 64 | attrs = kwargs.get('attrs', None) |
| 65 | if attrs is None: |
| 66 | attrs = {} |
| 67 | elif isinstance(attrs, (list, tuple)): |
| 68 | attrs = {k: PyCustomOpDef.dt_string for k in attrs} |
| 69 | opdef._nativedef.attrs = attrs |
| 70 | add_custom_op(opdef._nativedef) |
| 71 | return opdef |
| 72 | |
| 73 | def __call__(self, *args, **kwargs): |
| 74 | return self.body(*args, **kwargs) |
| 75 | |
| 76 | def cast_attributes(self, attributes): |
| 77 | res = {} |
| 78 | for k, v in attributes.items(): |
| 79 | if self._nativedef.attrs[k] == PyCustomOpDef.dt_int64: |
| 80 | res[k] = int(v) |
| 81 | elif self._nativedef.attrs[k] == PyCustomOpDef.dt_float: |
| 82 | res[k] = float(v) |
| 83 | elif self._nativedef.attrs[k] == PyCustomOpDef.dt_string: |
| 84 | res[k] = v |
| 85 | else: |
| 86 | raise RuntimeError("Unsupported attribute type {}.".format( |
| 87 | self._nativedef.attrs[k])) |
| 88 | return res |
| 89 | |
| 90 | |
| 91 | def _on_pyop_invocation(k_id, feed, attributes): |
| 92 | if k_id not in Opdef._odlist: |
| 93 | raise RuntimeError( |
| 94 | "Unable to find function id={}. " |
| 95 | "Did you decorate the operator with @onnx_op?.".format(k_id)) |
| 96 | op_ = Opdef._odlist[k_id] |
| 97 | rv = op_.body(*feed, **op_.cast_attributes(attributes)) |
| 98 | if isinstance(rv, tuple): |
| 99 | # Multiple outputs. |
| 100 | res = [] |
| 101 | for r in rv: |
| 102 | res.append(r.shape) |
| 103 | res.append(r.flatten().tolist()) |
| 104 | res = tuple(res) |
| 105 | else: |
| 106 | res = (rv.shape, rv.flatten().tolist()) |
| 107 | return (k_id, ) + res |
| 108 | |
| 109 | |
| 110 | def _ensure_opset_domain(model): |
| 111 | op_domain_name = default_opset_domain() |
| 112 | domain_missing = True |
| 113 | for oi_ in model.opset_import: |
| 114 | if oi_.domain == op_domain_name: |
| 115 | domain_missing = False |
| 116 | |
| 117 | if domain_missing: |
| 118 | model.opset_import.extend([helper.make_operatorsetid(op_domain_name, 1)]) |
| 119 | |
| 120 | return model |
| 121 | |
| 122 | |
| 123 | def expand_onnx_inputs(model, target_input, extra_nodes, new_inputs): |
| 124 | """ |
| 125 | Replace the existing inputs of a model with the new inputs, plus some extra nodes |
| 126 | :param model: The ONNX model loaded as ModelProto |
| 127 | :param target_input: The input name to be replaced |
| 128 | :param extra_nodes: The extra nodes to be added |
| 129 | :param new_inputs: The new input (type: ValueInfoProto) sequence |
| 130 | :return: The ONNX model after modification |
| 131 | """ |
| 132 | graph = model.graph |
| 133 | new_inputs = [n for n in graph.input if n.name != target_input] + new_inputs |
| 134 | new_nodes = list(model.graph.node) + extra_nodes |
| 135 | new_graph = helper.make_graph( |
| 136 | new_nodes, graph.name, new_inputs, list(graph.output), list(graph.initializer)) |
| 137 | |
| 138 | new_model = copy.deepcopy(model) |
| 139 | new_model.graph.CopyFrom(new_graph) |
| 140 | |
| 141 | return _ensure_opset_domain(new_model) |
| 142 | |
| 143 | |
| 144 | def hook_model_op(model, node_name, hook_func, input_types): |
| 145 | """ |
| 146 | Add a hook function node in the ONNX Model, which could be used for the model diagnosis. |
| 147 | :param model: The ONNX model loaded as ModelProto |
| 148 | :param node_name: The node name where the hook will be installed |
| 149 | :param hook_func: The hook function, callback on the model inference |
| 150 | :param input_types: The input types as a list |
| 151 | :return: The ONNX model with the hook installed |
| 152 | """ |
| 153 | |
| 154 | # onnx.shape_inference is very unstable, useless. |
| 155 | # hkd_model = shape_inference.infer_shapes(model) |
| 156 | hkd_model = model |
| 157 | |
| 158 | n_idx = 0 |
| 159 | hnode, nnode = (None, None) |
| 160 | nodes = list(hkd_model.graph.node) |
| 161 | brkpt_name = node_name + '_hkd' |
| 162 | optype_name = "op_{}_{}".format(hook_func.__name__, node_name) |
| 163 | for n_ in nodes: |
| 164 | if n_.name == node_name: |
| 165 | input_names = list(n_.input) |
| 166 | brk_output_name = [i_ + '_hkd' for i_ in input_names] |
| 167 | hnode = onnx.helper.make_node( |
| 168 | optype_name, n_.input, brk_output_name, name=brkpt_name, domain=default_opset_domain()) |
| 169 | nnode = n_ |
| 170 | del nnode.input[:] |
| 171 | nnode.input.extend(brk_output_name) |
| 172 | break |
| 173 | n_idx += 1 |
| 174 | |
| 175 | if hnode is None: |
| 176 | raise ValueError("{} is not an operator node name".format(node_name)) |
| 177 | |
| 178 | repacked = nodes[:n_idx] + [hnode, nnode] + nodes[n_idx+1:] |
| 179 | del hkd_model.graph.node[:] |
| 180 | hkd_model.graph.node.extend(repacked) |
| 181 | |
| 182 | Opdef.create(hook_func, op_type=optype_name, inputs=input_types, outputs=input_types) |
| 183 | return _ensure_opset_domain(hkd_model) |
| 184 | |
| 185 | |
| 186 | PyCustomOpDef.install_hooker(_on_pyop_invocation) |
| 187 | |