microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
rel-0.9

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/tools/pre_post_processing/pre_post_processor.py

392lines · modecode

1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License.
3
4import onnx
5
6from onnx import version_converter
7from typing import List, Tuple, Union
8
9from .utils import (
10 IoMapEntry,
11 IOEntryValuePreserver,
12 create_custom_op_checker_context,
13 sanitize_output_names,
14 TENSOR_TYPE_TO_ONNX_TYPE,
15)
16from .step import Step
17
18
19class PrePostProcessor:
20 """
21 Class to handle running all the pre/post processing steps and updating the model.
22 """
23
24 def __init__(self, inputs: List[onnx.ValueInfoProto] = None, onnx_opset: int = 16):
25 """
26 Create a PrePostProcessor instance.
27
28 Args:
29 inputs: The inputs the model will use if pre-processing is added.
30 onnx_opset: The ONNX opset to use.
31 Minimum is 16. 18 or higher is strongly preferred if image resizing is involved due to its
32 anti-aliasing ability.
33 """
34
35 if onnx_opset < 16:
36 raise ValueError("ONNX opset must be 16 or later.")
37
38 self._onnx_opset = onnx_opset
39 self._custom_op_checker_context = create_custom_op_checker_context(onnx_opset)
40
41 self.pre_processors = []
42 self.post_processors = []
43
44 # Connections for each pre/post processor. 1:1 mapping with entries in pre_processors/post_processors
45 self._pre_processor_connections = [] # type: List[List[IoMapEntry]]
46 self._post_processor_connections = [] # type: List[List[IoMapEntry]]
47
48 # explicitly join outputs from Steps in pre_processors to inputs of the original model
49 # format is Step or step name, step_idx, name of graph input/output
50 #
51 # Pre-processing we connect Step output to original model:
52 # - step_idx is for Step.output_names, and name is in graph.input
53 #
54 # Post-processing we connect the original model output to the Step input
55 # - step_idx is for Step.input_names, and name is in graph.output
56 self._pre_processing_joins = None # type: Union[None,List[Tuple[Union[Step, str], int, str]]]
57 self._post_processing_joins = None # type: Union[None,List[Tuple[Union[Step, str], int, str]]]
58
59 self._inputs = inputs if inputs else []
60
61 # preserve outputs from IOMapEntry, avoid it's consumed by the Follow-up steps.
62 # we now can support a output value has more than one consumers with IOEntryValuePreserver.
63 # IOEntryValuePreserver will preserve the output value and add it to the graph output
64 # until consumer step is done.
65 self.outputs_preserver = [] # type: List[IOEntryValuePreserver]
66
67 def add_pre_processing(self, items: List[Union[Step, Tuple[Step, List[IoMapEntry]]]]):
68 """
69 Add the pre-processing steps. The last step is automatically joined to the original model inputs.
70
71 Options are:
72 Add Step with default connection of outputs from the previous step (if available) to inputs of this step.
73 Add tuple of Step and list of IoMapEntry instances for manual connections to previous steps. This will be
74 used to override any automatic connections.
75 If IoMapEntry.producer is None it is inferred to be the immediately previous Step.
76 If IoMapEntry.producer is a step name it must match the name of a previous step.
77 """
78 self.__add_processing(self.pre_processors, self._pre_processor_connections, items)
79
80 def add_post_processing(self, items: List[Union[Step, Tuple[Step, List[IoMapEntry]]]]):
81 """
82 Add the post-processing steps. The first step is automatically joined to the original model outputs.
83
84 Options are:
85 Add Step with default connection of outputs from the previous step (if available) to inputs of this step.
86 Add tuple of Step and list of IoMapEntry instances for connections to previous steps. This will be
87 used to override any automatic connections.
88 If IoMapEntry.producer is None it is inferred to be the immediately previous Step.
89 If IoMapEntry.producer is a step name it must match the name of a previous step.
90 """
91 self.__add_processing(self.post_processors, self._post_processor_connections, items)
92
93 def _add_connection(self, consumer: Step, entry: IoMapEntry):
94 producer = self.__producer_from_step_or_str(entry.producer)
95
96 # Black does annoying things with the multi-line 'if' conditions making the code far less readable
97 # fmt: off
98 if not ((producer in self.pre_processors or producer in self.post_processors) and
99 (consumer in self.pre_processors or consumer in self.post_processors)):
100 raise ValueError("Producer and Consumer processors must both be registered")
101
102 if producer in self.pre_processors:
103 if (consumer in self.pre_processors and
104 self.pre_processors.index(producer) > self.pre_processors.index(consumer)):
105 raise ValueError("Producer was registered after consumer and cannot be connected")
106 elif producer in self.post_processors:
107 if consumer not in self.post_processors:
108 raise ValueError("Cannot connect pre-processor consumer with post-processor producer")
109 elif self.post_processors.index(producer) > self.post_processors.index(consumer):
110 raise ValueError("Producer was registered after consumer and cannot be connected")
111 # fmt: on
112
113 assert isinstance(producer, Step)
114 consumer.connect(entry)
115
116
117 def run(self, model: onnx.ModelProto):
118 """
119 Update the model with the graph from each step in the pre and post processing pipelines.
120
121 Args:
122 model: model to add pre/post processing to.
123
124 Returns:
125 model with pre/post processing in it.
126 """
127
128 # update the input model to the ONNX opset we're using. this is required as we implement the steps based on
129 # the operator specs for this opset.
130 model_opset = [
131 entry.version for entry in model.opset_import if entry.domain == "" or entry.domain == "ai.onnx"
132 ][0]
133
134 if model_opset > self._onnx_opset:
135 # It will probably work if the user updates PRE_POST_PROCESSING_ONNX_OPSET to match the model
136 # but there are no guarantees.
137 # Would only break if ONNX operators used in the pre/post processing graphs have had spec changes.
138 raise ValueError(f"Model opset is {model_opset} which is newer than the opset used by this script.")
139 elif model_opset < self._onnx_opset:
140 model = onnx.version_converter.convert_version(model, self._onnx_opset)
141
142 def name_nodes(new_graph: onnx.GraphProto, prefix: str):
143 # simple helper so all nodes are named. this makes it far easier to debug any issues.
144 idx = 0
145 for n in new_graph.node:
146 if not n.name:
147 n.name = prefix + str(idx)
148 idx += 1
149
150 def preserved_apply(processor: Step, *args):
151 # Trying to activate the IOEntryValuePreserver and preserve outputs.
152 # and deactivate the outputs when the current graph consumed them
153
154 for preserver in self.outputs_preserver:
155 if preserver.consumer == processor:
156 preserver.is_active = False
157
158 # IOEntryValuePreserver, preserve those outputs which has multiple consumers.
159 # we explicitly add the output to the graph output.
160 graph_outputs_to_maintain = [i.output for i in self.outputs_preserver if i.is_active]
161 graph_for_step = processor.apply(*args, graph_outputs_to_maintain=graph_outputs_to_maintain)
162
163 for preserver in self.outputs_preserver:
164 if preserver.producer == processor:
165 preserver.is_active = True
166 preserver.output = processor.output_names[preserver.producer_idx]
167 return graph_for_step
168
169 def connect_and_run(graph: onnx.GraphProto, processor: Step, connections: List[IoMapEntry]):
170 for connection in connections:
171 assert connection.producer
172 self._add_connection(processor, connection)
173
174 return preserved_apply(processor, graph, self._custom_op_checker_context)
175
176 # fix any invalid output names now if we're adding post-processing as the onnx parse_graph can't handle them
177 if self.post_processors:
178 sanitize_output_names(model.graph)
179
180 graph = model.graph
181 # add pre-processing
182 if self.pre_processors:
183 # create empty graph with pass through of the requested input name
184 pre_process_graph = onnx.GraphProto()
185 for i in self._inputs:
186 pre_process_graph.input.append(i)
187 pre_process_graph.output.append(i)
188
189 for idx, step in enumerate(self.pre_processors):
190 pre_process_graph = connect_and_run(pre_process_graph, step, self._pre_processor_connections[idx])
191
192 # name all the nodes for easier debugging
193 name_nodes(pre_process_graph, "pre_process_")
194
195 if not self._pre_processing_joins:
196 # default to 1:1 between outputs of last step with inputs of original model
197 last_step = self.pre_processors[-1]
198 num_entries = min(len(last_step.output_names), len(graph.input))
199 self._pre_processing_joins = [(last_step, i, graph.input[i].name) for i in range(0, num_entries)]
200
201 # map the pre-processing outputs to graph inputs
202 # we may need a natty way to get possible outputs after merge_graphs
203 step_graph_outputs = [o.name for o in pre_process_graph.output]
204 io_map = [] # type: List[Tuple[str, str]]
205 for step, step_idx, graph_input in self._pre_processing_joins:
206 io_map.append((step.output_names[step_idx], graph_input))
207 step_graph_outputs.remove((step.output_names[step_idx]))
208
209 # add outputs from previous IoMapEntry producers to maintain them as graph outputs
210 # until consumed by the final Step that requires them.
211 step_graph_outputs += [
212 o.name for o in graph.output if o.name not in step_graph_outputs]
213 external_outputs = [
214 i.output for i in self.outputs_preserver if i.is_active and i.output not in step_graph_outputs]
215 if external_outputs:
216 step_graph_outputs.extend(external_outputs)
217 graph = onnx.compose.merge_graphs(pre_process_graph, graph, io_map, outputs=step_graph_outputs)
218
219 # add post-processing
220 if self.post_processors:
221 orig_model_outputs = [o.name for o in model.graph.output]
222 graph_outputs = [o.name for o in graph.output] # this may have additional outputs from pre-processing
223
224 # create default joins if needed
225 if not self._post_processing_joins:
226 # default to 1:1 between outputs of original model with inputs of first post-processing step
227 first_step = self.post_processors[0]
228 num_entries = min(len(first_step.input_names), len(orig_model_outputs))
229 self._post_processing_joins = [(first_step, i, orig_model_outputs[i]) for i in range(0, num_entries)]
230
231 # update the input names for the steps to match the values produced by the model
232 for step, step_idx, graph_output in self._post_processing_joins:
233 assert graph_output in graph_outputs
234 step.input_names[step_idx] = graph_output
235
236 # create empty graph with the values that will be available to the post-processing
237 post_process_graph = onnx.GraphProto()
238 for o in graph.output:
239 post_process_graph.input.append(o)
240 post_process_graph.output.append(o)
241
242 for idx, step in enumerate(self.post_processors):
243 post_process_graph = connect_and_run(post_process_graph, step, self._post_processor_connections[idx])
244
245 name_nodes(post_process_graph, "post_process_")
246
247 # io_map should be 1:1 with the post-processing graph given we updated the step input names to match
248 io_map = [(o, o) for o in graph_outputs]
249 graph = onnx.compose.merge_graphs(graph, post_process_graph, io_map)
250
251 # Make the output names nicer by removing prefixing from naming that occurred when applying the steps
252 graph = PrePostProcessor.__cleanup_graph_output_names(graph)
253
254 opset_imports = [onnx.helper.make_operatorsetid(domain, opset)
255 for domain, opset in self._custom_op_checker_context.opset_imports.items()]
256 # find_min_ir_version_for doesn't support custom domains until ONNX 1.14 so extract the ONNX opset from the
257 # imports and only pass that in.
258 ir_version = onnx.helper.find_min_ir_version_for([entry for entry in opset_imports
259 if entry.domain == "" or entry.domain == "ai.onnx"])
260 new_model = onnx.helper.make_model(graph, opset_imports=opset_imports, ir_version=ir_version)
261
262 onnx.checker.check_model(new_model)
263
264 return new_model
265
266 def __add_processing(
267 self,
268 processors: List[Step],
269 processor_connections: List[List[IoMapEntry]],
270 items: List[Union[Step, Tuple[Step, List[IoMapEntry]]]],
271 ):
272 """
273 Add the pre/post processing steps and join with existing steps.
274
275 Args:
276 processors: List of processors to add items to.
277 processor_connections: Populated with connections between each step. 1:1 with entries in processors.
278 items: Items to add to processors.
279 Can be:
280 A Step instance. This will be implicitly joined to the immediately previous Step if one exists.
281 A tuple of (Step instance, list of IoMapEntry)
282 The IoMapEntry values are used to manually join an output from a producer Step to an input
283 of the current Step.
284 In each IoMapEntry, if a step name is provided the producer Step will be searched for in all
285 predecessor steps. It is valid for a post-processor step to consume output from a
286 pre-processor step.
287 """
288
289 for item in items:
290 step = None
291 explicit_io_map_entries = None
292
293 if isinstance(item, Step):
294 step = item
295 elif isinstance(item, tuple):
296 step, explicit_io_map_entries = item
297 else:
298 raise ValueError("Unexpected type " + str(type(item)))
299
300 # start with implicit joins and replace with explicitly provided ones
301 # this allows the user to specify the minimum number of manual joins.
302 io_map_entries = [None] * len(step.input_names) # type: List[Union[None,IoMapEntry]]
303 prev_step = None if len(processors) == 0 else processors[-1]
304 if prev_step:
305 # default is connecting as many outputs from the previous step as possible
306 for i in range(0, min(len(prev_step.output_names), len(step.input_names))):
307 io_map_entries[i] = IoMapEntry(prev_step, i, i)
308
309 # add explicit connections
310 if explicit_io_map_entries:
311 for entry in explicit_io_map_entries:
312 if not entry.producer:
313 producer = prev_step
314 else:
315 producer = self.__producer_from_step_or_str(entry.producer) # throws if not found
316
317 io_map_entries[entry.consumer_idx] = IoMapEntry(producer, entry.producer_idx, entry.consumer_idx)
318 self.outputs_preserver.append(IOEntryValuePreserver(producer, step, entry.producer_idx))
319
320 processors.append(step)
321 processor_connections.append([entry for entry in io_map_entries if entry is not None])
322
323 def __producer_from_step_or_str(self, entry: Union[Step, str]):
324 if isinstance(entry, Step):
325 return entry
326 if isinstance(entry, str):
327 match = (next((s for s in self.pre_processors if s.name == entry), None) or
328 next((s for s in self.post_processors if s.name == entry), None)) # fmt: skip
329
330 if not match:
331 raise ValueError(f"Step named {entry} was not found")
332
333 return match
334
335 @staticmethod
336 def __cleanup_graph_output_names(graph: onnx.GraphProto):
337 """
338 Hide the prefixing of names that happens when we merge the graphs from the pre/post processing steps.
339 Not essential but makes the graph outputs look far nicer.
340 """
341
342 # for each output create identity node to remove prefixing
343 io_map = []
344 fixes = onnx.GraphProto()
345
346 # manually handle naming clashes
347 input_names = set([i.name for i in graph.input])
348 used_names = set(input_names)
349 conflicts = 0
350
351 for o in graph.output:
352 if not o.name.startswith(Step.prefix):
353 continue
354
355 # we will create a small graph to do the renames so the output of the original graph will be an input
356 # to that 'fixer' graph
357 io_map.append((o.name, o.name))
358 clean_name = o.name
359 while clean_name.startswith(Step.prefix):
360 # output from last step will have one prefixing stage that adds Step._prefix + '_'
361 # e.g. '_ppp8_<orig_name>'
362 next_underscore = clean_name.find("_", 1)
363 if next_underscore > 0:
364 # this check shouldn't be necessary as we always add the trailing '_' when prefixing...
365 if len(clean_name) > next_underscore + 1:
366 next_underscore += 1
367 clean_name = clean_name[next_underscore:]
368
369 # handle things like super resolution where there's an 'image' input and 'image' output
370 if clean_name in input_names:
371 clean_name += "_out"
372
373 orig_clean_name = clean_name
374 while clean_name in used_names:
375 conflicts += 1
376 clean_name = f"{orig_clean_name}{conflicts}"
377
378 used_names.add(clean_name)
379
380 renamer = onnx.helper.make_node("Identity", [o.name], [clean_name], f"Rename {o.name}")
381 fixes.node.append(renamer)
382 fixes.input.append(o)
383
384 new_output = fixes.output.add()
385 new_output.name = clean_name
386 new_output.type.CopyFrom(o.type)
387
388 # merge if we have any renaming to do
389 if io_map:
390 graph = onnx.compose.merge_graphs(graph, fixes, io_map)
391
392 return graph