microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
onnxruntime_extensions/tools/pre_post_processing/step.py
222lines · modecode
| 1 | # Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | # Licensed under the MIT License. |
| 3 | |
| 4 | import abc |
| 5 | import onnx |
| 6 | |
| 7 | from onnx import parser |
| 8 | from typing import List, Optional, Tuple |
| 9 | |
| 10 | from .utils import ( |
| 11 | IoMapEntry, |
| 12 | create_custom_op_checker_context, |
| 13 | TENSOR_TYPE_TO_ONNX_TYPE, |
| 14 | ) |
| 15 | |
| 16 | |
| 17 | class Step(object): |
| 18 | """Base class for a pre or post processing step.""" |
| 19 | |
| 20 | prefix = "_ppp" |
| 21 | _step_num = 0 # unique step number so we can prefix the naming in the graph created for the step |
| 22 | |
| 23 | def __init__(self, inputs: List[str], outputs: List[str], name: Optional[str] = None): |
| 24 | """ |
| 25 | Initialize the step. |
| 26 | |
| 27 | Args: |
| 28 | inputs: List of default input names. |
| 29 | outputs: List of default output names. |
| 30 | name: Step name. Defaults to the derived class name. |
| 31 | """ |
| 32 | self.step_num = Step._step_num |
| 33 | self.input_names = inputs |
| 34 | self.output_names = outputs |
| 35 | self.name = name if name else f"{self.__class__.__name__}" |
| 36 | self._prefix = f"{Step.prefix}{self.step_num}_" |
| 37 | |
| 38 | Step._step_num += 1 |
| 39 | |
| 40 | def connect(self, entry: IoMapEntry): |
| 41 | """ |
| 42 | Connect the value name from a previous step to an input of this step so they match. |
| 43 | This makes joining the GraphProto created by each step trivial. |
| 44 | """ |
| 45 | assert len(entry.producer.output_names) >= entry.producer_idx |
| 46 | assert len(self.input_names) >= entry.consumer_idx |
| 47 | assert isinstance(entry.producer, Step) |
| 48 | |
| 49 | self.input_names[entry.consumer_idx] = entry.producer.output_names[entry.producer_idx] |
| 50 | |
| 51 | def apply(self, graph: onnx.GraphProto, |
| 52 | checker_context: onnx.checker.C.CheckerContext, |
| 53 | graph_outputs_to_maintain: List[str]): |
| 54 | """ |
| 55 | Create a graph for this step that can be appended to the provided graph. |
| 56 | The PrePostProcessor will handle merging the two. |
| 57 | |
| 58 | Args: |
| 59 | graph_outputs_to_maintain: List of output names to maintain in the graph by additional effort. |
| 60 | For outputs having multiple consumers, these outputs will be consumed by default and prevent |
| 61 | connection from the subsequent steps. |
| 62 | This outputs is generated by IOEntryValuePreserver. |
| 63 | """ |
| 64 | |
| 65 | onnx_opset = checker_context.opset_imports[""] |
| 66 | graph_for_step = self._create_graph_for_step(graph, onnx_opset) |
| 67 | onnx.checker.check_graph(graph_for_step, checker_context) |
| 68 | |
| 69 | # prefix the graph for this step to guarantee no clashes of value names with the existing graph |
| 70 | onnx.compose.add_prefix_graph(graph_for_step, self._prefix, inplace=True) |
| 71 | result = self.__merge(graph, graph_for_step, graph_outputs_to_maintain) |
| 72 | |
| 73 | # update self.output_names to the prefixed names so that when we connect later Steps the values match |
| 74 | new_outputs = [self._prefix + o for o in self.output_names] |
| 75 | result_outputs = [o.name for o in result.output] |
| 76 | |
| 77 | # sanity check that all of our outputs are in the merged graph |
| 78 | for o in new_outputs: |
| 79 | assert o in result_outputs |
| 80 | |
| 81 | self.output_names = new_outputs |
| 82 | |
| 83 | return result |
| 84 | |
| 85 | @abc.abstractmethod |
| 86 | def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int): |
| 87 | """ |
| 88 | Derived class should implement this and return the GraphProto containing the nodes required to |
| 89 | implement the step. |
| 90 | |
| 91 | Args: |
| 92 | graph: Graph the step will be appended to. Use to determine the types and shapes of values to connect. |
| 93 | onnx_opset: The ONNX opset being targeted. |
| 94 | """ |
| 95 | pass |
| 96 | |
| 97 | def __merge(self, first: onnx.GraphProto, second: onnx.GraphProto, |
| 98 | graph_outputs_to_maintain: Optional[List[str]] = None): |
| 99 | # We prefixed all the value names in `second`, so allow for that when connecting the two graphs |
| 100 | first_output = [o.name for o in first.output] |
| 101 | io_map = [] |
| 102 | for o in first.output: |
| 103 | # apply the same prefix to the output from the previous step to match the prefixed graph from this step |
| 104 | prefixed_output = self._prefix + o.name |
| 105 | for i in second.input: |
| 106 | if i.name == prefixed_output: |
| 107 | io_map.append((o.name, i.name)) |
| 108 | first_output.remove(o.name) |
| 109 | |
| 110 | graph_outputs = first_output + [o.name for o in second.output if o.name not in first_output] |
| 111 | graph_outputs += [o for o in graph_outputs_to_maintain if o not in graph_outputs] |
| 112 | |
| 113 | # merge with existing graph |
| 114 | merged_graph = onnx.compose.merge_graphs(first, second, io_map, outputs=graph_outputs) |
| 115 | |
| 116 | return merged_graph |
| 117 | |
| 118 | @staticmethod |
| 119 | def _elem_type_str(elem_type: int): |
| 120 | return TENSOR_TYPE_TO_ONNX_TYPE[elem_type] |
| 121 | |
| 122 | @staticmethod |
| 123 | def _shape_to_str(shape: onnx.TensorShapeProto): |
| 124 | """Returns the values from the shape as a comma separated string.""" |
| 125 | |
| 126 | def dim_to_str(dim): |
| 127 | if dim.HasField("dim_value"): |
| 128 | return str(dim.dim_value) |
| 129 | elif dim.HasField("dim_param"): |
| 130 | return dim.dim_param |
| 131 | else: |
| 132 | return "" |
| 133 | |
| 134 | shape_str = ",".join([dim_to_str(dim) for dim in shape.dim]) |
| 135 | return shape_str |
| 136 | |
| 137 | def _input_tensor_type(self, graph: onnx.GraphProto, input_num: int) -> onnx.TensorProto: |
| 138 | """Get the onnx.TensorProto for the input from the outputs of the graph we're appending to.""" |
| 139 | |
| 140 | input_type = None |
| 141 | for o in graph.output: |
| 142 | if o.name == self.input_names[input_num]: |
| 143 | input_type = o.type.tensor_type |
| 144 | break |
| 145 | |
| 146 | if not input_type: |
| 147 | raise ValueError(f"Input {self.input_names[input_num]} was not found in outputs of graph.") |
| 148 | |
| 149 | return input_type |
| 150 | |
| 151 | def _get_input_type_and_shape_strs(self, graph: onnx.GraphProto, input_num: int) -> Tuple[str, str]: |
| 152 | input_type = self._input_tensor_type(graph, input_num) |
| 153 | return Step._elem_type_str(input_type.elem_type), Step._shape_to_str(input_type.shape) |
| 154 | |
| 155 | |
| 156 | class Debug(Step): |
| 157 | """ |
| 158 | Step that can be arbitrarily inserted in the pre or post processing pipeline. |
| 159 | It will make the outputs of the previous Step also become graph outputs so their value can be more easily debugged. |
| 160 | |
| 161 | The output will be duplicated into two outputs, one will be renamed with a suffix "_next", |
| 162 | another will be renamed with a suffix "_debug". The "_next" outputs will feed into the next step, |
| 163 | the "_debug" outputs will become graph outputs. |
| 164 | """ |
| 165 | |
| 166 | def __init__(self, num_inputs: int = 1, name: Optional[str] = None): |
| 167 | """ |
| 168 | Initialize Debug step |
| 169 | Args: |
| 170 | num_inputs: Number of inputs from previous Step to make graph outputs. |
| 171 | name: Optional name for Step. Defaults to 'Debug' |
| 172 | """ |
| 173 | self._num_inputs = num_inputs |
| 174 | input_names = [f"input{i}" for i in range(0, num_inputs)] |
| 175 | output_names = [f"debug{i}" for i in range(0, num_inputs)] |
| 176 | |
| 177 | super().__init__(input_names, output_names, name) |
| 178 | |
| 179 | def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int): |
| 180 | if self._num_inputs > len(graph.output): |
| 181 | raise ValueError( |
| 182 | f"Debug step requested {self._num_inputs} inputs, but graph only has {len(graph.output)} outputs.") |
| 183 | |
| 184 | debug_offset = len(self.input_names) |
| 185 | # update output names so we preserve info from the latest input names |
| 186 | self.output_names = [f"{name}_next" for name in self.input_names] |
| 187 | self.output_names += [f"{name}_debug" for name in self.input_names] |
| 188 | |
| 189 | input_str_list = [] |
| 190 | output_str_list = [] |
| 191 | nodes_str_list = [] |
| 192 | for i in range(0, self._num_inputs): |
| 193 | input_type_str, input_shape_str = self._get_input_type_and_shape_strs( |
| 194 | graph, i) |
| 195 | |
| 196 | input_str_list.append( |
| 197 | f"{input_type_str}[{input_shape_str}] {self.input_names[i]}") |
| 198 | |
| 199 | output_str_list.append( |
| 200 | f"{input_type_str}[{input_shape_str}] {self.output_names[i]}") |
| 201 | output_str_list.append( |
| 202 | f"{input_type_str}[{input_shape_str}] {self.output_names[debug_offset+i]}") |
| 203 | |
| 204 | nodes_str_list.append( |
| 205 | f"{self.output_names[i]} = Identity({self.input_names[i]})\n") |
| 206 | nodes_str_list.append( |
| 207 | f"{self.output_names[debug_offset+i]} = Identity({self.input_names[i]})\n") |
| 208 | |
| 209 | # f-string can't have back-slash |
| 210 | node_str = '\n'.join(nodes_str_list) |
| 211 | debug_graph = onnx.parser.parse_graph( |
| 212 | f"""\ |
| 213 | debug ({','.join(input_str_list)}) |
| 214 | => ({','.join(output_str_list)}) |
| 215 | {{ |
| 216 | {node_str} |
| 217 | }} |
| 218 | """ |
| 219 | ) |
| 220 | |
| 221 | onnx.checker.check_graph(debug_graph) |
| 222 | return debug_graph |
| 223 | |