microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
onnxruntime_extensions/tools/pre_post_processing/pre_post_processor.py
388lines · modecode
| 1 | # Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | # Licensed under the MIT License. |
| 3 | |
| 4 | import onnx |
| 5 | |
| 6 | from onnx import version_converter |
| 7 | from typing import List, Tuple, Union |
| 8 | |
| 9 | from .utils import ( |
| 10 | IoMapEntry, |
| 11 | IOEntryValuePreserver, |
| 12 | create_custom_op_checker_context, |
| 13 | sanitize_output_names, |
| 14 | TENSOR_TYPE_TO_ONNX_TYPE, |
| 15 | ) |
| 16 | from .step import Step |
| 17 | |
| 18 | |
| 19 | class PrePostProcessor: |
| 20 | """ |
| 21 | Class to handle running all the pre/post processing steps and updating the model. |
| 22 | """ |
| 23 | |
| 24 | def __init__(self, inputs: List[onnx.ValueInfoProto] = None, onnx_opset: int = 16): |
| 25 | """ |
| 26 | Create a PrePostProcessor instance. |
| 27 | |
| 28 | Args: |
| 29 | inputs: The inputs the model will use if pre-processing is added. |
| 30 | onnx_opset: The ONNX opset to use. |
| 31 | Minimum is 16. 18 or higher is strongly preferred if image resizing is involved due to its |
| 32 | anti-aliasing ability. |
| 33 | """ |
| 34 | |
| 35 | if onnx_opset < 16: |
| 36 | raise ValueError("ONNX opset must be 16 or later.") |
| 37 | |
| 38 | self._onnx_opset = onnx_opset |
| 39 | self._custom_op_checker_context = create_custom_op_checker_context(onnx_opset) |
| 40 | |
| 41 | self.pre_processors = [] |
| 42 | self.post_processors = [] |
| 43 | |
| 44 | # Connections for each pre/post processor. 1:1 mapping with entries in pre_processors/post_processors |
| 45 | self._pre_processor_connections = [] # type: List[List[IoMapEntry]] |
| 46 | self._post_processor_connections = [] # type: List[List[IoMapEntry]] |
| 47 | |
| 48 | # explicitly join outputs from Steps in pre_processors to inputs of the original model |
| 49 | # format is Step or step name, step_idx, name of graph input/output |
| 50 | # |
| 51 | # Pre-processing we connect Step output to original model: |
| 52 | # - step_idx is for Step.output_names, and name is in graph.input |
| 53 | # |
| 54 | # Post-processing we connect the original model output to the Step input |
| 55 | # - step_idx is for Step.input_names, and name is in graph.output |
| 56 | self._pre_processing_joins = None # type: Union[None,List[Tuple[Union[Step, str], int, str]]] |
| 57 | self._post_processing_joins = None # type: Union[None,List[Tuple[Union[Step, str], int, str]]] |
| 58 | |
| 59 | self._inputs = inputs if inputs else [] |
| 60 | |
| 61 | # preserve outputs from IOMapEntry, avoid it's consumed by the Follow-up steps. |
| 62 | # we now can support a output value has more than one consumers with IOEntryValuePreserver. |
| 63 | # IOEntryValuePreserver will preserve the output value and add it to the graph output |
| 64 | # until consumer step is done. |
| 65 | self.outputs_preserver = [] # type: List[IOEntryValuePreserver] |
| 66 | |
| 67 | def add_pre_processing(self, items: List[Union[Step, Tuple[Step, List[IoMapEntry]]]]): |
| 68 | """ |
| 69 | Add the pre-processing steps. The last step is automatically joined to the original model inputs. |
| 70 | |
| 71 | Options are: |
| 72 | Add Step with default connection of outputs from the previous step (if available) to inputs of this step. |
| 73 | Add tuple of Step and list of IoMapEntry instances for manual connections to previous steps. This will be |
| 74 | used to override any automatic connections. |
| 75 | If IoMapEntry.producer is None it is inferred to be the immediately previous Step. |
| 76 | If IoMapEntry.producer is a step name it must match the name of a previous step. |
| 77 | """ |
| 78 | self.__add_processing(self.pre_processors, self._pre_processor_connections, items) |
| 79 | |
| 80 | def add_post_processing(self, items: List[Union[Step, Tuple[Step, List[IoMapEntry]]]]): |
| 81 | """ |
| 82 | Add the post-processing steps. The first step is automatically joined to the original model outputs. |
| 83 | |
| 84 | Options are: |
| 85 | Add Step with default connection of outputs from the previous step (if available) to inputs of this step. |
| 86 | Add tuple of Step and list of IoMapEntry instances for connections to previous steps. This will be |
| 87 | used to override any automatic connections. |
| 88 | If IoMapEntry.producer is None it is inferred to be the immediately previous Step. |
| 89 | If IoMapEntry.producer is a step name it must match the name of a previous step. |
| 90 | """ |
| 91 | self.__add_processing(self.post_processors, self._post_processor_connections, items) |
| 92 | |
| 93 | def _add_connection(self, consumer: Step, entry: IoMapEntry): |
| 94 | producer = self.__producer_from_step_or_str(entry.producer) |
| 95 | |
| 96 | # Black does annoying things with the multi-line 'if' conditions making the code far less readable |
| 97 | # fmt: off |
| 98 | if not ((producer in self.pre_processors or producer in self.post_processors) and |
| 99 | (consumer in self.pre_processors or consumer in self.post_processors)): |
| 100 | raise ValueError("Producer and Consumer processors must both be registered") |
| 101 | |
| 102 | if producer in self.pre_processors: |
| 103 | if (consumer in self.pre_processors and |
| 104 | self.pre_processors.index(producer) > self.pre_processors.index(consumer)): |
| 105 | raise ValueError("Producer was registered after consumer and cannot be connected") |
| 106 | elif producer in self.post_processors: |
| 107 | if consumer not in self.post_processors: |
| 108 | raise ValueError("Cannot connect pre-processor consumer with post-processor producer") |
| 109 | elif self.post_processors.index(producer) > self.post_processors.index(consumer): |
| 110 | raise ValueError("Producer was registered after consumer and cannot be connected") |
| 111 | # fmt: on |
| 112 | |
| 113 | assert isinstance(producer, Step) |
| 114 | consumer.connect(entry) |
| 115 | |
| 116 | |
| 117 | def run(self, model: onnx.ModelProto): |
| 118 | """ |
| 119 | Update the model with the graph from each step in the pre and post processing pipelines. |
| 120 | |
| 121 | Args: |
| 122 | model: model to add pre/post processing to. |
| 123 | |
| 124 | Returns: |
| 125 | model with pre/post processing in it. |
| 126 | """ |
| 127 | |
| 128 | # update the input model to the ONNX opset we're using. this is required as we implement the steps based on |
| 129 | # the operator specs for this opset. |
| 130 | model_opset = [ |
| 131 | entry.version for entry in model.opset_import if entry.domain == "" or entry.domain == "ai.onnx" |
| 132 | ][0] |
| 133 | |
| 134 | if model_opset > self._onnx_opset: |
| 135 | # It will probably work if the user updates PRE_POST_PROCESSING_ONNX_OPSET to match the model |
| 136 | # but there are no guarantees. |
| 137 | # Would only break if ONNX operators used in the pre/post processing graphs have had spec changes. |
| 138 | raise ValueError(f"Model opset is {model_opset} which is newer than the opset used by this script.") |
| 139 | elif model_opset < self._onnx_opset: |
| 140 | model = onnx.version_converter.convert_version(model, self._onnx_opset) |
| 141 | |
| 142 | def name_nodes(new_graph: onnx.GraphProto, prefix: str): |
| 143 | # simple helper so all nodes are named. this makes it far easier to debug any issues. |
| 144 | idx = 0 |
| 145 | for n in new_graph.node: |
| 146 | if not n.name: |
| 147 | n.name = prefix + str(idx) |
| 148 | idx += 1 |
| 149 | |
| 150 | def preserved_apply(processor: Step, *args): |
| 151 | # Trying to activate the IOEntryValuePreserver and preserve outputs. |
| 152 | # and deactivate the outputs when the current graph consumed them |
| 153 | |
| 154 | for preserver in self.outputs_preserver: |
| 155 | if preserver.consumer == processor: |
| 156 | preserver.is_active = False |
| 157 | |
| 158 | # IOEntryValuePreserver, preserve those outputs which has multiple consumers. |
| 159 | # we explicitly add the output to the graph output. |
| 160 | graph_outputs_to_maintain = [i.output for i in self.outputs_preserver if i.is_active] |
| 161 | graph_for_step = processor.apply(*args, graph_outputs_to_maintain=graph_outputs_to_maintain) |
| 162 | |
| 163 | for preserver in self.outputs_preserver: |
| 164 | if preserver.producer == processor: |
| 165 | preserver.is_active = True |
| 166 | preserver.output = processor.output_names[preserver.producer_idx] |
| 167 | return graph_for_step |
| 168 | |
| 169 | def connect_and_run(graph: onnx.GraphProto, processor: Step, connections: List[IoMapEntry]): |
| 170 | for connection in connections: |
| 171 | assert connection.producer |
| 172 | self._add_connection(processor, connection) |
| 173 | |
| 174 | return preserved_apply(processor, graph, self._custom_op_checker_context) |
| 175 | |
| 176 | # fix any invalid output names now if we're adding post-processing as the onnx parse_graph can't handle them |
| 177 | if self.post_processors: |
| 178 | sanitize_output_names(model.graph) |
| 179 | |
| 180 | graph = model.graph |
| 181 | # add pre-processing |
| 182 | if self.pre_processors: |
| 183 | # create empty graph with pass through of the requested input name |
| 184 | pre_process_graph = onnx.GraphProto() |
| 185 | for i in self._inputs: |
| 186 | pre_process_graph.input.append(i) |
| 187 | pre_process_graph.output.append(i) |
| 188 | |
| 189 | for idx, step in enumerate(self.pre_processors): |
| 190 | pre_process_graph = connect_and_run(pre_process_graph, step, self._pre_processor_connections[idx]) |
| 191 | |
| 192 | # name all the nodes for easier debugging |
| 193 | name_nodes(pre_process_graph, "pre_process_") |
| 194 | |
| 195 | if not self._pre_processing_joins: |
| 196 | # default to 1:1 between outputs of last step with inputs of original model |
| 197 | last_step = self.pre_processors[-1] |
| 198 | num_entries = min(len(last_step.output_names), len(graph.input)) |
| 199 | self._pre_processing_joins = [(last_step, i, graph.input[i].name) for i in range(0, num_entries)] |
| 200 | |
| 201 | # map the pre-processing outputs to graph inputs |
| 202 | # we may need a natty way to get possible outputs after merge_graphs |
| 203 | step_graph_outputs = [o.name for o in pre_process_graph.output] |
| 204 | io_map = [] # type: List[Tuple[str, str]] |
| 205 | for step, step_idx, graph_input in self._pre_processing_joins: |
| 206 | io_map.append((step.output_names[step_idx], graph_input)) |
| 207 | step_graph_outputs.remove((step.output_names[step_idx])) |
| 208 | |
| 209 | # add outputs from previous IoMapEntry producers to maintain them as graph outputs |
| 210 | # until consumed by the final Step that requires them. |
| 211 | step_graph_outputs += [ |
| 212 | o.name for o in graph.output if o.name not in step_graph_outputs] |
| 213 | external_outputs = [ |
| 214 | i.output for i in self.outputs_preserver if i.is_active and i.output not in step_graph_outputs] |
| 215 | if external_outputs: |
| 216 | step_graph_outputs.extend(external_outputs) |
| 217 | graph = onnx.compose.merge_graphs(pre_process_graph, graph, io_map, outputs=step_graph_outputs) |
| 218 | |
| 219 | # add post-processing |
| 220 | if self.post_processors: |
| 221 | orig_model_outputs = [o.name for o in model.graph.output] |
| 222 | graph_outputs = [o.name for o in graph.output] # this may have additional outputs from pre-processing |
| 223 | |
| 224 | # create default joins if needed |
| 225 | if not self._post_processing_joins: |
| 226 | # default to 1:1 between outputs of original model with inputs of first post-processing step |
| 227 | first_step = self.post_processors[0] |
| 228 | num_entries = min(len(first_step.input_names), len(orig_model_outputs)) |
| 229 | self._post_processing_joins = [(first_step, i, orig_model_outputs[i]) for i in range(0, num_entries)] |
| 230 | |
| 231 | # update the input names for the steps to match the values produced by the model |
| 232 | for step, step_idx, graph_output in self._post_processing_joins: |
| 233 | assert graph_output in graph_outputs |
| 234 | step.input_names[step_idx] = graph_output |
| 235 | |
| 236 | # create empty graph with the values that will be available to the post-processing |
| 237 | post_process_graph = onnx.GraphProto() |
| 238 | for o in graph.output: |
| 239 | post_process_graph.input.append(o) |
| 240 | post_process_graph.output.append(o) |
| 241 | |
| 242 | for idx, step in enumerate(self.post_processors): |
| 243 | post_process_graph = connect_and_run(post_process_graph, step, self._post_processor_connections[idx]) |
| 244 | |
| 245 | name_nodes(post_process_graph, "post_process_") |
| 246 | |
| 247 | # io_map should be 1:1 with the post-processing graph given we updated the step input names to match |
| 248 | io_map = [(o, o) for o in graph_outputs] |
| 249 | graph = onnx.compose.merge_graphs(graph, post_process_graph, io_map) |
| 250 | |
| 251 | # Make the output names nicer by removing prefixing from naming that occurred when applying the steps |
| 252 | graph = PrePostProcessor.__cleanup_graph_output_names(graph) |
| 253 | |
| 254 | opset_imports = [onnx.helper.make_operatorsetid(domain, opset) |
| 255 | for domain, opset in self._custom_op_checker_context.opset_imports.items()] |
| 256 | new_model = onnx.helper.make_model(graph, opset_imports=opset_imports) |
| 257 | |
| 258 | onnx.checker.check_model(new_model) |
| 259 | |
| 260 | return new_model |
| 261 | |
| 262 | def __add_processing( |
| 263 | self, |
| 264 | processors: List[Step], |
| 265 | processor_connections: List[List[IoMapEntry]], |
| 266 | items: List[Union[Step, Tuple[Step, List[IoMapEntry]]]], |
| 267 | ): |
| 268 | """ |
| 269 | Add the pre/post processing steps and join with existing steps. |
| 270 | |
| 271 | Args: |
| 272 | processors: List of processors to add items to. |
| 273 | processor_connections: Populated with connections between each step. 1:1 with entries in processors. |
| 274 | items: Items to add to processors. |
| 275 | Can be: |
| 276 | A Step instance. This will be implicitly joined to the immediately previous Step if one exists. |
| 277 | A tuple of (Step instance, list of IoMapEntry) |
| 278 | The IoMapEntry values are used to manually join an output from a producer Step to an input |
| 279 | of the current Step. |
| 280 | In each IoMapEntry, if a step name is provided the producer Step will be searched for in all |
| 281 | predecessor steps. It is valid for a post-processor step to consume output from a |
| 282 | pre-processor step. |
| 283 | """ |
| 284 | |
| 285 | for item in items: |
| 286 | step = None |
| 287 | explicit_io_map_entries = None |
| 288 | |
| 289 | if isinstance(item, Step): |
| 290 | step = item |
| 291 | elif isinstance(item, tuple): |
| 292 | step, explicit_io_map_entries = item |
| 293 | else: |
| 294 | raise ValueError("Unexpected type " + str(type(item))) |
| 295 | |
| 296 | # start with implicit joins and replace with explicitly provided ones |
| 297 | # this allows the user to specify the minimum number of manual joins. |
| 298 | io_map_entries = [None] * len(step.input_names) # type: List[Union[None,IoMapEntry]] |
| 299 | prev_step = None if len(processors) == 0 else processors[-1] |
| 300 | if prev_step: |
| 301 | # default is connecting as many outputs from the previous step as possible |
| 302 | for i in range(0, min(len(prev_step.output_names), len(step.input_names))): |
| 303 | io_map_entries[i] = IoMapEntry(prev_step, i, i) |
| 304 | |
| 305 | # add explicit connections |
| 306 | if explicit_io_map_entries: |
| 307 | for entry in explicit_io_map_entries: |
| 308 | if not entry.producer: |
| 309 | producer = prev_step |
| 310 | else: |
| 311 | producer = self.__producer_from_step_or_str(entry.producer) # throws if not found |
| 312 | |
| 313 | io_map_entries[entry.consumer_idx] = IoMapEntry(producer, entry.producer_idx, entry.consumer_idx) |
| 314 | self.outputs_preserver.append(IOEntryValuePreserver(producer, step, entry.producer_idx)) |
| 315 | |
| 316 | processors.append(step) |
| 317 | processor_connections.append([entry for entry in io_map_entries if entry is not None]) |
| 318 | |
| 319 | def __producer_from_step_or_str(self, entry: Union[Step, str]): |
| 320 | if isinstance(entry, Step): |
| 321 | return entry |
| 322 | if isinstance(entry, str): |
| 323 | match = (next((s for s in self.pre_processors if s.name == entry), None) or |
| 324 | next((s for s in self.post_processors if s.name == entry), None)) # fmt: skip |
| 325 | |
| 326 | if not match: |
| 327 | raise ValueError(f"Step named {entry} was not found") |
| 328 | |
| 329 | return match |
| 330 | |
| 331 | @staticmethod |
| 332 | def __cleanup_graph_output_names(graph: onnx.GraphProto): |
| 333 | """ |
| 334 | Hide the prefixing of names that happens when we merge the graphs from the pre/post processing steps. |
| 335 | Not essential but makes the graph outputs look far nicer. |
| 336 | """ |
| 337 | |
| 338 | # for each output create identity node to remove prefixing |
| 339 | io_map = [] |
| 340 | fixes = onnx.GraphProto() |
| 341 | |
| 342 | # manually handle naming clashes |
| 343 | input_names = set([i.name for i in graph.input]) |
| 344 | used_names = set(input_names) |
| 345 | conflicts = 0 |
| 346 | |
| 347 | for o in graph.output: |
| 348 | if not o.name.startswith(Step.prefix): |
| 349 | continue |
| 350 | |
| 351 | # we will create a small graph to do the renames so the output of the original graph will be an input |
| 352 | # to that 'fixer' graph |
| 353 | io_map.append((o.name, o.name)) |
| 354 | clean_name = o.name |
| 355 | while clean_name.startswith(Step.prefix): |
| 356 | # output from last step will have one prefixing stage that adds Step._prefix + '_' |
| 357 | # e.g. '_ppp8_<orig_name>' |
| 358 | next_underscore = clean_name.find("_", 1) |
| 359 | if next_underscore > 0: |
| 360 | # this check shouldn't be necessary as we always add the trailing '_' when prefixing... |
| 361 | if len(clean_name) > next_underscore + 1: |
| 362 | next_underscore += 1 |
| 363 | clean_name = clean_name[next_underscore:] |
| 364 | |
| 365 | # handle things like super resolution where there's an 'image' input and 'image' output |
| 366 | if clean_name in input_names: |
| 367 | clean_name += "_out" |
| 368 | |
| 369 | orig_clean_name = clean_name |
| 370 | while clean_name in used_names: |
| 371 | conflicts += 1 |
| 372 | clean_name = f"{orig_clean_name}{conflicts}" |
| 373 | |
| 374 | used_names.add(clean_name) |
| 375 | |
| 376 | renamer = onnx.helper.make_node("Identity", [o.name], [clean_name], f"Rename {o.name}") |
| 377 | fixes.node.append(renamer) |
| 378 | fixes.input.append(o) |
| 379 | |
| 380 | new_output = fixes.output.add() |
| 381 | new_output.name = clean_name |
| 382 | new_output.type.CopyFrom(o.type) |
| 383 | |
| 384 | # merge if we have any renaming to do |
| 385 | if io_map: |
| 386 | graph = onnx.compose.merge_graphs(graph, fixes, io_map) |
| 387 | |
| 388 | return graph |
| 389 | |