microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
natke-patch-1

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/tools/pre_post_processing/step.py

222lines · modecode

1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License.
3
4import abc
5import onnx
6
7from onnx import parser
8from typing import List, Optional, Tuple
9
10from .utils import (
11 IoMapEntry,
12 create_custom_op_checker_context,
13 TENSOR_TYPE_TO_ONNX_TYPE,
14)
15
16
17class Step(object):
18 """Base class for a pre or post processing step."""
19
20 prefix = "_ppp"
21 _step_num = 0 # unique step number so we can prefix the naming in the graph created for the step
22
23 def __init__(self, inputs: List[str], outputs: List[str], name: Optional[str] = None):
24 """
25 Initialize the step.
26
27 Args:
28 inputs: List of default input names.
29 outputs: List of default output names.
30 name: Step name. Defaults to the derived class name.
31 """
32 self.step_num = Step._step_num
33 self.input_names = inputs
34 self.output_names = outputs
35 self.name = name if name else f"{self.__class__.__name__}"
36 self._prefix = f"{Step.prefix}{self.step_num}_"
37
38 Step._step_num += 1
39
40 def connect(self, entry: IoMapEntry):
41 """
42 Connect the value name from a previous step to an input of this step so they match.
43 This makes joining the GraphProto created by each step trivial.
44 """
45 assert len(entry.producer.output_names) >= entry.producer_idx
46 assert len(self.input_names) >= entry.consumer_idx
47 assert isinstance(entry.producer, Step)
48
49 self.input_names[entry.consumer_idx] = entry.producer.output_names[entry.producer_idx]
50
51 def apply(self, graph: onnx.GraphProto,
52 checker_context: onnx.checker.C.CheckerContext,
53 graph_outputs_to_maintain: List[str]):
54 """
55 Create a graph for this step that can be appended to the provided graph.
56 The PrePostProcessor will handle merging the two.
57
58 Args:
59 graph_outputs_to_maintain: List of output names to maintain in the graph by additional effort.
60 For outputs having multiple consumers, these outputs will be consumed by default and prevent
61 connection from the subsequent steps.
62 This outputs is generated by IOEntryValuePreserver.
63 """
64
65 onnx_opset = checker_context.opset_imports[""]
66 graph_for_step = self._create_graph_for_step(graph, onnx_opset)
67 onnx.checker.check_graph(graph_for_step, checker_context)
68
69 # prefix the graph for this step to guarantee no clashes of value names with the existing graph
70 onnx.compose.add_prefix_graph(graph_for_step, self._prefix, inplace=True)
71 result = self.__merge(graph, graph_for_step, graph_outputs_to_maintain)
72
73 # update self.output_names to the prefixed names so that when we connect later Steps the values match
74 new_outputs = [self._prefix + o for o in self.output_names]
75 result_outputs = [o.name for o in result.output]
76
77 # sanity check that all of our outputs are in the merged graph
78 for o in new_outputs:
79 assert o in result_outputs
80
81 self.output_names = new_outputs
82
83 return result
84
85 @abc.abstractmethod
86 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
87 """
88 Derived class should implement this and return the GraphProto containing the nodes required to
89 implement the step.
90
91 Args:
92 graph: Graph the step will be appended to. Use to determine the types and shapes of values to connect.
93 onnx_opset: The ONNX opset being targeted.
94 """
95 pass
96
97 def __merge(self, first: onnx.GraphProto, second: onnx.GraphProto,
98 graph_outputs_to_maintain: Optional[List[str]] = None):
99 # We prefixed all the value names in `second`, so allow for that when connecting the two graphs
100 first_output = [o.name for o in first.output]
101 io_map = []
102 for o in first.output:
103 # apply the same prefix to the output from the previous step to match the prefixed graph from this step
104 prefixed_output = self._prefix + o.name
105 for i in second.input:
106 if i.name == prefixed_output:
107 io_map.append((o.name, i.name))
108 first_output.remove(o.name)
109
110 graph_outputs = first_output + [o.name for o in second.output if o.name not in first_output]
111 graph_outputs += [o for o in graph_outputs_to_maintain if o not in graph_outputs]
112
113 # merge with existing graph
114 merged_graph = onnx.compose.merge_graphs(first, second, io_map, outputs=graph_outputs)
115
116 return merged_graph
117
118 @staticmethod
119 def _elem_type_str(elem_type: int):
120 return TENSOR_TYPE_TO_ONNX_TYPE[elem_type]
121
122 @staticmethod
123 def _shape_to_str(shape: onnx.TensorShapeProto):
124 """Returns the values from the shape as a comma separated string."""
125
126 def dim_to_str(dim):
127 if dim.HasField("dim_value"):
128 return str(dim.dim_value)
129 elif dim.HasField("dim_param"):
130 return dim.dim_param
131 else:
132 return ""
133
134 shape_str = ",".join([dim_to_str(dim) for dim in shape.dim])
135 return shape_str
136
137 def _input_tensor_type(self, graph: onnx.GraphProto, input_num: int) -> onnx.TensorProto:
138 """Get the onnx.TensorProto for the input from the outputs of the graph we're appending to."""
139
140 input_type = None
141 for o in graph.output:
142 if o.name == self.input_names[input_num]:
143 input_type = o.type.tensor_type
144 break
145
146 if not input_type:
147 raise ValueError(f"Input {self.input_names[input_num]} was not found in outputs of graph.")
148
149 return input_type
150
151 def _get_input_type_and_shape_strs(self, graph: onnx.GraphProto, input_num: int) -> Tuple[str, str]:
152 input_type = self._input_tensor_type(graph, input_num)
153 return Step._elem_type_str(input_type.elem_type), Step._shape_to_str(input_type.shape)
154
155
156class Debug(Step):
157 """
158 Step that can be arbitrarily inserted in the pre or post processing pipeline.
159 It will make the outputs of the previous Step also become graph outputs so their value can be more easily debugged.
160
161 The output will be duplicated into two outputs, one will be renamed with a suffix "_next",
162 another will be renamed with a suffix "_debug". The "_next" outputs will feed into the next step,
163 the "_debug" outputs will become graph outputs.
164 """
165
166 def __init__(self, num_inputs: int = 1, name: Optional[str] = None):
167 """
168 Initialize Debug step
169 Args:
170 num_inputs: Number of inputs from previous Step to make graph outputs.
171 name: Optional name for Step. Defaults to 'Debug'
172 """
173 self._num_inputs = num_inputs
174 input_names = [f"input{i}" for i in range(0, num_inputs)]
175 output_names = [f"debug{i}" for i in range(0, num_inputs)]
176
177 super().__init__(input_names, output_names, name)
178
179 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
180 if self._num_inputs > len(graph.output):
181 raise ValueError(
182 f"Debug step requested {self._num_inputs} inputs, but graph only has {len(graph.output)} outputs.")
183
184 debug_offset = len(self.input_names)
185 # update output names so we preserve info from the latest input names
186 self.output_names = [f"{name}_next" for name in self.input_names]
187 self.output_names += [f"{name}_debug" for name in self.input_names]
188
189 input_str_list = []
190 output_str_list = []
191 nodes_str_list = []
192 for i in range(0, self._num_inputs):
193 input_type_str, input_shape_str = self._get_input_type_and_shape_strs(
194 graph, i)
195
196 input_str_list.append(
197 f"{input_type_str}[{input_shape_str}] {self.input_names[i]}")
198
199 output_str_list.append(
200 f"{input_type_str}[{input_shape_str}] {self.output_names[i]}")
201 output_str_list.append(
202 f"{input_type_str}[{input_shape_str}] {self.output_names[debug_offset+i]}")
203
204 nodes_str_list.append(
205 f"{self.output_names[i]} = Identity({self.input_names[i]})\n")
206 nodes_str_list.append(
207 f"{self.output_names[debug_offset+i]} = Identity({self.input_names[i]})\n")
208
209 # f-string can't have back-slash
210 node_str = '\n'.join(nodes_str_list)
211 debug_graph = onnx.parser.parse_graph(
212 f"""\
213 debug ({','.join(input_str_list)})
214 => ({','.join(output_str_list)})
215 {{
216 {node_str}
217 }}
218 """
219 )
220
221 onnx.checker.check_graph(debug_graph)
222 return debug_graph
223