microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
1f0c76cefaa9359b9fc74ff4b9ed93214486205a

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/tools/pre_post_processing/step.py

222lines · modepreview

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import abc
import onnx

from onnx import parser
from typing import List, Optional, Tuple

from .utils import (
    IoMapEntry,
    create_custom_op_checker_context,
    TENSOR_TYPE_TO_ONNX_TYPE,
)


class Step(object):
    """Base class for a pre or post processing step."""

    prefix = "_ppp"
    _step_num = 0  # unique step number so we can prefix the naming in the graph created for the step

    def __init__(self, inputs: List[str], outputs: List[str], name: Optional[str] = None):
        """
        Initialize the step.

        Args:
            inputs: List of default input names.
            outputs: List of default output names.
            name: Step name. Defaults to the derived class name.
        """
        self.step_num = Step._step_num
        self.input_names = inputs
        self.output_names = outputs
        self.name = name if name else f"{self.__class__.__name__}"
        self._prefix = f"{Step.prefix}{self.step_num}_"

        Step._step_num += 1

    def connect(self, entry: IoMapEntry):
        """
        Connect the value name from a previous step to an input of this step so they match.
        This makes joining the GraphProto created by each step trivial.
        """
        assert len(entry.producer.output_names) >= entry.producer_idx
        assert len(self.input_names) >= entry.consumer_idx
        assert isinstance(entry.producer, Step)

        self.input_names[entry.consumer_idx] = entry.producer.output_names[entry.producer_idx]

    def apply(self, graph: onnx.GraphProto, 
              checker_context: onnx.checker.C.CheckerContext, 
              graph_outputs_to_maintain: List[str]):
        """
        Create a graph for this step that can be appended to the provided graph.
        The PrePostProcessor will handle merging the two.

        Args:
            graph_outputs_to_maintain: List of output names to maintain in the graph by additional effort.
                For outputs having multiple consumers, these outputs will be consumed by default and prevent
                connection from the subsequent steps.
                This outputs is generated by IOEntryValuePreserver.
        """

        onnx_opset = checker_context.opset_imports[""]
        graph_for_step = self._create_graph_for_step(graph, onnx_opset)
        onnx.checker.check_graph(graph_for_step, checker_context)

        # prefix the graph for this step to guarantee no clashes of value names with the existing graph
        onnx.compose.add_prefix_graph(graph_for_step, self._prefix, inplace=True)
        result = self.__merge(graph, graph_for_step, graph_outputs_to_maintain)

        # update self.output_names to the prefixed names so that when we connect later Steps the values match
        new_outputs = [self._prefix + o for o in self.output_names]
        result_outputs = [o.name for o in result.output]

        # sanity check that all of our outputs are in the merged graph
        for o in new_outputs:
            assert o in result_outputs

        self.output_names = new_outputs

        return result

    @abc.abstractmethod
    def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
        """
        Derived class should implement this and return the GraphProto containing the nodes required to
        implement the step.

        Args:
            graph: Graph the step will be appended to. Use to determine the types and shapes of values to connect.
            onnx_opset: The ONNX opset being targeted.
        """
        pass

    def __merge(self, first: onnx.GraphProto, second: onnx.GraphProto,
                graph_outputs_to_maintain: Optional[List[str]] = None):
        # We prefixed all the value names in `second`, so allow for that when connecting the two graphs
        first_output = [o.name for o in first.output]
        io_map = []
        for o in first.output:
            # apply the same prefix to the output from the previous step to match the prefixed graph from this step
            prefixed_output = self._prefix + o.name
            for i in second.input:
                if i.name == prefixed_output:
                    io_map.append((o.name, i.name))
                    first_output.remove(o.name)
                    
        graph_outputs = first_output + [o.name for o in second.output if o.name not in first_output]
        graph_outputs += [o for o in graph_outputs_to_maintain if o not in graph_outputs]
                
        # merge with existing graph
        merged_graph = onnx.compose.merge_graphs(first, second, io_map, outputs=graph_outputs)

        return merged_graph

    @staticmethod
    def _elem_type_str(elem_type: int):
        return TENSOR_TYPE_TO_ONNX_TYPE[elem_type]

    @staticmethod
    def _shape_to_str(shape: onnx.TensorShapeProto):
        """Returns the values from the shape as a comma separated string."""

        def dim_to_str(dim):
            if dim.HasField("dim_value"):
                return str(dim.dim_value)
            elif dim.HasField("dim_param"):
                return dim.dim_param
            else:
                return ""

        shape_str = ",".join([dim_to_str(dim) for dim in shape.dim])
        return shape_str

    def _input_tensor_type(self, graph: onnx.GraphProto, input_num: int) -> onnx.TensorProto:
        """Get the onnx.TensorProto for the input from the outputs of the graph we're appending to."""

        input_type = None
        for o in graph.output:
            if o.name == self.input_names[input_num]:
                input_type = o.type.tensor_type
                break

        if not input_type:
            raise ValueError(f"Input {self.input_names[input_num]} was not found in outputs of graph.")

        return input_type

    def _get_input_type_and_shape_strs(self, graph: onnx.GraphProto, input_num: int) -> Tuple[str, str]:
        input_type = self._input_tensor_type(graph, input_num)
        return Step._elem_type_str(input_type.elem_type), Step._shape_to_str(input_type.shape)


