microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
rel-0.7

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/tools/pre_post_processing/steps/vision.py

606lines · modecode

1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License.
3
4import onnx
5import numpy as np
6
7from typing import List, Optional, Tuple, Union
8from ..step import Step
9from .general import Transpose
10
11#
12# Image conversion
13#
14class ConvertImageToBGR(Step):
15 """
16 Convert the bytes of an image by decoding to BGR ordered uint8 values.
17 Supported input formats: jpg, png
18 Input shape: {num_encoded_bytes}
19 Output shape: {input_image_height, input_image_width, 3}
20 """
21
22 def __init__(self, name: Optional[str] = None):
23 """
24 Args:
25 name: Optional name of step. Defaults to 'ConvertImageToBGR'
26
27 NOTE: Input image format is inferred and does not need to be specified.
28 """
29 super().__init__(["image"], ["bgr_data"], name)
30
31 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
32 input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
33 assert input_type_str == "uint8"
34 output_shape_str = f"to_bgr_ppp_{self.step_num}_h, to_bgr_ppp_{self.step_num}_w, 3"
35
36 converter_graph = onnx.parser.parse_graph(
37 f"""\
38 image_to_bgr (uint8[{input_shape_str}] {self.input_names[0]})
39 => (uint8[{output_shape_str}] {self.output_names[0]})
40 {{
41 {self.output_names[0]} = com.microsoft.extensions.DecodeImage({self.input_names[0]})
42 }}
43 """
44 )
45
46 return converter_graph
47
48
49class ConvertBGRToImage(Step):
50 """
51 Convert BGR ordered uint8 data into an encoded image.
52 Supported output input formats: jpg, png
53 Input shape: {input_image_height, input_image_width, 3}
54 Output shape: {num_encoded_bytes}
55 """
56
57 def __init__(self, image_format: str = "jpg", name: Optional[str] = None):
58 """
59 Args:
60 image_format: Format to encode to. jpg and png are supported.
61 name: Optional step name. Defaults to 'ConvertBGRToImage'
62 """
63 super().__init__(["bgr_data"], ["image"], name)
64 assert image_format == "jpg" or image_format == "png"
65 self._format = image_format
66
67 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
68 input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
69 assert input_type_str == "uint8"
70 output_shape_str = f"to_image_ppp_{self.step_num}_num_bytes"
71
72 converter_graph = onnx.parser.parse_graph(
73 f"""\
74 bgr_to_image (uint8[{input_shape_str}] {self.input_names[0]})
75 => (uint8[{output_shape_str}] {self.output_names[0]})
76 {{
77 {self.output_names[0]} = com.microsoft.extensions.EncodeImage ({self.input_names[0]})
78 }}
79 """
80 )
81
82 # as this is a custom op we have to add the attribute for `format` directly to the node.
83 # parse_graph doesn't have a schema for the operator and fails attempting to validate the attribute.
84 format_attr = converter_graph.node[0].attribute.add()
85 format_attr.name = "format"
86 format_attr.type = onnx.AttributeProto.AttributeType.STRING
87 format_attr.s = bytes(self._format, "utf-8")
88
89 return converter_graph
90
91
92class PixelsToYCbCr(Step):
93 """
94 Convert RGB or BGR pixel data to YCbCr format.
95 Input shape: {height, width, 3}
96 Output shape is the same.
97 Output data is float, but rounded and clipped to the range 0..255 as per the spec for YCbCr conversion.
98 """
99
100 def __init__(self, layout: str = "BGR", name: Optional[str] = None):
101 """
102 Args:
103 layout: Input data layout. Can be 'BGR' or 'RGB'
104 name: Optional step name. Defaults to 'PixelsToYCbCr'
105 """
106 super().__init__(["pixels"], ["Y", "Cb", "Cr"], name)
107 assert layout == "RGB" or layout == "BGR"
108 self._layout = layout
109
110 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
111 input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
112 # input should be uint8 data HWC
113 input_dims = input_shape_str.split(",")
114 assert input_type_str == "uint8" and len(input_dims) == 3 and input_dims[2] == "3"
115
116 # https://en.wikipedia.org/wiki/YCbCr
117 # exact weights from https://www.itu.int/rec/T-REC-T.871-201105-I/en
118 rgb_weights = np.array([[0.299, 0.587, 0.114],
119 [-0.299 / 1.772, -0.587 / 1.772, 0.500],
120 [0.500, -0.587 / 1.402, -0.114 / 1.402]],
121 dtype=np.float32) # fmt: skip
122
123 bias = [0.0, 128.0, 128.0]
124
125 if self._layout == "RGB":
126 weights = rgb_weights
127 else:
128 weights = rgb_weights[:, ::-1] # reverse the order of the last dim for BGR input
129
130 # Weights are transposed for usage in matmul.
131 weights_shape = "3, 3"
132 weights = ",".join([str(w) for w in weights.T.flatten()])
133
134 bias_shape = "3"
135 bias = ",".join([str(b) for b in bias])
136
137 # each output is {h, w}. TBD if input is CHW or HWC though. Once we figure that out we could copy values from
138 # the input shape
139 output_shape_str = f"YCbCr_ppp_{self.step_num}_h, YCbCr_ppp_{self.step_num}_w"
140 assert input_type_str == "uint8"
141
142 split_attr = "axis = -1"
143 if onnx_opset >= 18:
144 # Split now requires the number of outputs to be specified even though that can be easily inferred...
145 split_attr += ", num_outputs = 3"
146
147 # convert to float for MatMul
148 # apply weights and bias
149 # round and clip so it's in the range 0..255
150 # split into channels. shape will be {h, w, 1}
151 # remove the trailing '1' so output is {h, w}
152 converter_graph = onnx.parser.parse_graph(
153 f"""\
154 pixels_to_YCbCr (uint8[{input_shape_str}] {self.input_names[0]})
155 => (float[{output_shape_str}] {self.output_names[0]},
156 float[{output_shape_str}] {self.output_names[1]},
157 float[{output_shape_str}] {self.output_names[2]})
158 {{
159 kWeights = Constant <value = float[{weights_shape}] {{{weights}}}> ()
160 kBias = Constant <value = float[{bias_shape}] {{{bias}}}> ()
161 i64_neg1 = Constant <value = int64[1] {{-1}}> ()
162 f_0 = Constant <value = float[1] {{0.0}}> ()
163 f_255 = Constant <value = float[1] {{255.0}}> ()
164
165 f_pixels = Cast <to = 1> ({self.input_names[0]})
166 f_weighted = MatMul(f_pixels, kWeights)
167 f_biased = Add(f_weighted, kBias)
168 f_rounded = Round(f_biased)
169 f_clipped = Clip (f_rounded, f_0, f_255)
170 split_Y, split_Cb, split_Cr = Split <{split_attr}>(f_clipped)
171 {self.output_names[0]} = Squeeze (split_Y, i64_neg1)
172 {self.output_names[1]} = Squeeze (split_Cb, i64_neg1)
173 {self.output_names[2]} = Squeeze (split_Cr, i64_neg1)
174 }}
175 """
176 )
177
178 return converter_graph
179
180
181class YCbCrToPixels(Step):
182 """
183 Convert YCbCr input to RGB or BGR.
184
185 Input data can be uint8 or float but all inputs must use the same type.
186 Input shape: {height, width, 3}
187 Output shape is the same.
188 """
189
190 def __init__(self, layout: str = "BGR", name: Optional[str] = None):
191 """
192 Args:
193 layout: Output layout. Can be 'BGR' or 'RGB'
194 name: Optional step name. Defaults to 'YCbCrToPixels'
195 """
196 super().__init__(["Y", "Cb", "Cr"], ["bgr_data"], name)
197 assert layout == "RGB" or layout == "BGR"
198 self._layout = layout
199
200 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
201 input_type_str0, input_shape_str0 = self._get_input_type_and_shape_strs(graph, 0)
202 input_type_str1, input_shape_str1 = self._get_input_type_and_shape_strs(graph, 1)
203 input_type_str2, input_shape_str2 = self._get_input_type_and_shape_strs(graph, 2)
204 assert (input_type_str0 == "uint8" and input_type_str1 == "uint8" and input_type_str2 == "uint8") or (
205 input_type_str0 == "float" and input_type_str1 == "float" and input_type_str2 == "float"
206 )
207
208 assert (
209 len(input_shape_str0.split(",")) == 2
210 and len(input_shape_str1.split(",")) == 2
211 and len(input_shape_str2.split(",")) == 2
212 )
213
214 output_shape_str = f"{input_shape_str0}, 3"
215
216 # fmt: off
217 # https://en.wikipedia.org/wiki/YCbCr
218 # exact weights from https://www.itu.int/rec/T-REC-T.871-201105-I/en
219 ycbcr_to_rgb_weights = np.array([[1, 0, 1.402],
220 [1, -0.114*1.772/0.587, -0.299*1.402/0.587],
221 [1, 1.772, 0]],
222 dtype=np.float32)
223 # fmt: on
224
225 # reverse first dim of weights for output to be bgr
226 ycbcr_to_bgr_weights = ycbcr_to_rgb_weights[::-1, :]
227
228 weights = ycbcr_to_bgr_weights if self._layout == "BGR" else ycbcr_to_rgb_weights
229 bias = [0.0, 128.0, 128.0]
230
231 weights_shape = "3, 3"
232 # transpose weights for use in matmul
233 weights = ",".join([str(w) for w in weights.T.flatten()])
234
235 bias_shape = "3"
236 bias = ",".join([str(b) for b in bias])
237
238 # unsqueeze the {h, w} inputs to add channels dim. new shape is {h, w, 1}
239 # merge Y, Cb, Cr data on the new channel axis
240 # convert to float to apply weights etc.
241 # remove bias
242 # apply weights
243 # round and clip to 0..255
244 # convert to uint8.
245 converter_graph = onnx.parser.parse_graph(
246 f"""\
247 YCbCr_to_RGB ({input_type_str0}[{input_shape_str0}] {self.input_names[0]},
248 {input_type_str1}[{input_shape_str1}] {self.input_names[1]},
249 {input_type_str2}[{input_shape_str2}] {self.input_names[2]})
250 => (uint8[{output_shape_str}] {self.output_names[0]})
251 {{
252 kWeights = Constant <value = float[{weights_shape}] {{{weights}}}> ()
253 kBias = Constant <value = float[{bias_shape}] {{{bias}}}> ()
254 f_0 = Constant <value = float[1] {{0.0}}> ()
255 f_255 = Constant <value = float[1] {{255.0}}> ()
256 i64_neg1 = Constant <value = int64[1] {{-1}}> ()
257
258 Y1 = Unsqueeze({self.input_names[0]}, i64_neg1)
259 Cb1 = Unsqueeze({self.input_names[1]}, i64_neg1)
260 Cr1 = Unsqueeze({self.input_names[2]}, i64_neg1)
261 YCbCr = Concat <axis = -1> (Y1, Cb1, Cr1)
262 f_YCbCr = Cast <to = 1> (YCbCr)
263 f_unbiased = Sub (f_YCbCr, kBias)
264 f_pixels = MatMul (f_unbiased, kWeights)
265 f_rounded = Round (f_pixels)
266 clipped = Clip (f_rounded, f_0, f_255)
267 {self.output_names[0]} = Cast <to = {onnx.TensorProto.UINT8}> (clipped)
268 }}
269 """
270 )
271
272 return converter_graph
273
274
275#
276# Pre-processing
277#
278class Resize(Step):
279 """
280 Resize input data. Aspect ratio is maintained.
281 e.g. if image is 1200 x 600 and 300 x 300 is requested the result will be 600 x 300
282 """
283
284 def __init__(self, resize_to: Union[int, Tuple[int, int]], layout: str = "HWC", name: Optional[str] = None):
285 """
286 Args:
287 resize_to: Target size. Can be a single value or a tuple with (target_height, target_width).
288 The aspect ratio will be maintained and neither height or width in the result will be smaller
289 than the requested value.
290 layout: Input layout. 'NCHW', 'NHWC', 'CHW', 'HWC' and 'HW' are supported.
291 name: Optional name. Defaults to 'Resize'
292 """
293 super().__init__(["image"], ["resized_image"], name)
294 if isinstance(resize_to, int):
295 self._height = self._width = resize_to
296 else:
297 assert isinstance(resize_to, tuple)
298 self._height, self._width = resize_to
299
300 self._layout = layout
301
302 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
303 input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
304 dims = input_shape_str.split(",")
305
306 # adjust for layout
307 # resize will use the largest ratio so both sides won't necessarily match the requested height and width.
308 # use symbolic names for the output dims as we have to provide values. prefix the names to try and
309 # avoid any clashes.
310 add_batch_dim = False
311
312 if self._layout == "NHWC":
313 assert len(dims) == 4
314 split_str = "n, h, w, c"
315 sizes_str = "n, h2, w2, c"
316 output_shape_str = f"{dims[0]}, resize_ppp_{self.step_num}_h, resize_ppp_{self.step_num}_w, {dims[-1]}"
317 elif self._layout == "NCHW":
318 assert len(dims) == 4
319 split_str = "n, c, h, w"
320 sizes_str = "n, c, h2, w2"
321 output_shape_str = f"{dims[0]}, {dims[1]}, resize_ppp_{self.step_num}_h, resize_ppp_{self.step_num}_w"
322 elif self._layout == "HWC":
323 assert len(dims) == 3
324 add_batch_dim = True
325 split_str = "h, w, c"
326 sizes_str = "h2, w2, c"
327 output_shape_str = f"resize_ppp_{self.step_num}_h, resize_ppp_{self.step_num}_w, {dims[-1]}"
328 elif self._layout == "CHW":
329 assert len(dims) == 3
330 add_batch_dim = True
331 split_str = "c, h, w"
332 sizes_str = "c, h2, w2"
333 output_shape_str = f"{dims[0]}, resize_ppp_{self.step_num}_h, resize_ppp_{self.step_num}_w"
334 elif self._layout == "HW":
335 assert len(dims) == 2
336 split_str = "h, w"
337 sizes_str = "h2, w2"
338 output_shape_str = f"resize_ppp_{self.step_num}_h, resize_ppp_{self.step_num}_w"
339 else:
340 raise ValueError(f"Unsupported layout of {self._layout}")
341
342 # TODO: Make this configurable. Matching PIL resize for now.
343 resize_attributes = 'mode = "linear", nearest_mode = "floor"'
344 if onnx_opset >= 18:
345 # Resize matches PIL better if antialiasing is used, but that isn't available until ONNX opset 18.
346 # Allow this to be used with older opsets as well.
347 resize_attributes += ', antialias = 1'
348
349 u64_1_str = ""
350
351 # Rank 3 input uses trilinear interpolation, so if input is HWC or CHW we need to add a temporary batch dim
352 # to make it rank 4, which will result in Resize using the desired bilinear interpolation.
353 if add_batch_dim:
354 u64_1_str = "u64_1 = Constant <value = int64[1] {1}> ()"
355 sizes_str = "u64_1, " + sizes_str
356 resize_str = \
357 f"""\
358 axes = Constant <value = int64[1] {{{0}}}> ()
359 unsqueezed = Unsqueeze ({self.input_names[0]}, axes)
360 resized = Resize <{resize_attributes}> (unsqueezed, , , sizes_resize)
361 {self.output_names[0]} = Squeeze (resized, axes)
362 """
363 else:
364 resize_str = \
365 f"{self.output_names[0]} = Resize <{resize_attributes}> ({self.input_names[0]}, , , sizes_resize)"
366
367 split_input_shape_attr = "axis = 0"
368 split_new_sizes_attr = "axis = 0"
369 if onnx_opset >= 18:
370 # Split now requires the number of outputs to be specified even though that can be easily inferred...
371 split_input_shape_attr += f", num_outputs = {len(dims)}"
372 split_new_sizes_attr += ", num_outputs = 2"
373
374 resize_graph = onnx.parser.parse_graph(
375 f"""\
376 resize ({input_type_str}[{input_shape_str}] {self.input_names[0]}) =>
377 ({input_type_str}[{output_shape_str}] {self.output_names[0]})
378 {{
379 target_size = Constant <value = float[2] {{{float(self._height)}, {float(self._width)}}}> ()
380 image_shape = Shape ({self.input_names[0]})
381 {split_str} = Split <{split_input_shape_attr}> (image_shape)
382 hw = Concat <axis = 0> (h, w)
383 f_hw = Cast <to = 1> (hw)
384 ratios = Div (target_size, f_hw)
385 ratio_resize = ReduceMax (ratios)
386 f_hw2_exact = Mul (f_hw, ratio_resize)
387 f_hw2_round = Round (f_hw2_exact)
388 hw2 = Cast <to = 7> (f_hw2_round)
389 h2, w2 = Split <{split_new_sizes_attr}> (hw2)
390 {u64_1_str}
391 sizes_resize = Concat <axis = 0> ({sizes_str})
392 {resize_str}
393 }}
394 """
395 )
396
397 return resize_graph
398
399
400class CenterCrop(Step):
401 """
402 Crop the input to the requested dimensions, with the crop being centered.
403 Currently only HWC input is handled.
404 """
405
406 def __init__(self, height: int, width: int, name: Optional[str] = None):
407 """
408 Args:
409 height: Height of area to crop.
410 width: Width of area to crop.
411 name: Optional step name. Defaults to 'CenterCrop'
412 """
413 super().__init__(["image"], ["cropped_image"], name)
414 self._height = height
415 self._width = width
416
417 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
418 input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
419 dims = input_shape_str.split(",")
420 output_shape_str = f"{self._height}, {self._width}, {dims[-1]}"
421
422 crop_graph = onnx.parser.parse_graph(
423 f"""\
424 crop ({input_type_str}[{input_shape_str}] {self.input_names[0]})
425 => ({input_type_str}[{output_shape_str}] {self.output_names[0]})
426 {{
427 target_crop = Constant <value = int64[2] {{{self._height}, {self._width}}}> ()
428 i64_2 = Constant <value = int64[1] {{2}}> ()
429 axes = Constant <value = int64[2] {{0, 1}}> ()
430 x_shape = Shape ({self.input_names[0]})
431 hw = Gather (x_shape, axes)
432 hw_diff = Sub (hw, target_crop)
433 start_xy = Div (hw_diff, i64_2)
434 end_xy = Add (start_xy, target_crop)
435 {self.output_names[0]} = Slice ({self.input_names[0]}, start_xy, end_xy, axes)
436 }}
437 """
438 )
439
440 return crop_graph
441
442
443class Normalize(Step):
444 """
445 Normalize input data on a per-channel basis.
446 `x -> (x - mean) / stddev`
447 Output is float with same shape as input.
448 """
449
450 def __init__(self, normalization_values: List[Tuple[float, float]], layout: str = "CHW", name: Optional[str] = None):
451 """
452 Args:
453 normalization_values: Tuple with (mean, stddev). One entry per channel.
454 If single entry is provided it will be used for all channels.
455 layout: Input layout. Can be 'CHW' or 'HWC'
456 name: Optional step name. Defaults to 'Normalize'
457 """
458 super().__init__(["data"], ["normalized_data"], name)
459
460 # duplicate for each channel if needed
461 if len(normalization_values) == 1:
462 normalization_values *= 3
463
464 assert len(normalization_values) == 3
465 self._normalization_values = normalization_values
466 assert layout == "HWC" or layout == "CHW"
467 self._hwc_layout = True if layout == "HWC" else False
468
469 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
470 mean0 = self._normalization_values[0][0]
471 mean1 = self._normalization_values[1][0]
472 mean2 = self._normalization_values[2][0]
473 stddev0 = self._normalization_values[0][1]
474 stddev1 = self._normalization_values[1][1]
475 stddev2 = self._normalization_values[2][1]
476
477 input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
478 values_shape = "3" if self._hwc_layout else "3, 1, 1"
479
480 normalize_graph = onnx.parser.parse_graph(
481 f"""\
482 normalize ({input_type_str}[{input_shape_str}] {self.input_names[0]})
483 => (float[{input_shape_str}] {self.output_names[0]})
484 {{
485 kMean = Constant <value = float[{values_shape}] {{{mean0}, {mean1}, {mean2}}}> ()
486 kStddev = Constant <value = float[{values_shape}] {{{stddev0}, {stddev1}, {stddev2}}}> ()
487 f_input = Cast <to = 1> ({self.input_names[0]})
488 f_sub_mean = Sub (f_input, kMean)
489 {self.output_names[0]} = Div (f_sub_mean, kStddev)
490 }}
491 """
492 )
493
494 onnx.checker.check_graph(normalize_graph)
495 return normalize_graph
496
497
498#
499# Utilities
500#
501class ImageBytesToFloat(Step):
502 """
503 Convert uint8 or float values in range 0..255 to floating point values in range 0..1
504 """
505
506 def __init__(self, name: Optional[str] = None):
507 """
508 Args:
509 name: Optional step name. Defaults to 'ImageBytesToFloat'
510 """
511 super().__init__(["data"], ["float_data"], name)
512
513 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
514 input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
515 if input_type_str == "uint8":
516 optional_cast = f"""\
517 input_f = Cast <to = 1> ({self.input_names[0]})
518 """
519 else:
520 # no-op that optimizer will remove
521 optional_cast = f"input_f = Identity ({self.input_names[0]})"
522
523 byte_to_float_graph = onnx.parser.parse_graph(
524 f"""\
525 byte_to_float ({input_type_str}[{input_shape_str}] {self.input_names[0]})
526 => (float[{input_shape_str}] {self.output_names[0]})
527 {{
528 f_255 = Constant <value = float[1] {{255.0}}>()
529
530 {optional_cast}
531 {self.output_names[0]} = Div(input_f, f_255)
532 }}
533 """
534 )
535
536 onnx.checker.check_graph(byte_to_float_graph)
537 return byte_to_float_graph
538
539
540class FloatToImageBytes(Step):
541 """
542 Converting floating point values to uint8 values in range 0..255.
543 Typically this reverses ImageBytesToFloat by converting input data in the range 0..1, but an optional multiplier
544 can be specified if the input data has a different range.
545 Values will be rounded prior to clipping and conversion to uint8.
546 """
547
548 def __init__(self, multiplier: float = 255.0, name: Optional[str] = None):
549 """
550 Args:
551 multiplier: Optional multiplier. Currently, the expected values are 255 (input data is in range 0..1), or
552 1 (input data is in range 0..255).
553 name: Optional step name. Defaults to 'FloatToImageBytes'
554 """
555 super().__init__(["float_data"], ["pixel_data"], name)
556 self._multiplier = multiplier
557
558 def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
559 input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
560 assert input_type_str == "float"
561
562 if self._multiplier == 1.0:
563 scale_input = ''
564 scaled_input_name = self.input_names[0]
565 else:
566 scale_input = \
567 f"""\
568 f_multiplier = Constant <value = float[1] {{{self._multiplier}}}> ()
569 scaled_input = Mul ({self.input_names[0]}, f_multiplier)
570 """
571 scaled_input_name = 'scaled_input'
572
573 float_to_byte_graphs = onnx.parser.parse_graph(
574 f"""\
575 float_to_type (float[{input_shape_str}] {self.input_names[0]})
576 => (uint8[{input_shape_str}] {self.output_names[0]})
577 {{
578 f_0 = Constant <value = float[1] {{0.0}}> ()
579 f_255 = Constant <value = float[1] {{255.0}}>()
580
581 {scale_input}
582 rounded = Round ({scaled_input_name})
583 clipped = Clip (rounded, f_0, f_255)
584 {self.output_names[0]} = Cast <to = {onnx.TensorProto.UINT8}> (clipped)
585 }}
586 """
587 )
588
589 onnx.checker.check_graph(float_to_byte_graphs)
590 return float_to_byte_graphs
591
592
593class ChannelsLastToChannelsFirst(Transpose):
594 """
595 Convert channels last data to channels first.
596 Input can be NHWC or HWC.
597 """
598
599 def __init__(self, has_batch_dim: bool = False, name: Optional[str] = None):
600 """
601 Args:
602 has_batch_dim: Set to True if the input has a batch dimension (i.e. is NHWC)
603 name: Optional step name. Defaults to 'ChannelsLastToChannelsFirst'
604 """
605 perms = [0, 3, 1, 2] if has_batch_dim else [2, 0, 1]
606 super().__init__(perms, name)
607