microsoft/onnxruntime-extensions

Public

mirrored from https://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
edgchen1/fix_ci

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/tools/pre_post_processing/steps/vision.py

590lines · modecode

1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License.
3
4import onnx
5import numpy as np
6
7from typing import List, Optional, Tuple, Union
8from ..step import Step
9from .general import Transpose
10
11#
12# Image conversion
13#
14class ConvertImageToBGR(Step):
15 """
16 Convert the bytes of an image by decoding to BGR ordered uint8 values.
17 Supported input formats: jpg, png
18 Input shape: {num_encoded_bytes}
19 Output shape: {input_image_height, input_image_width, 3}
20 """
21
22 def __init__(self, name: Optional[str] = None):
23 """
24 Args:
25 name: Optional name of step. Defaults to 'ConvertImageToBGR'
26
27 NOTE: Input image format is inferred and does not need to be specified.
28 """
29 super().__init__(["image"], ["bgr_data"], name)
30
31 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
32 input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
33 assert input_type_str == "uint8"
34 output_shape_str = f"to_bgr_ppp_{self.step_num}_h, to_bgr_ppp_{self.step_num}_w, 3"
35
36 converter_graph = onnx.parser.parse_graph(
37 f"""\
38 image_to_bgr (uint8[{input_shape_str}] {self.input_names[0]})
39 => (uint8[{output_shape_str}] {self.output_names[0]})
40 {{
41 {self.output_names[0]} = com.microsoft.extensions.DecodeImage({self.input_names[0]})
42 }}
43 """
44 )
45
46 return converter_graph
47
48
49class ConvertBGRToImage(Step):
50 """
51 Convert BGR ordered uint8 data into an encoded image.
52 Supported output input formats: jpg, png
53 Input shape: {input_image_height, input_image_width, 3}
54 Output shape: {num_encoded_bytes}
55 """
56
57 def __init__(self, image_format: str = "jpg", name: Optional[str] = None):
58 """
59 Args:
60 image_format: Format to encode to. jpg and png are supported.
61 name: Optional step name. Defaults to 'ConvertBGRToImage'
62 """
63 super().__init__(["bgr_data"], ["image"], name)
64 assert image_format == "jpg" or image_format == "png"
65 self._format = image_format
66
67 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
68 input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
69 assert input_type_str == "uint8"
70 output_shape_str = f"to_image_ppp_{self.step_num}_num_bytes"
71
72 converter_graph = onnx.parser.parse_graph(
73 f"""\
74 bgr_to_image (uint8[{input_shape_str}] {self.input_names[0]})
75 => (uint8[{output_shape_str}] {self.output_names[0]})
76 {{
77 {self.output_names[0]} = com.microsoft.extensions.EncodeImage ({self.input_names[0]})
78 }}
79 """
80 )
81
82 # as this is a custom op we have to add the attribute for `format` directly to the node.
83 # parse_graph doesn't have a schema for the operator and fails attempting to validate the attribute.
84 format_attr = converter_graph.node[0].attribute.add()
85 format_attr.name = "format"
86 format_attr.type = onnx.AttributeProto.AttributeType.STRING
87 format_attr.s = bytes(self._format, "utf-8")
88
89 return converter_graph
90
91
92class PixelsToYCbCr(Step):
93 """
94 Convert RGB or BGR pixel data to YCbCr format.
95 Input shape: {height, width, 3}
96 Output shape is the same.
97 Output data is float, but rounded and clipped to the range 0..255 as per the spec for YCbCr conversion.
98 """
99
100 def __init__(self, layout: str = "BGR", name: Optional[str] = None):
101 """
102 Args:
103 layout: Input data layout. Can be 'BGR' or 'RGB'
104 name: Optional step name. Defaults to 'PixelsToYCbCr'
105 """
106 super().__init__(["pixels"], ["Y", "Cb", "Cr"], name)
107 assert layout == "RGB" or layout == "BGR"
108 self._layout = layout
109
110 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
111 input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
112 # input should be uint8 data HWC
113 input_dims = input_shape_str.split(",")
114 assert input_type_str == "uint8" and len(input_dims) == 3 and input_dims[2] == "3"
115
116 # https://en.wikipedia.org/wiki/YCbCr
117 # exact weights from https://www.itu.int/rec/T-REC-T.871-201105-I/en
118 rgb_weights = np.array([[0.299, 0.587, 0.114],
119 [-0.299 / 1.772, -0.587 / 1.772, 0.500],
120 [0.500, -0.587 / 1.402, -0.114 / 1.402]],
121 dtype=np.float32) # fmt: skip
122
123 bias = [0.0, 128.0, 128.0]
124
125 if self._layout == "RGB":
126 weights = rgb_weights
127 else:
128 weights = rgb_weights[:, ::-1] # reverse the order of the last dim for BGR input
129
130 # Weights are transposed for usage in matmul.
131 weights_shape = "3, 3"
132 weights = ",".join([str(w) for w in weights.T.flatten()])
133
134 bias_shape = "3"
135 bias = ",".join([str(b) for b in bias])
136
137 # each output is {h, w}. TBD if input is CHW or HWC though. Once we figure that out we could copy values from
138 # the input shape
139 output_shape_str = f"YCbCr_ppp_{self.step_num}_h, YCbCr_ppp_{self.step_num}_w"
140 assert input_type_str == "uint8"
141
142 # convert to float for MatMul
143 # apply weights and bias
144 # round and clip so it's in the range 0..255
145 # split into channels. shape will be {h, w, 1}
146 # remove the trailing '1' so output is {h, w}
147 converter_graph = onnx.parser.parse_graph(
148 f"""\
149 pixels_to_YCbCr (uint8[{input_shape_str}] {self.input_names[0]})
150 => (float[{output_shape_str}] {self.output_names[0]},
151 float[{output_shape_str}] {self.output_names[1]},
152 float[{output_shape_str}] {self.output_names[2]})
153 {{
154 kWeights = Constant <value = float[{weights_shape}] {{{weights}}}> ()
155 kBias = Constant <value = float[{bias_shape}] {{{bias}}}> ()
156 i64_neg1 = Constant <value = int64[1] {{-1}}> ()
157 f_0 = Constant <value = float[1] {{0.0}}> ()
158 f_255 = Constant <value = float[1] {{255.0}}> ()
159
160 f_pixels = Cast <to = 1> ({self.input_names[0]})
161 f_weighted = MatMul(f_pixels, kWeights)
162 f_biased = Add(f_weighted, kBias)
163 f_rounded = Round(f_biased)
164 f_clipped = Clip (f_rounded, f_0, f_255)
165 split_Y, split_Cb, split_Cr = Split <axis = -1>(f_clipped)
166 {self.output_names[0]} = Squeeze (split_Y, i64_neg1)
167 {self.output_names[1]} = Squeeze (split_Cb, i64_neg1)
168 {self.output_names[2]} = Squeeze (split_Cr, i64_neg1)
169 }}
170 """
171 )
172
173 return converter_graph
174
175
176class YCbCrToPixels(Step):
177 """
178 Convert YCbCr input to RGB or BGR.
179
180 Input data can be uint8 or float but all inputs must use the same type.
181 Input shape: {height, width, 3}
182 Output shape is the same.
183 """
184
185 def __init__(self, layout: str = "BGR", name: Optional[str] = None):
186 """
187 Args:
188 layout: Output layout. Can be 'BGR' or 'RGB'
189 name: Optional step name. Defaults to 'YCbCrToPixels'
190 """
191 super().__init__(["Y", "Cb", "Cr"], ["bgr_data"], name)
192 assert layout == "RGB" or layout == "BGR"
193 self._layout = layout
194
195 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
196 input_type_str0, input_shape_str0 = self._get_input_type_and_shape_strs(graph, 0)
197 input_type_str1, input_shape_str1 = self._get_input_type_and_shape_strs(graph, 1)
198 input_type_str2, input_shape_str2 = self._get_input_type_and_shape_strs(graph, 2)
199 assert (input_type_str0 == "uint8" and input_type_str1 == "uint8" and input_type_str2 == "uint8") or (
200 input_type_str0 == "float" and input_type_str1 == "float" and input_type_str2 == "float"
201 )
202
203 assert (
204 len(input_shape_str0.split(",")) == 2
205 and len(input_shape_str1.split(",")) == 2
206 and len(input_shape_str2.split(",")) == 2
207 )
208
209 output_shape_str = f"{input_shape_str0}, 3"
210
211 # fmt: off
212 # https://en.wikipedia.org/wiki/YCbCr
213 # exact weights from https://www.itu.int/rec/T-REC-T.871-201105-I/en
214 ycbcr_to_rgb_weights = np.array([[1, 0, 1.402],
215 [1, -0.114*1.772/0.587, -0.299*1.402/0.587],
216 [1, 1.772, 0]],
217 dtype=np.float32)
218 # fmt: on
219
220 # reverse first dim of weights for output to be bgr
221 ycbcr_to_bgr_weights = ycbcr_to_rgb_weights[::-1, :]
222
223 weights = ycbcr_to_bgr_weights if self._layout == "BGR" else ycbcr_to_rgb_weights
224 bias = [0.0, 128.0, 128.0]
225
226 weights_shape = "3, 3"
227 # transpose weights for use in matmul
228 weights = ",".join([str(w) for w in weights.T.flatten()])
229
230 bias_shape = "3"
231 bias = ",".join([str(b) for b in bias])
232
233 # unsqueeze the {h, w} inputs to add channels dim. new shape is {h, w, 1}
234 # merge Y, Cb, Cr data on the new channel axis
235 # convert to float to apply weights etc.
236 # remove bias
237 # apply weights
238 # round and clip to 0..255
239 # convert to uint8.
240 converter_graph = onnx.parser.parse_graph(
241 f"""\
242 YCbCr_to_RGB ({input_type_str0}[{input_shape_str0}] {self.input_names[0]},
243 {input_type_str1}[{input_shape_str1}] {self.input_names[1]},
244 {input_type_str2}[{input_shape_str2}] {self.input_names[2]})
245 => (uint8[{output_shape_str}] {self.output_names[0]})
246 {{
247 kWeights = Constant <value = float[{weights_shape}] {{{weights}}}> ()
248 kBias = Constant <value = float[{bias_shape}] {{{bias}}}> ()
249 f_0 = Constant <value = float[1] {{0.0}}> ()
250 f_255 = Constant <value = float[1] {{255.0}}> ()
251 i64_neg1 = Constant <value = int64[1] {{-1}}> ()
252
253 Y1 = Unsqueeze({self.input_names[0]}, i64_neg1)
254 Cb1 = Unsqueeze({self.input_names[1]}, i64_neg1)
255 Cr1 = Unsqueeze({self.input_names[2]}, i64_neg1)
256 YCbCr = Concat <axis = -1> (Y1, Cb1, Cr1)
257 f_YCbCr = Cast <to = 1> (YCbCr)
258 f_unbiased = Sub (f_YCbCr, kBias)
259 f_pixels = MatMul (f_unbiased, kWeights)
260 f_rounded = Round (f_pixels)
261 clipped = Clip (f_rounded, f_0, f_255)
262 {self.output_names[0]} = Cast <to = {onnx.TensorProto.UINT8}> (clipped)
263 }}
264 """
265 )
266
267 return converter_graph
268
269
270#
271# Pre-processing
272#
273class Resize(Step):
274 """
275 Resize input data. Aspect ratio is maintained.
276 e.g. if image is 1200 x 600 and 300 x 300 is requested the result will be 600 x 300
277 """
278
279 def __init__(self, resize_to: Union[int, Tuple[int, int]], layout: str = "HWC", name: Optional[str] = None):
280 """
281 Args:
282 resize_to: Target size. Can be a single value or a tuple with (target_height, target_width).
283 The aspect ratio will be maintained and neither height or width in the result will be smaller
284 than the requested value.
285 layout: Input layout. 'NCHW', 'NHWC', 'CHW', 'HWC' and 'HW' are supported.
286 name: Optional name. Defaults to 'Resize'
287 """
288 super().__init__(["image"], ["resized_image"], name)
289 if isinstance(resize_to, int):
290 self._height = self._width = resize_to
291 else:
292 assert isinstance(resize_to, tuple)
293 self._height, self._width = resize_to
294
295 self._layout = layout
296
297 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
298 input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
299 dims = input_shape_str.split(",")
300
301 # adjust for layout
302 # resize will use the largest ratio so both sides won't necessarily match the requested height and width.
303 # use symbolic names for the output dims as we have to provide values. prefix the names to try and
304 # avoid any clashes.
305 scales_constant_str = "f_1 = Constant <value = float[1] {1.0}> ()"
306 add_batch_dim = False
307
308 if self._layout == "NHWC":
309 assert len(dims) == 4
310 split_str = "n, h, w, c"
311 scales_str = "f_1, ratio_resize, ratio_resize, f_1"
312 output_shape_str = f"{dims[0]}, resize_ppp_{self.step_num}_h, resize_ppp_{self.step_num}_w, {dims[-1]}"
313 elif self._layout == "NCHW":
314 assert len(dims) == 4
315 split_str = "n, c, h, w"
316 scales_str = "f_1, f_1, ratio_resize, ratio_resize"
317 output_shape_str = f"{dims[0]}, {dims[1]}, resize_ppp_{self.step_num}_h, resize_ppp_{self.step_num}_w"
318 elif self._layout == "HWC":
319 assert len(dims) == 3
320 add_batch_dim = True
321 split_str = "h, w, c"
322 scales_str = "ratio_resize, ratio_resize, f_1"
323 output_shape_str = f"resize_ppp_{self.step_num}_h, resize_ppp_{self.step_num}_w, {dims[-1]}"
324 elif self._layout == "CHW":
325 assert len(dims) == 3
326 add_batch_dim = True
327 split_str = "c, h, w"
328 scales_str = "f_1, ratio_resize, ratio_resize"
329 output_shape_str = f"{dims[0]}, resize_ppp_{self.step_num}_h, resize_ppp_{self.step_num}_w"
330 elif self._layout == "HW":
331 assert len(dims) == 2
332 split_str = "h, w"
333 scales_str = "ratio_resize, ratio_resize"
334 scales_constant_str = ""
335 output_shape_str = f"resize_ppp_{self.step_num}_h, resize_ppp_{self.step_num}_w"
336 else:
337 raise ValueError(f"Unsupported layout of {self._layout}")
338
339 # TODO: Make this configurable. Matching PIL resize for now.
340 resize_attributes = 'mode = "linear", nearest_mode = "floor"'
341 if onnx_opset >= 18:
342 # Resize matches PIL better if antialiasing is used, but that isn't available until ONNX opset 18.
343 # Allow this to be used with older opsets as well.
344 resize_attributes += ', antialias = 1'
345
346 # Rank 3 input uses trilinear interpolation, so if input is HWC or CHW we need to add a temporary batch dim
347 # to make it rank 4, which will result in Resize using the desired bilinear interpolation.
348 if add_batch_dim:
349 scales_str = "f_1, " + scales_str
350 resize_str = \
351 f"""\
352 axes = Constant <value = int64[1] {{{0}}}> ()
353 unsqueezed = Unsqueeze ({self.input_names[0]}, axes)
354 resized = Resize <{resize_attributes}> (unsqueezed, , scales_resize)
355 {self.output_names[0]} = Squeeze (resized, axes)
356 """
357 else:
358 resize_str = \
359 f"{self.output_names[0]} = Resize <{resize_attributes}> ({self.input_names[0]}, , scales_resize)"
360
361 resize_graph = onnx.parser.parse_graph(
362 f"""\
363 resize ({input_type_str}[{input_shape_str}] {self.input_names[0]}) =>
364 ({input_type_str}[{output_shape_str}] {self.output_names[0]})
365 {{
366 target_size = Constant <value = float[2] {{{float(self._height)}, {float(self._width)}}}> ()
367 image_shape = Shape ({self.input_names[0]})
368 {split_str} = Split <axis = 0> (image_shape)
369 hw = Concat <axis = 0> (h, w)
370 f_hw = Cast <to = 1> (hw)
371 ratios = Div (target_size, f_hw)
372 ratio_resize = ReduceMax (ratios)
373
374 {scales_constant_str}
375 scales_resize = Concat <axis = 0> ({scales_str})
376 {resize_str}
377 }}
378 """
379 )
380
381 return resize_graph
382
383
384class CenterCrop(Step):
385 """
386 Crop the input to the requested dimensions, with the crop being centered.
387 """
388
389 def __init__(self, height: int, width: int, name: Optional[str] = None):
390 """
391 Args:
392 height: Height of area to crop.
393 width: Width of area to crop.
394 name: Optional step name. Defaults to 'CenterCrop'
395 """
396 super().__init__(["image"], ["cropped_image"], name)
397 self._height = height
398 self._width = width
399
400 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
401 input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
402 dims = input_shape_str.split(",")
403 output_shape_str = f"{self._height}, {self._width}, {dims[-1]}"
404
405 crop_graph = onnx.parser.parse_graph(
406 f"""\
407 crop ({input_type_str}[{input_shape_str}] {self.input_names[0]})
408 => ({input_type_str}[{output_shape_str}] {self.output_names[0]})
409 {{
410 target_crop = Constant <value = int64[2] {{{self._height}, {self._width}}}> ()
411 i64_2 = Constant <value = int64[1] {{2}}> ()
412 axes = Constant <value = int64[2] {{0, 1}}> ()
413 x_shape = Shape ({self.input_names[0]})
414 h, w, c = Split <axis = 0> (x_shape)
415 hw = Concat <axis = 0> (h, w)
416 hw_diff = Sub (hw, target_crop)
417 start_xy = Div (hw_diff, i64_2)
418 end_xy = Add (start_xy, target_crop)
419 {self.output_names[0]} = Slice ({self.input_names[0]}, start_xy, end_xy, axes)
420 }}
421 """
422 )
423
424 return crop_graph
425
426
427class Normalize(Step):
428 """
429 Normalize input data on a per-channel basis.
430 `x -> (x - mean) / stddev`
431 Output is float with same shape as input.
432 """
433
434 def __init__(self, normalization_values: List[Tuple[float, float]], layout: str = "CHW", name: Optional[str] = None):
435 """
436 Args:
437 normalization_values: Tuple with (mean, stddev). One entry per channel.
438 If single entry is provided it will be used for all channels.
439 layout: Input layout. Can be 'CHW' or 'HWC'
440 name: Optional step name. Defaults to 'Normalize'
441 """
442 super().__init__(["data"], ["normalized_data"], name)
443
444 # duplicate for each channel if needed
445 if len(normalization_values) == 1:
446 normalization_values *= 3
447
448 assert len(normalization_values) == 3
449 self._normalization_values = normalization_values
450 assert layout == "HWC" or layout == "CHW"
451 self._hwc_layout = True if layout == "HWC" else False
452
453 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
454 mean0 = self._normalization_values[0][0]
455 mean1 = self._normalization_values[1][0]
456 mean2 = self._normalization_values[2][0]
457 stddev0 = self._normalization_values[0][1]
458 stddev1 = self._normalization_values[1][1]
459 stddev2 = self._normalization_values[2][1]
460
461 input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
462 values_shape = "3" if self._hwc_layout else "3, 1, 1"
463
464 normalize_graph = onnx.parser.parse_graph(
465 f"""\
466 normalize ({input_type_str}[{input_shape_str}] {self.input_names[0]})
467 => (float[{input_shape_str}] {self.output_names[0]})
468 {{
469 kMean = Constant <value = float[{values_shape}] {{{mean0}, {mean1}, {mean2}}}> ()
470 kStddev = Constant <value = float[{values_shape}] {{{stddev0}, {stddev1}, {stddev2}}}> ()
471 f_input = Cast <to = 1> ({self.input_names[0]})
472 f_sub_mean = Sub (f_input, kMean)
473 {self.output_names[0]} = Div (f_sub_mean, kStddev)
474 }}
475 """
476 )
477
478 onnx.checker.check_graph(normalize_graph)
479 return normalize_graph
480
481
482#
483# Utilities
484#
485class ImageBytesToFloat(Step):
486 """
487 Convert uint8 or float values in range 0..255 to floating point values in range 0..1
488 """
489
490 def __init__(self, name: Optional[str] = None):
491 """
492 Args:
493 name: Optional step name. Defaults to 'ImageBytesToFloat'
494 """
495 super().__init__(["data"], ["float_data"], name)
496
497 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
498 input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
499 if input_type_str == "uint8":
500 optional_cast = f"""\
501 input_f = Cast <to = 1> ({self.input_names[0]})
502 """
503 else:
504 # no-op that optimizer will remove
505 optional_cast = f"input_f = Identity ({self.input_names[0]})"
506
507 byte_to_float_graph = onnx.parser.parse_graph(
508 f"""\
509 byte_to_float ({input_type_str}[{input_shape_str}] {self.input_names[0]})
510 => (float[{input_shape_str}] {self.output_names[0]})
511 {{
512 f_255 = Constant <value = float[1] {{255.0}}>()
513
514 {optional_cast}
515 {self.output_names[0]} = Div(input_f, f_255)
516 }}
517 """
518 )
519
520 onnx.checker.check_graph(byte_to_float_graph)
521 return byte_to_float_graph
522
523
524class FloatToImageBytes(Step):
525 """
526 Converting floating point values to uint8 values in range 0..255.
527 Typically this reverses ImageBytesToFloat by converting input data in the range 0..1, but an optional multiplier
528 can be specified if the input data has a different range.
529 Values will be rounded prior to clipping and conversion to uint8.
530 """
531
532 def __init__(self, multiplier: float = 255.0, name: Optional[str] = None):
533 """
534 Args:
535 multiplier: Optional multiplier. Currently, the expected values are 255 (input data is in range 0..1), or
536 1 (input data is in range 0..255).
537 name: Optional step name. Defaults to 'FloatToImageBytes'
538 """
539 super().__init__(["float_data"], ["pixel_data"], name)
540 self._multiplier = multiplier
541
542 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
543 input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
544 assert input_type_str == "float"
545
546 if self._multiplier == 1.0:
547 scale_input = ''
548 scaled_input_name = self.input_names[0]
549 else:
550 scale_input = \
551 f"""\
552 f_multiplier = Constant <value = float[1] {{{self._multiplier}}}> ()
553 scaled_input = Mul ({self.input_names[0]}, f_multiplier)
554 """
555 scaled_input_name = 'scaled_input'
556
557 float_to_byte_graphs = onnx.parser.parse_graph(
558 f"""\
559 float_to_type (float[{input_shape_str}] {self.input_names[0]})
560 => (uint8[{input_shape_str}] {self.output_names[0]})
561 {{
562 f_0 = Constant <value = float[1] {{0.0}}> ()
563 f_255 = Constant <value = float[1] {{255.0}}>()
564
565 {scale_input}
566 rounded = Round ({scaled_input_name})
567 clipped = Clip (rounded, f_0, f_255)
568 {self.output_names[0]} = Cast <to = {onnx.TensorProto.UINT8}> (clipped)
569 }}
570 """
571 )
572
573 onnx.checker.check_graph(float_to_byte_graphs)
574 return float_to_byte_graphs
575
576
577class ChannelsLastToChannelsFirst(Transpose):
578 """
579 Convert channels last data to channels first.
580 Input can be NHWC or HWC.
581 """
582
583 def __init__(self, has_batch_dim: bool = False, name: Optional[str] = None):
584 """
585 Args:
586 has_batch_dim: Set to True if the input has a batch dimension (i.e. is NHWC)
587 name: Optional step name. Defaults to 'ChannelsLastToChannelsFirst'
588 """
589 perms = [0, 3, 1, 2] if has_batch_dim else [2, 0, 1]
590 super().__init__(perms, name)
591