microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
rel-0.7

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/tools/pre_post_processing/pre_post_processor.py

388lines · modecode

1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License.
3
4import onnx
5
6from onnx import version_converter
7from typing import List, Tuple, Union
8
9from .utils import (
10 IoMapEntry,
11 IOEntryValuePreserver,
12 create_custom_op_checker_context,
13 sanitize_output_names,
14 TENSOR_TYPE_TO_ONNX_TYPE,
15)
16from .step import Step
17
18
19class PrePostProcessor:
20 """
21 Class to handle running all the pre/post processing steps and updating the model.
22 """
23
24 def __init__(self, inputs: List[onnx.ValueInfoProto] = None, onnx_opset: int = 16):
25 """
26 Create a PrePostProcessor instance.
27
28 Args:
29 inputs: The inputs the model will use if pre-processing is added.
30 onnx_opset: The ONNX opset to use.
31 Minimum is 16. 18 or higher is strongly preferred if image resizing is involved due to its
32 anti-aliasing ability.
33 """
34
35 if onnx_opset < 16:
36 raise ValueError("ONNX opset must be 16 or later.")
37
38 self._onnx_opset = onnx_opset
39 self._custom_op_checker_context = create_custom_op_checker_context(onnx_opset)
40
41 self.pre_processors = []
42 self.post_processors = []
43
44 # Connections for each pre/post processor. 1:1 mapping with entries in pre_processors/post_processors
45 self._pre_processor_connections = [] # type: List[List[IoMapEntry]]
46 self._post_processor_connections = [] # type: List[List[IoMapEntry]]
47
48 # explicitly join outputs from Steps in pre_processors to inputs of the original model
49 # format is Step or step name, step_idx, name of graph input/output
50 #
51 # Pre-processing we connect Step output to original model:
52 # - step_idx is for Step.output_names, and name is in graph.input
53 #
54 # Post-processing we connect the original model output to the Step input
55 # - step_idx is for Step.input_names, and name is in graph.output
56 self._pre_processing_joins = None # type: Union[None,List[Tuple[Union[Step, str], int, str]]]
57 self._post_processing_joins = None # type: Union[None,List[Tuple[Union[Step, str], int, str]]]
58
59 self._inputs = inputs if inputs else []
60
61 # preserve outputs from IOMapEntry, avoid it's consumed by the Follow-up steps.
62 # we now can support a output value has more than one consumers with IOEntryValuePreserver.
63 # IOEntryValuePreserver will preserve the output value and add it to the graph output
64 # until consumer step is done.
65 self.outputs_preserver = [] # type: List[IOEntryValuePreserver]
66
67 def add_pre_processing(self, items: List[Union[Step, Tuple[Step, List[IoMapEntry]]]]):
68 """
69 Add the pre-processing steps. The last step is automatically joined to the original model inputs.
70
71 Options are:
72 Add Step with default connection of outputs from the previous step (if available) to inputs of this step.
73 Add tuple of Step and list of IoMapEntry instances for manual connections to previous steps. This will be
74 used to override any automatic connections.
75 If IoMapEntry.producer is None it is inferred to be the immediately previous Step.
76 If IoMapEntry.producer is a step name it must match the name of a previous step.
77 """
78 self.__add_processing(self.pre_processors, self._pre_processor_connections, items)
79
80 def add_post_processing(self, items: List[Union[Step, Tuple[Step, List[IoMapEntry]]]]):
81 """
82 Add the post-processing steps. The first step is automatically joined to the original model outputs.
83
84 Options are:
85 Add Step with default connection of outputs from the previous step (if available) to inputs of this step.
86 Add tuple of Step and list of IoMapEntry instances for connections to previous steps. This will be
87 used to override any automatic connections.
88 If IoMapEntry.producer is None it is inferred to be the immediately previous Step.
89 If IoMapEntry.producer is a step name it must match the name of a previous step.
90 """
91 self.__add_processing(self.post_processors, self._post_processor_connections, items)
92
93 def _add_connection(self, consumer: Step, entry: IoMapEntry):
94 producer = self.__producer_from_step_or_str(entry.producer)
95
96 # Black does annoying things with the multi-line 'if' conditions making the code far less readable
97 # fmt: off
98 if not ((producer in self.pre_processors or producer in self.post_processors) and
99 (consumer in self.pre_processors or consumer in self.post_processors)):
100 raise ValueError("Producer and Consumer processors must both be registered")
101
102 if producer in self.pre_processors:
103 if (consumer in self.pre_processors and
104 self.pre_processors.index(producer) > self.pre_processors.index(consumer)):
105 raise ValueError("Producer was registered after consumer and cannot be connected")
106 elif producer in self.post_processors:
107 if consumer not in self.post_processors:
108 raise ValueError("Cannot connect pre-processor consumer with post-processor producer")
109 elif self.post_processors.index(producer) > self.post_processors.index(consumer):
110 raise ValueError("Producer was registered after consumer and cannot be connected")
111 # fmt: on
112
113 assert isinstance(producer, Step)
114 consumer.connect(entry)
115
116
117 def run(self, model: onnx.ModelProto):
118 """
119 Update the model with the graph from each step in the pre and post processing pipelines.
120
121 Args:
122 model: model to add pre/post processing to.
123
124 Returns:
125 model with pre/post processing in it.
126 """
127
128 # update the input model to the ONNX opset we're using. this is required as we implement the steps based on
129 # the operator specs for this opset.
130 model_opset = [
131 entry.version for entry in model.opset_import if entry.domain == "" or entry.domain == "ai.onnx"
132 ][0]
133
134 if model_opset > self._onnx_opset:
135 # It will probably work if the user updates PRE_POST_PROCESSING_ONNX_OPSET to match the model
136 # but there are no guarantees.
137 # Would only break if ONNX operators used in the pre/post processing graphs have had spec changes.
138 raise ValueError(f"Model opset is {model_opset} which is newer than the opset used by this script.")
139 elif model_opset < self._onnx_opset:
140 model = onnx.version_converter.convert_version(model, self._onnx_opset)
141
142 def name_nodes(new_graph: onnx.GraphProto, prefix: str):
143 # simple helper so all nodes are named. this makes it far easier to debug any issues.
144 idx = 0
145 for n in new_graph.node:
146 if not n.name:
147 n.name = prefix + str(idx)
148 idx += 1
149
150 def preserved_apply(processor: Step, *args):
151 # Trying to activate the IOEntryValuePreserver and preserve outputs.
152 # and deactivate the outputs when the current graph consumed them
153
154 for preserver in self.outputs_preserver:
155 if preserver.consumer == processor:
156 preserver.is_active = False
157
158 # IOEntryValuePreserver, preserve those outputs which has multiple consumers.
159 # we explicitly add the output to the graph output.
160 graph_outputs_to_maintain = [i.output for i in self.outputs_preserver if i.is_active]
161 graph_for_step = processor.apply(*args, graph_outputs_to_maintain=graph_outputs_to_maintain)
162
163 for preserver in self.outputs_preserver:
164 if preserver.producer == processor:
165 preserver.is_active = True
166 preserver.output = processor.output_names[preserver.producer_idx]
167 return graph_for_step
168
169 def connect_and_run(graph: onnx.GraphProto, processor: Step, connections: List[IoMapEntry]):
170 for connection in connections:
171 assert connection.producer
172 self._add_connection(processor, connection)
173
174 return preserved_apply(processor, graph, self._custom_op_checker_context)
175
176 # fix any invalid output names now if we're adding post-processing as the onnx parse_graph can't handle them
177 if self.post_processors:
178 sanitize_output_names(model.graph)
179
180 graph = model.graph
181 # add pre-processing
182 if self.pre_processors:
183 # create empty graph with pass through of the requested input name
184 pre_process_graph = onnx.GraphProto()
185 for i in self._inputs:
186 pre_process_graph.input.append(i)
187 pre_process_graph.output.append(i)
188
189 for idx, step in enumerate(self.pre_processors):
190 pre_process_graph = connect_and_run(pre_process_graph, step, self._pre_processor_connections[idx])
191
192 # name all the nodes for easier debugging
193 name_nodes(pre_process_graph, "pre_process_")
194
195 if not self._pre_processing_joins:
196 # default to 1:1 between outputs of last step with inputs of original model
197 last_step = self.pre_processors[-1]
198 num_entries = min(len(last_step.output_names), len(graph.input))
199 self._pre_processing_joins = [(last_step, i, graph.input[i].name) for i in range(0, num_entries)]
200
201 # map the pre-processing outputs to graph inputs
202 # we may need a natty way to get possible outputs after merge_graphs
203 step_graph_outputs = [o.name for o in pre_process_graph.output]
204 io_map = [] # type: List[Tuple[str, str]]
205 for step, step_idx, graph_input in self._pre_processing_joins:
206 io_map.append((step.output_names[step_idx], graph_input))
207 step_graph_outputs.remove((step.output_names[step_idx]))
208
209 # add outputs from previous IoMapEntry producers to maintain them as graph outputs
210 # until consumed by the final Step that requires them.
211 step_graph_outputs += [
212 o.name for o in graph.output if o.name not in step_graph_outputs]
213 external_outputs = [
214 i.output for i in self.outputs_preserver if i.is_active and i.output not in step_graph_outputs]
215 if external_outputs:
216 step_graph_outputs.extend(external_outputs)
217 graph = onnx.compose.merge_graphs(pre_process_graph, graph, io_map, outputs=step_graph_outputs)
218
219 # add post-processing
220 if self.post_processors:
221 orig_model_outputs = [o.name for o in model.graph.output]
222 graph_outputs = [o.name for o in graph.output] # this may have additional outputs from pre-processing
223
224 # create default joins if needed
225 if not self._post_processing_joins:
226 # default to 1:1 between outputs of original model with inputs of first post-processing step
227 first_step = self.post_processors[0]
228 num_entries = min(len(first_step.input_names), len(orig_model_outputs))
229 self._post_processing_joins = [(first_step, i, orig_model_outputs[i]) for i in range(0, num_entries)]
230
231 # update the input names for the steps to match the values produced by the model
232 for step, step_idx, graph_output in self._post_processing_joins:
233 assert graph_output in graph_outputs
234 step.input_names[step_idx] = graph_output
235
236 # create empty graph with the values that will be available to the post-processing
237 post_process_graph = onnx.GraphProto()
238 for o in graph.output:
239 post_process_graph.input.append(o)
240 post_process_graph.output.append(o)
241
242 for idx, step in enumerate(self.post_processors):
243 post_process_graph = connect_and_run(post_process_graph, step, self._post_processor_connections[idx])
244
245 name_nodes(post_process_graph, "post_process_")
246
247 # io_map should be 1:1 with the post-processing graph given we updated the step input names to match
248 io_map = [(o, o) for o in graph_outputs]
249 graph = onnx.compose.merge_graphs(graph, post_process_graph, io_map)
250
251 # Make the output names nicer by removing prefixing from naming that occurred when applying the steps
252 graph = PrePostProcessor.__cleanup_graph_output_names(graph)
253
254 opset_imports = [onnx.helper.make_operatorsetid(domain, opset)
255 for domain, opset in self._custom_op_checker_context.opset_imports.items()]
256 new_model = onnx.helper.make_model(graph, opset_imports=opset_imports)
257
258 onnx.checker.check_model(new_model)
259
260 return new_model
261
262 def __add_processing(
263 self,
264 processors: List[Step],
265 processor_connections: List[List[IoMapEntry]],
266 items: List[Union[Step, Tuple[Step, List[IoMapEntry]]]],
267 ):
268 """
269 Add the pre/post processing steps and join with existing steps.
270
271 Args:
272 processors: List of processors to add items to.
273 processor_connections: Populated with connections between each step. 1:1 with entries in processors.
274 items: Items to add to processors.
275 Can be:
276 A Step instance. This will be implicitly joined to the immediately previous Step if one exists.
277 A tuple of (Step instance, list of IoMapEntry)
278 The IoMapEntry values are used to manually join an output from a producer Step to an input
279 of the current Step.
280 In each IoMapEntry, if a step name is provided the producer Step will be searched for in all
281 predecessor steps. It is valid for a post-processor step to consume output from a
282 pre-processor step.
283 """
284
285 for item in items:
286 step = None
287 explicit_io_map_entries = None
288
289 if isinstance(item, Step):
290 step = item
291 elif isinstance(item, tuple):
292 step, explicit_io_map_entries = item
293 else:
294 raise ValueError("Unexpected type " + str(type(item)))
295
296 # start with implicit joins and replace with explicitly provided ones
297 # this allows the user to specify the minimum number of manual joins.
298 io_map_entries = [None] * len(step.input_names) # type: List[Union[None,IoMapEntry]]
299 prev_step = None if len(processors) == 0 else processors[-1]
300 if prev_step:
301 # default is connecting as many outputs from the previous step as possible
302 for i in range(0, min(len(prev_step.output_names), len(step.input_names))):
303 io_map_entries[i] = IoMapEntry(prev_step, i, i)
304
305 # add explicit connections
306 if explicit_io_map_entries:
307 for entry in explicit_io_map_entries:
308 if not entry.producer:
309 producer = prev_step
310 else:
311 producer = self.__producer_from_step_or_str(entry.producer) # throws if not found
312
313 io_map_entries[entry.consumer_idx] = IoMapEntry(producer, entry.producer_idx, entry.consumer_idx)
314 self.outputs_preserver.append(IOEntryValuePreserver(producer, step, entry.producer_idx))
315
316 processors.append(step)
317 processor_connections.append([entry for entry in io_map_entries if entry is not None])
318
319 def __producer_from_step_or_str(self, entry: Union[Step, str]):
320 if isinstance(entry, Step):
321 return entry
322 if isinstance(entry, str):
323 match = (next((s for s in self.pre_processors if s.name == entry), None) or
324 next((s for s in self.post_processors if s.name == entry), None)) # fmt: skip
325
326 if not match:
327 raise ValueError(f"Step named {entry} was not found")
328
329 return match
330
331 @staticmethod
332 def __cleanup_graph_output_names(graph: onnx.GraphProto):
333 """
334 Hide the prefixing of names that happens when we merge the graphs from the pre/post processing steps.
335 Not essential but makes the graph outputs look far nicer.
336 """
337
338 # for each output create identity node to remove prefixing
339 io_map = []
340 fixes = onnx.GraphProto()
341
342 # manually handle naming clashes
343 input_names = set([i.name for i in graph.input])
344 used_names = set(input_names)
345 conflicts = 0
346
347 for o in graph.output:
348 if not o.name.startswith(Step.prefix):
349 continue
350
351 # we will create a small graph to do the renames so the output of the original graph will be an input
352 # to that 'fixer' graph
353 io_map.append((o.name, o.name))
354 clean_name = o.name
355 while clean_name.startswith(Step.prefix):
356 # output from last step will have one prefixing stage that adds Step._prefix + '_'
357 # e.g. '_ppp8_<orig_name>'
358 next_underscore = clean_name.find("_", 1)
359 if next_underscore > 0:
360 # this check shouldn't be necessary as we always add the trailing '_' when prefixing...
361 if len(clean_name) > next_underscore + 1:
362 next_underscore += 1
363 clean_name = clean_name[next_underscore:]
364
365 # handle things like super resolution where there's an 'image' input and 'image' output
366 if clean_name in input_names:
367 clean_name += "_out"
368
369 orig_clean_name = clean_name
370 while clean_name in used_names:
371 conflicts += 1
372 clean_name = f"{orig_clean_name}{conflicts}"
373
374 used_names.add(clean_name)
375
376 renamer = onnx.helper.make_node("Identity", [o.name], [clean_name], f"Rename {o.name}")
377 fixes.node.append(renamer)
378 fixes.input.append(o)
379
380 new_output = fixes.output.add()
381 new_output.name = clean_name
382 new_output.type.CopyFrom(o.type)
383
384 # merge if we have any renaming to do
385 if io_map:
386 graph = onnx.compose.merge_graphs(graph, fixes, io_map)
387
388 return graph
389