microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
a98c29f6d258c378600c1632385f620737b297a8

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_customops/_ocos.py

104lines · modepreview

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
###############################################################################

import sys
import copy
from onnx import helper
from ._ortcustomops import (  # noqa
    PyCustomOpDef, enable_custom_op, add_custom_op, hash_64)


def get_library_path():
    mod = sys.modules['onnxruntime_customops._ortcustomops']
    return mod.__file__


class Opdef:

    _odlist = {}

    def __init__(self, op_type, func):
        self.op_type = op_type
        self.body = func
        self._id = id(self)

    @staticmethod
    def declare(*args, **kwargs):
        if len(args) > 0 and hasattr(args[0], '__call__'):
            raise RuntimeError("Unexpected arguments {}.".format(args))
            # return Opdef._create(args[0])
        return lambda f: Opdef._create(f, *args, **kwargs)

    @staticmethod
    def _create(func, *args, **kwargs):
        name = kwargs.get('op_type', None)
        op_type = name or func.__name__
        opdef = Opdef(op_type, func)
        od_id = id(opdef)

        # Tells python this object cannot be destroyed
        # because it is also stored in C++ container.
        Opdef._odlist[od_id] = opdef
        opdef._nativedef = PyCustomOpDef()
        opdef._nativedef.op_type = op_type
        opdef._nativedef.obj_id = od_id

        # TODO: add handle more types and multiple inputs/outputs.
        # by default the op is single in/out
        inputs = kwargs.get('inputs', None)
        if inputs is None:
            inputs = [PyCustomOpDef.dt_float]
        opdef._nativedef.input_types = inputs
        outputs = kwargs.get('outputs', None)
        if outputs is None:
            outputs = [PyCustomOpDef.dt_float]
        opdef._nativedef.output_types = outputs
        add_custom_op(opdef._nativedef)
        return opdef

    def __call__(self, *args, **kwargs):
        return self.body(*args, **kwargs)


def _on_pyop_invocation(k_id, feed):
    if k_id not in Opdef._odlist:
        raise RuntimeError(
            "Unable to find function id={}. "
            "Did you decorate the operator with @onnx_op?.".format(k_id))
    op_ = Opdef._odlist[k_id]
    rv = op_.body(*feed)
    if isinstance(rv, tuple):
        # Multiple outputs.
        res = []
        for r in rv:
            res.append(r.shape)
            res.append(r.flatten().tolist())
        res = tuple(res)
    else:
        res = (rv.shape, rv.flatten().tolist())
    return (k_id, ) + res


def expand_onnx_inputs(model, target_input, extra_nodes, new_inputs):
    graph = model.graph
    new_inputs = [n for n in graph.input if n.name != target_input] + new_inputs
    new_nodes = list(model.graph.node) + extra_nodes
    new_graph = helper.make_graph(
        new_nodes, graph.name, new_inputs, list(graph.output), list(graph.initializer))

    new_model = copy.deepcopy(model)
    new_model.graph.CopyFrom(new_graph)
    domain_missing = True
    for oi_ in model.opset_import:
        if oi_.domain == 'ai.onnx.contrib':
            domain_missing = False

    if domain_missing:
        new_model.opset_import.extend([helper.make_operatorsetid('ai.onnx.contrib', 1)])

    return new_model


PyCustomOpDef.install_hooker(_on_pyop_invocation)