microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
onnxruntime_extensions/tools/pre_post_processing/steps/general.py
254lines · modecode
| 1 | # Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | # Licensed under the MIT License. |
| 3 | |
| 4 | import onnx |
| 5 | from typing import List, Optional |
| 6 | from ..step import Step |
| 7 | |
| 8 | |
| 9 | class ReverseAxis(Step): |
| 10 | """ |
| 11 | Reverses the data in an axis by splitting and concatenating in reverse order. |
| 12 | e.g. convert RGB ordered data to BGR. |
| 13 | Output data type and shape is the same as the input. |
| 14 | """ |
| 15 | |
| 16 | def __init__(self, axis: int = -1, dim_value: int = -1, name: Optional[str] = None): |
| 17 | """ |
| 18 | Args: |
| 19 | axis: Axis to reverse. Default is last axis. |
| 20 | dim_value: Explicit value for size of dimension being reversed. |
| 21 | This can be provided if the axis being reversed currently has a symbolic value. |
| 22 | Note that this will fail during graph execution if the actual value at runtime does not match. |
| 23 | If not provided, the size of the dimension to reverse is inferred from the input shape. |
| 24 | name: Optional Step name. Defaults to 'ReverseAxis' |
| 25 | """ |
| 26 | super().__init__(["data"], ["data_with_reversed_axis"], name) |
| 27 | self._axis = axis |
| 28 | self._dim_value = dim_value |
| 29 | |
| 30 | def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int): |
| 31 | input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0) |
| 32 | input_dims = input_shape_str.split(",") |
| 33 | split_dim = input_dims[self._axis] |
| 34 | |
| 35 | if split_dim.isdigit(): |
| 36 | dim_value = int(split_dim) |
| 37 | if self._dim_value != -1: |
| 38 | # TODO: Technically we don't require a match here. For now expect it to match. |
| 39 | assert dim_value == self._dim_value |
| 40 | else: |
| 41 | self._dim_value = dim_value |
| 42 | |
| 43 | split_outs = [] |
| 44 | for i in range(0, self._dim_value): |
| 45 | split_outs.append(f"split_out_{i}") |
| 46 | |
| 47 | split_attr = f"axis = {self._axis}" |
| 48 | if onnx_opset >= 18: |
| 49 | # Split now requires the number of outputs to be specified even though that can be easily inferred... |
| 50 | split_attr += f", num_outputs = {len(split_outs)}" |
| 51 | |
| 52 | reverse_graph = onnx.parser.parse_graph( |
| 53 | f"""\ |
| 54 | reverse_axis ({input_type_str}[{input_shape_str}] {self.input_names[0]}) |
| 55 | => ({input_type_str}[{input_shape_str}] {self.output_names[0]}) |
| 56 | {{ |
| 57 | {','.join(split_outs)} = Split <{split_attr}> ({self.input_names[0]}) |
| 58 | {self.output_names[0]} = Concat <axis = {self._axis}> ({','.join(reversed(split_outs))}) |
| 59 | }} |
| 60 | """ |
| 61 | ) |
| 62 | |
| 63 | return reverse_graph |
| 64 | |
| 65 | |
| 66 | class Squeeze(Step): |
| 67 | """ |
| 68 | ONNX Squeeze |
| 69 | """ |
| 70 | |
| 71 | def __init__(self, axes: Optional[List[int]] = None, name: Optional[str] = None): |
| 72 | """ |
| 73 | Args: |
| 74 | axes: Axes to remove. |
| 75 | If None, remove all axes with size of 1. Requires all dimensions to have explicit values. |
| 76 | name: Optional Step name. Defaults to 'Squeeze' |
| 77 | """ |
| 78 | super().__init__(["data"], ["squeezed"], name) |
| 79 | self._axes = axes |
| 80 | |
| 81 | def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int): |
| 82 | input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0) |
| 83 | dims = input_shape_str.split(",") |
| 84 | |
| 85 | axes = self._axes |
| 86 | if not axes: |
| 87 | axes = [] |
| 88 | for idx, dim in enumerate(dims): |
| 89 | if not dim.isnumeric(): |
| 90 | # we can't infer the output shape if there are symbolic dims |
| 91 | raise ValueError("Axes must be specified if there are symbolic dimensions.") |
| 92 | |
| 93 | if dim == '1': |
| 94 | axes.append(int(idx)) |
| 95 | |
| 96 | output_dims = [dim for idx, dim in enumerate(dims) if idx not in axes] |
| 97 | output_shape_str = ",".join(output_dims) |
| 98 | |
| 99 | axes_strs = [str(axis) for axis in axes] |
| 100 | |
| 101 | squeeze_graph = onnx.parser.parse_graph( |
| 102 | f"""\ |
| 103 | squeeze ({input_type_str}[{input_shape_str}] {self.input_names[0]}) |
| 104 | => ({input_type_str}[{output_shape_str}] {self.output_names[0]}) |
| 105 | {{ |
| 106 | axes = Constant <value = int64[{len(axes)}] {{{','.join(axes_strs)}}}> () |
| 107 | {self.output_names[0]} = Squeeze({self.input_names[0]}, axes) |
| 108 | }} |
| 109 | """ |
| 110 | ) |
| 111 | |
| 112 | return squeeze_graph |
| 113 | |
| 114 | |
| 115 | class Transpose(Step): |
| 116 | """ |
| 117 | ONNX Transpose. |
| 118 | """ |
| 119 | |
| 120 | def __init__(self, perms: List[int], name: Optional[str] = None): |
| 121 | """ |
| 122 | Args: |
| 123 | perms: List of integers with permutations to apply. |
| 124 | name: Optional Step name. Defaults to 'Transpose' |
| 125 | """ |
| 126 | super().__init__(["X"], ["transposed"], name) |
| 127 | self.perms = perms |
| 128 | |
| 129 | def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int): |
| 130 | input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0) |
| 131 | perms_str = ",".join([str(idx) for idx in self.perms]) |
| 132 | dims = input_shape_str.split(",") |
| 133 | output_dims = [dims[axis] for axis in self.perms] |
| 134 | output_shape_str = ",".join(output_dims) |
| 135 | |
| 136 | transpose_graph = onnx.parser.parse_graph( |
| 137 | f"""\ |
| 138 | transpose ({input_type_str}[{input_shape_str}] {self.input_names[0]}) |
| 139 | => ({input_type_str}[{output_shape_str}] {self.output_names[0]}) |
| 140 | {{ |
| 141 | {self.output_names[0]} = Transpose <perm = [{perms_str}]> ({self.input_names[0]}) |
| 142 | }} |
| 143 | """ |
| 144 | ) |
| 145 | |
| 146 | return transpose_graph |
| 147 | |
| 148 | |
| 149 | class Softmax(Step): |
| 150 | """ |
| 151 | ONNX Softmax |
| 152 | """ |
| 153 | |
| 154 | def __init__(self, name: Optional[str] = None): |
| 155 | """ |
| 156 | Args: |
| 157 | name: Optional Step name. Defaults to 'Softmax' |
| 158 | """ |
| 159 | super().__init__(["data"], ["probabilities"], name) |
| 160 | |
| 161 | def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int): |
| 162 | input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0) |
| 163 | |
| 164 | softmax_graph = onnx.parser.parse_graph( |
| 165 | f"""\ |
| 166 | softmax ({input_type_str}[{input_shape_str}] {self.input_names[0]}) |
| 167 | => ({input_type_str}[{input_shape_str}] {self.output_names[0]}) |
| 168 | {{ |
| 169 | {self.output_names[0]} = Softmax ({self.input_names[0]}) |
| 170 | }} |
| 171 | """ |
| 172 | ) |
| 173 | |
| 174 | return softmax_graph |
| 175 | |
| 176 | |
| 177 | class Unsqueeze(Step): |
| 178 | """ |
| 179 | ONNX Unsqueeze |
| 180 | """ |
| 181 | |
| 182 | def __init__(self, axes: List[int], name: Optional[str] = None): |
| 183 | """ |
| 184 | Args: |
| 185 | axes: List of integers indicating the dimensions to be inserted. |
| 186 | name: Optional Step name. Defaults to 'Unsqueeze' |
| 187 | """ |
| 188 | super().__init__(["data"], ["expanded"], name) |
| 189 | self._axes = axes |
| 190 | |
| 191 | def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int): |
| 192 | input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0) |
| 193 | dims = input_shape_str.split(",") |
| 194 | |
| 195 | for idx in self._axes: |
| 196 | dims.insert(idx, "1") |
| 197 | |
| 198 | output_shape_str = ",".join(dims) |
| 199 | axes_strs = [str(axis) for axis in self._axes] |
| 200 | |
| 201 | unsqueeze_graph = onnx.parser.parse_graph( |
| 202 | f"""\ |
| 203 | unsqueeze ({input_type_str}[{input_shape_str}] {self.input_names[0]}) |
| 204 | => ({input_type_str}[{output_shape_str}] {self.output_names[0]}) |
| 205 | {{ |
| 206 | axes = Constant <value = int64[{len(self._axes)}] {{{','.join(axes_strs)}}}> () |
| 207 | {self.output_names[0]} = Unsqueeze ({self.input_names[0]}, axes) |
| 208 | }} |
| 209 | """ |
| 210 | ) |
| 211 | |
| 212 | return unsqueeze_graph |
| 213 | |
| 214 | |
| 215 | class ArgMax(Step): |
| 216 | def __init__(self, name: Optional[str] = None, axis: int = -1, keepdims: int = 0): |
| 217 | """ |
| 218 | Brief: |
| 219 | Same as ArgMax op. |
| 220 | Args: |
| 221 | name: Optional name of step. Defaults to 'ArgMax' |
| 222 | |
| 223 | """ |
| 224 | super().__init__(["data"], ["index"], name) |
| 225 | self._axis = axis |
| 226 | self._keepdims = keepdims |
| 227 | |
| 228 | def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int): |
| 229 | input_type_str_0, input_shape_str_0 = self._get_input_type_and_shape_strs(graph, 0) |
| 230 | input_shape_0 = input_shape_str_0.split(",") |
| 231 | |
| 232 | def build_input_declare(): |
| 233 | return f"{input_type_str_0}[{input_shape_str_0}] {self.input_names[0]}" |
| 234 | |
| 235 | axis = self._axis + len(input_shape_0) if self._axis < 0 else self._axis |
| 236 | if axis >= len(input_shape_0): |
| 237 | raise ValueError("axis should be in range [-rank, rank-1].") |
| 238 | |
| 239 | output_shape_str = input_shape_0.copy() |
| 240 | output_shape_str[axis] = "1" |
| 241 | if self._keepdims == 0: |
| 242 | output_shape_str.pop(axis) |
| 243 | |
| 244 | converter_graph = onnx.parser.parse_graph( |
| 245 | f"""\ |
| 246 | classify ({build_input_declare()}) |
| 247 | => (int64[{','.join(output_shape_str)}] {self.output_names[0]}) |
| 248 | {{ |
| 249 | {self.output_names[0]} = ArgMax<axis = {self._axis}, keepdims={self._keepdims}>({self.input_names[0]}) |
| 250 | }} |
| 251 | """ |
| 252 | ) |
| 253 | |
| 254 | return converter_graph |
| 255 | |