microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
rel-0.7

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/tools/pre_post_processing/steps/general.py

254lines · modecode

1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License.
3
4import onnx
5from typing import List, Optional
6from ..step import Step
7
8
9class ReverseAxis(Step):
10 """
11 Reverses the data in an axis by splitting and concatenating in reverse order.
12 e.g. convert RGB ordered data to BGR.
13 Output data type and shape is the same as the input.
14 """
15
16 def __init__(self, axis: int = -1, dim_value: int = -1, name: Optional[str] = None):
17 """
18 Args:
19 axis: Axis to reverse. Default is last axis.
20 dim_value: Explicit value for size of dimension being reversed.
21 This can be provided if the axis being reversed currently has a symbolic value.
22 Note that this will fail during graph execution if the actual value at runtime does not match.
23 If not provided, the size of the dimension to reverse is inferred from the input shape.
24 name: Optional Step name. Defaults to 'ReverseAxis'
25 """
26 super().__init__(["data"], ["data_with_reversed_axis"], name)
27 self._axis = axis
28 self._dim_value = dim_value
29
30 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
31 input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
32 input_dims = input_shape_str.split(",")
33 split_dim = input_dims[self._axis]
34
35 if split_dim.isdigit():
36 dim_value = int(split_dim)
37 if self._dim_value != -1:
38 # TODO: Technically we don't require a match here. For now expect it to match.
39 assert dim_value == self._dim_value
40 else:
41 self._dim_value = dim_value
42
43 split_outs = []
44 for i in range(0, self._dim_value):
45 split_outs.append(f"split_out_{i}")
46
47 split_attr = f"axis = {self._axis}"
48 if onnx_opset >= 18:
49 # Split now requires the number of outputs to be specified even though that can be easily inferred...
50 split_attr += f", num_outputs = {len(split_outs)}"
51
52 reverse_graph = onnx.parser.parse_graph(
53 f"""\
54 reverse_axis ({input_type_str}[{input_shape_str}] {self.input_names[0]})
55 => ({input_type_str}[{input_shape_str}] {self.output_names[0]})
56 {{
57 {','.join(split_outs)} = Split <{split_attr}> ({self.input_names[0]})
58 {self.output_names[0]} = Concat <axis = {self._axis}> ({','.join(reversed(split_outs))})
59 }}
60 """
61 )
62
63 return reverse_graph
64
65
66class Squeeze(Step):
67 """
68 ONNX Squeeze
69 """
70
71 def __init__(self, axes: Optional[List[int]] = None, name: Optional[str] = None):
72 """
73 Args:
74 axes: Axes to remove.
75 If None, remove all axes with size of 1. Requires all dimensions to have explicit values.
76 name: Optional Step name. Defaults to 'Squeeze'
77 """
78 super().__init__(["data"], ["squeezed"], name)
79 self._axes = axes
80
81 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
82 input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
83 dims = input_shape_str.split(",")
84
85 axes = self._axes
86 if not axes:
87 axes = []
88 for idx, dim in enumerate(dims):
89 if not dim.isnumeric():
90 # we can't infer the output shape if there are symbolic dims
91 raise ValueError("Axes must be specified if there are symbolic dimensions.")
92
93 if dim == '1':
94 axes.append(int(idx))
95
96 output_dims = [dim for idx, dim in enumerate(dims) if idx not in axes]
97 output_shape_str = ",".join(output_dims)
98
99 axes_strs = [str(axis) for axis in axes]
100
101 squeeze_graph = onnx.parser.parse_graph(
102 f"""\
103 squeeze ({input_type_str}[{input_shape_str}] {self.input_names[0]})
104 => ({input_type_str}[{output_shape_str}] {self.output_names[0]})
105 {{
106 axes = Constant <value = int64[{len(axes)}] {{{','.join(axes_strs)}}}> ()
107 {self.output_names[0]} = Squeeze({self.input_names[0]}, axes)
108 }}
109 """
110 )
111
112 return squeeze_graph
113
114
115class Transpose(Step):
116 """
117 ONNX Transpose.
118 """
119
120 def __init__(self, perms: List[int], name: Optional[str] = None):
121 """
122 Args:
123 perms: List of integers with permutations to apply.
124 name: Optional Step name. Defaults to 'Transpose'
125 """
126 super().__init__(["X"], ["transposed"], name)
127 self.perms = perms
128
129 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
130 input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
131 perms_str = ",".join([str(idx) for idx in self.perms])
132 dims = input_shape_str.split(",")
133 output_dims = [dims[axis] for axis in self.perms]
134 output_shape_str = ",".join(output_dims)
135
136 transpose_graph = onnx.parser.parse_graph(
137 f"""\
138 transpose ({input_type_str}[{input_shape_str}] {self.input_names[0]})
139 => ({input_type_str}[{output_shape_str}] {self.output_names[0]})
140 {{
141 {self.output_names[0]} = Transpose <perm = [{perms_str}]> ({self.input_names[0]})
142 }}
143 """
144 )
145
146 return transpose_graph
147
148
149class Softmax(Step):
150 """
151 ONNX Softmax
152 """
153
154 def __init__(self, name: Optional[str] = None):
155 """
156 Args:
157 name: Optional Step name. Defaults to 'Softmax'
158 """
159 super().__init__(["data"], ["probabilities"], name)
160
161 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
162 input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
163
164 softmax_graph = onnx.parser.parse_graph(
165 f"""\
166 softmax ({input_type_str}[{input_shape_str}] {self.input_names[0]})
167 => ({input_type_str}[{input_shape_str}] {self.output_names[0]})
168 {{
169 {self.output_names[0]} = Softmax ({self.input_names[0]})
170 }}
171 """
172 )
173
174 return softmax_graph
175
176
177class Unsqueeze(Step):
178 """
179 ONNX Unsqueeze
180 """
181
182 def __init__(self, axes: List[int], name: Optional[str] = None):
183 """
184 Args:
185 axes: List of integers indicating the dimensions to be inserted.
186 name: Optional Step name. Defaults to 'Unsqueeze'
187 """
188 super().__init__(["data"], ["expanded"], name)
189 self._axes = axes
190
191 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
192 input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
193 dims = input_shape_str.split(",")
194
195 for idx in self._axes:
196 dims.insert(idx, "1")
197
198 output_shape_str = ",".join(dims)
199 axes_strs = [str(axis) for axis in self._axes]
200
201 unsqueeze_graph = onnx.parser.parse_graph(
202 f"""\
203 unsqueeze ({input_type_str}[{input_shape_str}] {self.input_names[0]})
204 => ({input_type_str}[{output_shape_str}] {self.output_names[0]})
205 {{
206 axes = Constant <value = int64[{len(self._axes)}] {{{','.join(axes_strs)}}}> ()
207 {self.output_names[0]} = Unsqueeze ({self.input_names[0]}, axes)
208 }}
209 """
210 )
211
212 return unsqueeze_graph
213
214
215class ArgMax(Step):
216 def __init__(self, name: Optional[str] = None, axis: int = -1, keepdims: int = 0):
217 """
218 Brief:
219 Same as ArgMax op.
220 Args:
221 name: Optional name of step. Defaults to 'ArgMax'
222
223 """
224 super().__init__(["data"], ["index"], name)
225 self._axis = axis
226 self._keepdims = keepdims
227
228 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
229 input_type_str_0, input_shape_str_0 = self._get_input_type_and_shape_strs(graph, 0)
230 input_shape_0 = input_shape_str_0.split(",")
231
232 def build_input_declare():
233 return f"{input_type_str_0}[{input_shape_str_0}] {self.input_names[0]}"
234
235 axis = self._axis + len(input_shape_0) if self._axis < 0 else self._axis
236 if axis >= len(input_shape_0):
237 raise ValueError("axis should be in range [-rank, rank-1].")
238
239 output_shape_str = input_shape_0.copy()
240 output_shape_str[axis] = "1"
241 if self._keepdims == 0:
242 output_shape_str.pop(axis)
243
244 converter_graph = onnx.parser.parse_graph(
245 f"""\
246 classify ({build_input_declare()})
247 => (int64[{','.join(output_shape_str)}] {self.output_names[0]})
248 {{
249 {self.output_names[0]} = ArgMax<axis = {self._axis}, keepdims={self._keepdims}>({self.input_names[0]})
250 }}
251 """
252 )
253
254 return converter_graph
255