microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
edgchen1/fix_ci

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/tools/pre_post_processing/step.py

214lines · modecode

1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License.
3
4import abc
5import onnx
6
7from onnx import parser
8from typing import List, Optional, Tuple
9
10from .utils import (
11 IoMapEntry,
12 create_custom_op_checker_context,
13 TENSOR_TYPE_TO_ONNX_TYPE,
14)
15
16
17class Step(object):
18 """Base class for a pre or post processing step."""
19
20 prefix = "_ppp"
21 _step_num = 0 # unique step number so we can prefix the naming in the graph created for the step
22
23 def __init__(self, inputs: List[str], outputs: List[str], name: Optional[str] = None):
24 """
25 Initialize the step.
26
27 Args:
28 inputs: List of default input names.
29 outputs: List of default output names.
30 name: Step name. Defaults to the derived class name.
31 """
32 self.step_num = Step._step_num
33 self.input_names = inputs
34 self.output_names = outputs
35 self.name = name if name else f"{self.__class__.__name__}"
36 self._prefix = f"{Step.prefix}{self.step_num}_"
37
38 Step._step_num += 1
39
40 def connect(self, entry: IoMapEntry):
41 """
42 Connect the value name from a previous step to an input of this step so they match.
43 This makes joining the GraphProto created by each step trivial.
44 """
45 assert len(entry.producer.output_names) >= entry.producer_idx
46 assert len(self.input_names) >= entry.consumer_idx
47 assert isinstance(entry.producer, Step)
48
49 self.input_names[entry.consumer_idx] = entry.producer.output_names[entry.producer_idx]
50
51 def apply(self, graph: onnx.GraphProto, checker_context: onnx.checker.C.CheckerContext):
52 """
53 Create a graph for this step that can be appended to the provided graph.
54 The PrePostProcessor will handle merging the two.
55 """
56
57 onnx_opset = checker_context.opset_imports[""]
58 graph_for_step = self._create_graph_for_step(graph, onnx_opset)
59 onnx.checker.check_graph(graph_for_step, checker_context)
60
61 # prefix the graph for this step to guarantee no clashes of value names with the existing graph
62 onnx.compose.add_prefix_graph(graph_for_step, self._prefix, inplace=True)
63 result = self.__merge(graph, graph_for_step)
64
65 # update self.output_names to the prefixed names so that when we connect later Steps the values match
66 new_outputs = [self._prefix + o for o in self.output_names]
67 result_outputs = [o.name for o in result.output]
68
69 # sanity check that all of our outputs are in the merged graph
70 for o in new_outputs:
71 assert o in result_outputs
72
73 self.output_names = new_outputs
74
75 return result
76
77 @abc.abstractmethod
78 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
79 """
80 Derived class should implement this and return the GraphProto containing the nodes required to
81 implement the step.
82
83 Args:
84 graph: Graph the step will be appended to. Use to determine the types and shapes of values to connect.
85 onnx_opset: The ONNX opset being targeted.
86 """
87 pass
88
89 def __merge(self, first: onnx.GraphProto, second: onnx.GraphProto):
90 # We prefixed all the value names in `second`, so allow for that when connecting the two graphs
91 io_map = []
92 for o in first.output:
93 # apply the same prefix to the output from the previous step to match the prefixed graph from this step
94 prefixed_output = self._prefix + o.name
95 for i in second.input:
96 if i.name == prefixed_output:
97 io_map.append((o.name, i.name))
98
99 outputs_to_preserve = None
100
101 # special handling of Debug class.
102 if isinstance(self, Debug):
103 # preserve outputs of the first graph so they're available downstream. otherwise they are consumed by
104 # the Debug node and disappear during the ONNX graph_merge as it considers consumed values to be
105 # internal - which is entirely reasonable when merging graphs.
106 # the issue we have is that we don't know what future steps might want things to remain as outputs.
107 # the current approach is to insert a Debug step which simply duplicates the values so that they are
108 # guaranteed not be consumed (only one of the two copies will be used).
109 # doesn't change the number of outputs from the previous step, so it can be transparently inserted in the
110 # pre/post processing pipeline.
111 # need to also list the second graph's outputs when manually specifying outputs.
112 outputs_to_preserve = [o.name for o in first.output] + [o.name for o in second.output]
113
114 # merge with existing graph
115 merged_graph = onnx.compose.merge_graphs(first, second, io_map, outputs=outputs_to_preserve)
116
117 return merged_graph
118
119 @staticmethod
120 def _elem_type_str(elem_type: int):
121 return TENSOR_TYPE_TO_ONNX_TYPE[elem_type]
122
123 @staticmethod
124 def _shape_to_str(shape: onnx.TensorShapeProto):
125 """Returns the values from the shape as a comma separated string."""
126
127 def dim_to_str(dim):
128 if dim.HasField("dim_value"):
129 return str(dim.dim_value)
130 elif dim.HasField("dim_param"):
131 return dim.dim_param
132 else:
133 return ""
134
135 shape_str = ",".join([dim_to_str(dim) for dim in shape.dim])
136 return shape_str
137
138 def _input_tensor_type(self, graph: onnx.GraphProto, input_num: int) -> onnx.TensorProto:
139 """Get the onnx.TensorProto for the input from the outputs of the graph we're appending to."""
140
141 input_type = None
142 for o in graph.output:
143 if o.name == self.input_names[input_num]:
144 input_type = o.type.tensor_type
145 break
146
147 if not input_type:
148 raise ValueError(f"Input {self.input_names[input_num]} was not found in outputs of graph.")
149
150 return input_type
151
152 def _get_input_type_and_shape_strs(self, graph: onnx.GraphProto, input_num: int) -> Tuple[str, str]:
153 input_type = self._input_tensor_type(graph, input_num)
154 return Step._elem_type_str(input_type.elem_type), Step._shape_to_str(input_type.shape)
155
156
157# special case. we include the helper Debug step here as logic in the base class is conditional on it.
158class Debug(Step):
159 """
160 Step that can be arbitrarily inserted in the pre or post processing pipeline.
161 It will make the outputs of the previous Step also become graph outputs so their value can be more easily debugged.
162
163 NOTE: Depending on when the previous Step's outputs are consumed in the pipeline the graph output for it
164 may or may not have '_debug' as a suffix.
165 TODO: PrePostProcessor __cleanup_graph_output_names could also hide the _debug by inserting an Identity node
166 to rename so it's more consistent.
167 """
168
169 def __init__(self, num_inputs: int = 1, name: Optional[str] = None):
170 """
171 Initialize Debug step
172 Args:
173 num_inputs: Number of inputs from previous Step to make graph outputs.
174 name: Optional name for Step. Defaults to 'Debug'
175 """
176 self._num_inputs = num_inputs
177 input_names = [f"input{i}" for i in range(0, num_inputs)]
178 output_names = [f"debug{i}" for i in range(0, num_inputs)]
179
180 super().__init__(input_names, output_names, name)
181
182 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
183 input_str = ""
184 output_str = ""
185 output_debug_str = ""
186 nodes_str = ""
187
188 # update output names so we preserve info from the latest input names
189 self.output_names = [f"{name}_debug" for name in self.input_names]
190
191 for i in range(0, self._num_inputs):
192 input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, i)
193 if i > 0:
194 input_str += ", "
195 output_str += ", "
196 output_debug_str += ", "
197 nodes_str += "\n"
198
199 input_str += f"{input_type_str}[{input_shape_str}] {self.input_names[i]}"
200 output_str += f"{input_type_str}[{input_shape_str}] {self.output_names[i]}"
201 nodes_str += f"{self.output_names[i]} = Identity({self.input_names[i]})\n"
202
203 debug_graph = onnx.parser.parse_graph(
204 f"""\
205 debug ({input_str})
206 => ({output_str})
207 {{
208 {nodes_str}
209 }}
210 """
211 )
212
213 onnx.checker.check_graph(debug_graph)
214 return debug_graph
215