class Debug(Step):
    """
    Step that can be arbitrarily inserted in the pre or post processing pipeline.
    It will make the outputs of the previous Step also become graph outputs so their value can be more easily debugged.

    The output will be duplicated into two outputs, one will be renamed with a suffix "_next",
    another will be renamed with a suffix "_debug". The "_next" outputs will feed into the next step,
    the "_debug" outputs will become graph outputs.
    """

    def __init__(self, num_inputs: int = 1, name: Optional[str] = None):
        """
        Initialize Debug step
        Args:
            num_inputs: Number of inputs from previous Step to make graph outputs.
            name: Optional name for Step. Defaults to 'Debug'
        """
        self._num_inputs = num_inputs
        input_names = [f"input{i}" for i in range(0, num_inputs)]
        output_names = [f"debug{i}" for i in range(0, num_inputs)]

        super().__init__(input_names, output_names, name)

    def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
        if self._num_inputs > len(graph.output):
            raise ValueError(
                f"Debug step requested {self._num_inputs} inputs, but graph only has {len(graph.output)} outputs.")

        debug_offset = len(self.input_names)
        # update output names so we preserve info from the latest input names
        self.output_names = [f"{name}_next" for name in self.input_names]
        self.output_names += [f"{name}_debug" for name in self.input_names]

        input_str_list = []
        output_str_list = []
        nodes_str_list = []
        for i in range(0, self._num_inputs):
            input_type_str, input_shape_str = self._get_input_type_and_shape_strs(
                graph, i)

            input_str_list.append(
                f"{input_type_str}[{input_shape_str}] {self.input_names[i]}")

            output_str_list.append(
                f"{input_type_str}[{input_shape_str}] {self.output_names[i]}")
            output_str_list.append(
                f"{input_type_str}[{input_shape_str}] {self.output_names[debug_offset+i]}")

            nodes_str_list.append(
                f"{self.output_names[i]} = Identity({self.input_names[i]})\n")
            nodes_str_list.append(
                f"{self.output_names[debug_offset+i]} = Identity({self.input_names[i]})\n")

        # f-string can't have back-slash
        node_str = '\n'.join(nodes_str_list)
        debug_graph = onnx.parser.parse_graph(
            f"""\
            debug ({','.join(input_str_list)}) 
                => ({','.join(output_str_list)})
            {{
                {node_str}
            }}
            """
        )

        onnx.checker.check_graph(debug_graph)
        return debug_graph