microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
onnxruntime_customops/_ocos.py
104lines · modecode
| 1 | # Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | # Licensed under the MIT License. See License.txt in the project root for |
| 3 | # license information. |
| 4 | ############################################################################### |
| 5 | |
| 6 | import sys |
| 7 | import copy |
| 8 | from onnx import helper |
| 9 | from ._ortcustomops import ( # noqa |
| 10 | PyCustomOpDef, enable_custom_op, add_custom_op, hash_64) |
| 11 | |
| 12 | |
| 13 | def get_library_path(): |
| 14 | mod = sys.modules['onnxruntime_customops._ortcustomops'] |
| 15 | return mod.__file__ |
| 16 | |
| 17 | |
| 18 | class Opdef: |
| 19 | |
| 20 | _odlist = {} |
| 21 | |
| 22 | def __init__(self, op_type, func): |
| 23 | self.op_type = op_type |
| 24 | self.body = func |
| 25 | self._id = id(self) |
| 26 | |
| 27 | @staticmethod |
| 28 | def declare(*args, **kwargs): |
| 29 | if len(args) > 0 and hasattr(args[0], '__call__'): |
| 30 | raise RuntimeError("Unexpected arguments {}.".format(args)) |
| 31 | # return Opdef._create(args[0]) |
| 32 | return lambda f: Opdef._create(f, *args, **kwargs) |
| 33 | |
| 34 | @staticmethod |
| 35 | def _create(func, *args, **kwargs): |
| 36 | name = kwargs.get('op_type', None) |
| 37 | op_type = name or func.__name__ |
| 38 | opdef = Opdef(op_type, func) |
| 39 | od_id = id(opdef) |
| 40 | |
| 41 | # Tells python this object cannot be destroyed |
| 42 | # because it is also stored in C++ container. |
| 43 | Opdef._odlist[od_id] = opdef |
| 44 | opdef._nativedef = PyCustomOpDef() |
| 45 | opdef._nativedef.op_type = op_type |
| 46 | opdef._nativedef.obj_id = od_id |
| 47 | |
| 48 | # TODO: add handle more types and multiple inputs/outputs. |
| 49 | # by default the op is single in/out |
| 50 | inputs = kwargs.get('inputs', None) |
| 51 | if inputs is None: |
| 52 | inputs = [PyCustomOpDef.dt_float] |
| 53 | opdef._nativedef.input_types = inputs |
| 54 | outputs = kwargs.get('outputs', None) |
| 55 | if outputs is None: |
| 56 | outputs = [PyCustomOpDef.dt_float] |
| 57 | opdef._nativedef.output_types = outputs |
| 58 | add_custom_op(opdef._nativedef) |
| 59 | return opdef |
| 60 | |
| 61 | def __call__(self, *args, **kwargs): |
| 62 | return self.body(*args, **kwargs) |
| 63 | |
| 64 | |
| 65 | def _on_pyop_invocation(k_id, feed): |
| 66 | if k_id not in Opdef._odlist: |
| 67 | raise RuntimeError( |
| 68 | "Unable to find function id={}. " |
| 69 | "Did you decorate the operator with @onnx_op?.".format(k_id)) |
| 70 | op_ = Opdef._odlist[k_id] |
| 71 | rv = op_.body(*feed) |
| 72 | if isinstance(rv, tuple): |
| 73 | # Multiple outputs. |
| 74 | res = [] |
| 75 | for r in rv: |
| 76 | res.append(r.shape) |
| 77 | res.append(r.flatten().tolist()) |
| 78 | res = tuple(res) |
| 79 | else: |
| 80 | res = (rv.shape, rv.flatten().tolist()) |
| 81 | return (k_id, ) + res |
| 82 | |
| 83 | |
| 84 | def expand_onnx_inputs(model, target_input, extra_nodes, new_inputs): |
| 85 | graph = model.graph |
| 86 | new_inputs = [n for n in graph.input if n.name != target_input] + new_inputs |
| 87 | new_nodes = list(model.graph.node) + extra_nodes |
| 88 | new_graph = helper.make_graph( |
| 89 | new_nodes, graph.name, new_inputs, list(graph.output), list(graph.initializer)) |
| 90 | |
| 91 | new_model = copy.deepcopy(model) |
| 92 | new_model.graph.CopyFrom(new_graph) |
| 93 | domain_missing = True |
| 94 | for oi_ in model.opset_import: |
| 95 | if oi_.domain == 'ai.onnx.contrib': |
| 96 | domain_missing = False |
| 97 | |
| 98 | if domain_missing: |
| 99 | new_model.opset_import.extend([helper.make_operatorsetid('ai.onnx.contrib', 1)]) |
| 100 | |
| 101 | return new_model |
| 102 | |
| 103 | |
| 104 | PyCustomOpDef.install_hooker(_on_pyop_invocation) |
| 105 | |