microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
onnxruntime_extensions/tools/pre_post_processing/step.py
214lines · modecode
| 1 | # Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | # Licensed under the MIT License. |
| 3 | |
| 4 | import abc |
| 5 | import onnx |
| 6 | |
| 7 | from onnx import parser |
| 8 | from typing import List, Optional, Tuple |
| 9 | |
| 10 | from .utils import ( |
| 11 | IoMapEntry, |
| 12 | create_custom_op_checker_context, |
| 13 | TENSOR_TYPE_TO_ONNX_TYPE, |
| 14 | ) |
| 15 | |
| 16 | |
| 17 | class Step(object): |
| 18 | """Base class for a pre or post processing step.""" |
| 19 | |
| 20 | prefix = "_ppp" |
| 21 | _step_num = 0 # unique step number so we can prefix the naming in the graph created for the step |
| 22 | |
| 23 | def __init__(self, inputs: List[str], outputs: List[str], name: Optional[str] = None): |
| 24 | """ |
| 25 | Initialize the step. |
| 26 | |
| 27 | Args: |
| 28 | inputs: List of default input names. |
| 29 | outputs: List of default output names. |
| 30 | name: Step name. Defaults to the derived class name. |
| 31 | """ |
| 32 | self.step_num = Step._step_num |
| 33 | self.input_names = inputs |
| 34 | self.output_names = outputs |
| 35 | self.name = name if name else f"{self.__class__.__name__}" |
| 36 | self._prefix = f"{Step.prefix}{self.step_num}_" |
| 37 | |
| 38 | Step._step_num += 1 |
| 39 | |
| 40 | def connect(self, entry: IoMapEntry): |
| 41 | """ |
| 42 | Connect the value name from a previous step to an input of this step so they match. |
| 43 | This makes joining the GraphProto created by each step trivial. |
| 44 | """ |
| 45 | assert len(entry.producer.output_names) >= entry.producer_idx |
| 46 | assert len(self.input_names) >= entry.consumer_idx |
| 47 | assert isinstance(entry.producer, Step) |
| 48 | |
| 49 | self.input_names[entry.consumer_idx] = entry.producer.output_names[entry.producer_idx] |
| 50 | |
| 51 | def apply(self, graph: onnx.GraphProto, checker_context: onnx.checker.C.CheckerContext): |
| 52 | """ |
| 53 | Create a graph for this step that can be appended to the provided graph. |
| 54 | The PrePostProcessor will handle merging the two. |
| 55 | """ |
| 56 | |
| 57 | onnx_opset = checker_context.opset_imports[""] |
| 58 | graph_for_step = self._create_graph_for_step(graph, onnx_opset) |
| 59 | onnx.checker.check_graph(graph_for_step, checker_context) |
| 60 | |
| 61 | # prefix the graph for this step to guarantee no clashes of value names with the existing graph |
| 62 | onnx.compose.add_prefix_graph(graph_for_step, self._prefix, inplace=True) |
| 63 | result = self.__merge(graph, graph_for_step) |
| 64 | |
| 65 | # update self.output_names to the prefixed names so that when we connect later Steps the values match |
| 66 | new_outputs = [self._prefix + o for o in self.output_names] |
| 67 | result_outputs = [o.name for o in result.output] |
| 68 | |
| 69 | # sanity check that all of our outputs are in the merged graph |
| 70 | for o in new_outputs: |
| 71 | assert o in result_outputs |
| 72 | |
| 73 | self.output_names = new_outputs |
| 74 | |
| 75 | return result |
| 76 | |
| 77 | @abc.abstractmethod |
| 78 | def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int): |
| 79 | """ |
| 80 | Derived class should implement this and return the GraphProto containing the nodes required to |
| 81 | implement the step. |
| 82 | |
| 83 | Args: |
| 84 | graph: Graph the step will be appended to. Use to determine the types and shapes of values to connect. |
| 85 | onnx_opset: The ONNX opset being targeted. |
| 86 | """ |
| 87 | pass |
| 88 | |
| 89 | def __merge(self, first: onnx.GraphProto, second: onnx.GraphProto): |
| 90 | # We prefixed all the value names in `second`, so allow for that when connecting the two graphs |
| 91 | io_map = [] |
| 92 | for o in first.output: |
| 93 | # apply the same prefix to the output from the previous step to match the prefixed graph from this step |
| 94 | prefixed_output = self._prefix + o.name |
| 95 | for i in second.input: |
| 96 | if i.name == prefixed_output: |
| 97 | io_map.append((o.name, i.name)) |
| 98 | |
| 99 | outputs_to_preserve = None |
| 100 | |
| 101 | # special handling of Debug class. |
| 102 | if isinstance(self, Debug): |
| 103 | # preserve outputs of the first graph so they're available downstream. otherwise they are consumed by |
| 104 | # the Debug node and disappear during the ONNX graph_merge as it considers consumed values to be |
| 105 | # internal - which is entirely reasonable when merging graphs. |
| 106 | # the issue we have is that we don't know what future steps might want things to remain as outputs. |
| 107 | # the current approach is to insert a Debug step which simply duplicates the values so that they are |
| 108 | # guaranteed not be consumed (only one of the two copies will be used). |
| 109 | # doesn't change the number of outputs from the previous step, so it can be transparently inserted in the |
| 110 | # pre/post processing pipeline. |
| 111 | # need to also list the second graph's outputs when manually specifying outputs. |
| 112 | outputs_to_preserve = [o.name for o in first.output] + [o.name for o in second.output] |
| 113 | |
| 114 | # merge with existing graph |
| 115 | merged_graph = onnx.compose.merge_graphs(first, second, io_map, outputs=outputs_to_preserve) |
| 116 | |
| 117 | return merged_graph |
| 118 | |
| 119 | @staticmethod |
| 120 | def _elem_type_str(elem_type: int): |
| 121 | return TENSOR_TYPE_TO_ONNX_TYPE[elem_type] |
| 122 | |
| 123 | @staticmethod |
| 124 | def _shape_to_str(shape: onnx.TensorShapeProto): |
| 125 | """Returns the values from the shape as a comma separated string.""" |
| 126 | |
| 127 | def dim_to_str(dim): |
| 128 | if dim.HasField("dim_value"): |
| 129 | return str(dim.dim_value) |
| 130 | elif dim.HasField("dim_param"): |
| 131 | return dim.dim_param |
| 132 | else: |
| 133 | return "" |
| 134 | |
| 135 | shape_str = ",".join([dim_to_str(dim) for dim in shape.dim]) |
| 136 | return shape_str |
| 137 | |
| 138 | def _input_tensor_type(self, graph: onnx.GraphProto, input_num: int) -> onnx.TensorProto: |
| 139 | """Get the onnx.TensorProto for the input from the outputs of the graph we're appending to.""" |
| 140 | |
| 141 | input_type = None |
| 142 | for o in graph.output: |
| 143 | if o.name == self.input_names[input_num]: |
| 144 | input_type = o.type.tensor_type |
| 145 | break |
| 146 | |
| 147 | if not input_type: |
| 148 | raise ValueError(f"Input {self.input_names[input_num]} was not found in outputs of graph.") |
| 149 | |
| 150 | return input_type |
| 151 | |
| 152 | def _get_input_type_and_shape_strs(self, graph: onnx.GraphProto, input_num: int) -> Tuple[str, str]: |
| 153 | input_type = self._input_tensor_type(graph, input_num) |
| 154 | return Step._elem_type_str(input_type.elem_type), Step._shape_to_str(input_type.shape) |
| 155 | |
| 156 | |
| 157 | # special case. we include the helper Debug step here as logic in the base class is conditional on it. |
| 158 | class Debug(Step): |
| 159 | """ |
| 160 | Step that can be arbitrarily inserted in the pre or post processing pipeline. |
| 161 | It will make the outputs of the previous Step also become graph outputs so their value can be more easily debugged. |
| 162 | |
| 163 | NOTE: Depending on when the previous Step's outputs are consumed in the pipeline the graph output for it |
| 164 | may or may not have '_debug' as a suffix. |
| 165 | TODO: PrePostProcessor __cleanup_graph_output_names could also hide the _debug by inserting an Identity node |
| 166 | to rename so it's more consistent. |
| 167 | """ |
| 168 | |
| 169 | def __init__(self, num_inputs: int = 1, name: Optional[str] = None): |
| 170 | """ |
| 171 | Initialize Debug step |
| 172 | Args: |
| 173 | num_inputs: Number of inputs from previous Step to make graph outputs. |
| 174 | name: Optional name for Step. Defaults to 'Debug' |
| 175 | """ |
| 176 | self._num_inputs = num_inputs |
| 177 | input_names = [f"input{i}" for i in range(0, num_inputs)] |
| 178 | output_names = [f"debug{i}" for i in range(0, num_inputs)] |
| 179 | |
| 180 | super().__init__(input_names, output_names, name) |
| 181 | |
| 182 | def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int): |
| 183 | input_str = "" |
| 184 | output_str = "" |
| 185 | output_debug_str = "" |
| 186 | nodes_str = "" |
| 187 | |
| 188 | # update output names so we preserve info from the latest input names |
| 189 | self.output_names = [f"{name}_debug" for name in self.input_names] |
| 190 | |
| 191 | for i in range(0, self._num_inputs): |
| 192 | input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, i) |
| 193 | if i > 0: |
| 194 | input_str += ", " |
| 195 | output_str += ", " |
| 196 | output_debug_str += ", " |
| 197 | nodes_str += "\n" |
| 198 | |
| 199 | input_str += f"{input_type_str}[{input_shape_str}] {self.input_names[i]}" |
| 200 | output_str += f"{input_type_str}[{input_shape_str}] {self.output_names[i]}" |
| 201 | nodes_str += f"{self.output_names[i]} = Identity({self.input_names[i]})\n" |
| 202 | |
| 203 | debug_graph = onnx.parser.parse_graph( |
| 204 | f"""\ |
| 205 | debug ({input_str}) |
| 206 | => ({output_str}) |
| 207 | {{ |
| 208 | {nodes_str} |
| 209 | }} |
| 210 | """ |
| 211 | ) |
| 212 | |
| 213 | onnx.checker.check_graph(debug_graph) |
| 214 | return debug_graph |
| 215 | |