microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
natke-patch-1

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/tools/add_pre_post_processing_to_model.py

512lines · modecode

1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License.
3
4import argparse
5import enum
6import onnx
7import os
8
9from pathlib import Path
10from typing import Union
11# NOTE: If you're working on this script install onnxruntime_extensions using `pip install -e .` from the repo root
12# and run with `python -m onnxruntime_extensions.tools.add_pre_post_processing_to_model`
13# Running directly will result in an error from a relative import.
14from .pre_post_processing import *
15
16
17class ModelSource(enum.Enum):
18 PYTORCH = 0
19 TENSORFLOW = 1
20 OTHER = 2
21
22
23def imagenet_preprocessing(model_source: ModelSource = ModelSource.PYTORCH):
24 """
25 Common pre-processing for an imagenet trained model.
26
27 - Resize so smallest side is 256
28 - Centered crop to 224 x 224
29 - Convert image bytes to floating point values in range 0..1
30 - [Channels last to channels first (convert to ONNX layout) if model came from pytorch and has NCHW layout]
31 - Normalize
32 - (value - mean) / stddev
33 - for a pytorch model, this applies per-channel normalization parameters
34 - for a tensorflow model this simply moves the image bytes into the range -1..1
35 - adds a batch dimension with a value of 1
36 """
37
38 # These utils cover both cases of typical pytorch/tensorflow pre-processing for an imagenet trained model
39 # https://github.com/keras-team/keras/blob/b80dd12da9c0bc3f569eca3455e77762cf2ee8ef/keras/applications/imagenet_utils.py#L177
40
41 steps = [
42 Resize(256),
43 CenterCrop(224, 224),
44 ImageBytesToFloat()
45 ]
46
47 if model_source == ModelSource.PYTORCH:
48 # pytorch model has NCHW layout
49 steps.extend([
50 ChannelsLastToChannelsFirst(),
51 Normalize([(0.485, 0.229), (0.456, 0.224), (0.406, 0.225)], layout="CHW")
52 ])
53 else:
54 # TF processing involves moving the data into the range -1..1 instead of 0..1.
55 # ImageBytesToFloat converts to range 0..1, so we use 0.5 for the mean to move into the range -0.5..0.5
56 # and 0.5 for the stddev to expand to -1..1
57 steps.append(Normalize([(0.5, 0.5)], layout="HWC"))
58
59 steps.append(Unsqueeze([0])) # add batch dim
60
61 return steps
62
63
64def mobilenet(model_file: Path, output_file: Path, model_source: ModelSource, onnx_opset: int = 16):
65 model = onnx.load(str(model_file.resolve(strict=True)))
66 inputs = [create_named_value("image", onnx.TensorProto.UINT8, ["num_bytes"])]
67
68 pipeline = PrePostProcessor(inputs, onnx_opset)
69
70 # support user providing encoded image bytes
71 preprocessing = [
72 ConvertImageToBGR(), # custom op to convert jpg/png to BGR (output is HWC)
73 ReverseAxis(axis=2, dim_value=3, name="BGR_to_RGB"),
74 ] # Normalization params are for RGB ordering
75 # plug in default imagenet pre-processing
76 preprocessing.extend(imagenet_preprocessing(model_source))
77
78 pipeline.add_pre_processing(preprocessing)
79
80 # for mobilenet we convert the score to probabilities with softmax if necessary. the TF model includes Softmax
81 if model.graph.node[-1].op_type != "Softmax":
82 pipeline.add_post_processing([Softmax()])
83
84 new_model = pipeline.run(model)
85
86 onnx.save_model(new_model, str(output_file.resolve()))
87
88
89def superresolution(model_file: Path, output_file: Path, output_format: str, onnx_opset: int = 16):
90 # TODO: There seems to be a split with some super resolution models processing RGB input and some processing
91 # the Y channel after converting to YCbCr.
92 # For the sake of this example implementation we do the trickier YCbCr processing as that involves joining the
93 # Cb and Cr channels with the model output to create the resized image.
94 # Model is from https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
95 model = onnx.load(str(model_file.resolve(strict=True)))
96 inputs = [create_named_value("image", onnx.TensorProto.UINT8, ["num_bytes"])]
97
98 # assuming input is *CHW, infer the input sizes from the model.
99 # requires the model input and output has a fixed size for the input and output height and width.
100 model_input_shape = model.graph.input[0].type.tensor_type.shape
101 model_output_shape = model.graph.output[0].type.tensor_type.shape
102 assert model_input_shape.dim[-1].HasField("dim_value")
103 assert model_input_shape.dim[-2].HasField("dim_value")
104 assert model_output_shape.dim[-1].HasField("dim_value")
105 assert model_output_shape.dim[-2].HasField("dim_value")
106
107 w_in = model_input_shape.dim[-1].dim_value
108 h_in = model_input_shape.dim[-2].dim_value
109 h_out = model_output_shape.dim[-2].dim_value
110 w_out = model_output_shape.dim[-1].dim_value
111
112 # pre/post processing for https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
113 pipeline = PrePostProcessor(inputs, onnx_opset)
114 pipeline.add_pre_processing(
115 [
116 ConvertImageToBGR(), # jpg/png image to BGR in HWC layout
117 Resize((h_in, w_in)),
118 CenterCrop(h_in, w_in),
119 # this produces Y, Cb and Cr outputs. each has shape {h_in, w_in}. only Y is input to model
120 PixelsToYCbCr(layout="BGR"),
121 # if you inserted this Debug step here the 3 outputs from PixelsToYCbCr would also be model outputs
122 # Debug(num_inputs=3),
123 ImageBytesToFloat(), # Convert Y to float in range 0..1
124 Unsqueeze([0, 1]), # add batch and channels dim to Y so shape is {1, 1, h_in, w_in}
125 ]
126 )
127
128 # Post-processing is complicated here. resize the Cb and Cr outputs from the pre-processing to match
129 # the model output size, merge those with the Y` model output, and convert back to RGB.
130
131 # create the Steps we need to use in the manual connections
132 pipeline.add_post_processing(
133 [
134 Squeeze([0, 1]), # remove batch and channels dims from Y'
135 FloatToImageBytes(name="Y1_uint8"), # convert Y' to uint8 in range 0..255
136
137 # Resize the Cb values (output 1 from PixelsToYCbCr)
138 (Resize((h_out, w_out), "HW"),
139 [IoMapEntry(producer="PixelsToYCbCr", producer_idx=1, consumer_idx=0)]),
140
141 # the Cb and Cr values are already in the range 0..255 so multiplier is 1. we're using the step to round
142 # for accuracy (a direct Cast would just truncate) and clip (to ensure range 0..255) the values post-Resize
143 FloatToImageBytes(multiplier=1.0, name="Cb1_uint8"),
144
145 (Resize((h_out, w_out), "HW"), [IoMapEntry("PixelsToYCbCr", 2, 0)]),
146 FloatToImageBytes(multiplier=1.0, name="Cr1_uint8"),
147
148 # as we're selecting outputs from multiple previous steps we need to map them to the inputs using step names
149 (
150 YCbCrToPixels(layout="BGR"),
151 [
152 IoMapEntry("Y1_uint8", 0, 0), # uint8 Y' with shape {h, w}
153 IoMapEntry("Cb1_uint8", 0, 1),
154 IoMapEntry("Cr1_uint8", 0, 2),
155 ],
156 ),
157 ConvertBGRToImage(image_format=output_format), # jpg or png are supported
158 ]
159 )
160
161 new_model = pipeline.run(model)
162 onnx.save_model(new_model, str(output_file.resolve()))
163
164
165def yolo_detection(model_file: Path, output_file: Path, output_format: str = 'jpg',
166 onnx_opset: int = 16, num_classes: int = 80, input_shape: List[int] = None):
167 """
168 SSD-like model and Faster-RCNN-like model are including NMS inside already, You can find it from onnx model zoo.
169
170 A pure detection model accept fix-sized(say 1,3,640,640) image as input, and output a list of bounding boxes, which
171 the numbers are determinate by anchors.
172
173 This function target for Yolo detection model. It support YOLOv3-yolov8 models theoretically.
174 You should assure this model has only one input, and the input shape is [1, 3, h, w].
175 The model has either one or more outputs.
176 If the model has one output, the output shape is [1,num_boxes, coor+(obj)+cls]
177 or [1, coor+(obj)+cls, num_boxes].
178 If the model has more than one outputs, you should assure the first output shape is
179 [1, num_boxes, coor+(obj)+cls] or [1, coor+(obj)+cls, num_boxes].
180 Note: (obj) means it's optional.
181
182 :param model_file: The input model file path.
183 :param output_file: The output file path, where the finalized model saved to.
184 :param output_format: The output image format, jpg or png.
185 :param onnx_opset: The opset version of onnx model, default(16).
186 :param num_classes: The number of classes, default(80).
187 :param input_shape: The shape of input image (height,width), default will be asked from model input.
188 """
189 model = onnx.load(str(model_file.resolve(strict=True)))
190 inputs = [create_named_value("image", onnx.TensorProto.UINT8, ["num_bytes"])]
191
192 model_input_shape = model.graph.input[0].type.tensor_type.shape
193 model_output_shape = model.graph.output[0].type.tensor_type.shape
194
195 # We will use the input_shape to create the model if provided by user.
196 if input_shape is not None:
197 assert len(input_shape) == 2, "The input_shape should be [h, w]."
198 w_in = input_shape[1]
199 h_in = input_shape[0]
200 else:
201 assert (model_input_shape.dim[-1].HasField("dim_value") and
202 model_input_shape.dim[-2].HasField("dim_value")), "please provide input_shape in the command args."
203
204 w_in = model_input_shape.dim[-1].dim_value
205 h_in = model_input_shape.dim[-2].dim_value
206
207 # Yolov5(v3,v7) has an output of shape (batchSize, 25200, 85) (Num classes + box[x,y,w,h] + confidence[c])
208 # Yolov8 has an output of shape (batchSize, 84, 8400) (Num classes + box[x,y,w,h])
209 # https://github.com/ultralytics/ultralytics/blob/e5cb35edfc3bbc9d7d7db8a6042778a751f0e39e/examples/YOLOv8-CPP-Inference/inference.cpp#L31-L33
210 # We always want the box info to be the last dim for each of iteration.
211 # For new variants like YoloV8, we need to add an transpose op to permute output back.
212 need_transpose = False
213
214 output_shape = [model_output_shape.dim[i].dim_value if model_output_shape.dim[i].HasField("dim_value") else -1
215 for i in [-2, -1]]
216 if output_shape[0] != -1 and output_shape[1] != -1:
217 need_transpose = output_shape[0] < output_shape[1]
218 else:
219 assert len(model.graph.input) == 1, "Doesn't support adding pre and post-processing for multi-inputs model."
220 try:
221 import numpy as np
222 import onnxruntime
223 except ImportError:
224 raise ImportError(
225 """Please install onnxruntime and numpy to run this script. eg 'pip install onnxruntime numpy'.
226Because we need to execute the model to determine the output shape in order to add the correct post-processing""")
227
228 # Generate a random input to run the model and infer the output shape.
229 session = onnxruntime.InferenceSession(str(model_file), providers=["CPUExecutionProvider"])
230 input_name = session.get_inputs()[0].name
231 input_type = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[model.graph.input[0].type.tensor_type.elem_type]
232 inp = {input_name: np.random.rand(1, 3, h_in, w_in).astype(dtype=input_type)}
233 outputs = session.run(None, inp)[0]
234 assert len(outputs.shape) == 3 and outputs.shape[0] == 1, "shape of the first model output is not (1, n, m)"
235 if outputs.shape[1] < outputs.shape[2]:
236 need_transpose = True
237 assert num_classes+4 == outputs.shape[2] or num_classes+5 == outputs.shape[2], \
238 "The output shape is neither (1, num_boxes, num_classes+4(reg)) nor (1, num_boxes, num_classes+5(reg+obj))"
239
240 pipeline = PrePostProcessor(inputs, onnx_opset)
241 # precess steps are responsible for converting any jpg/png image to CHW BGR float32 tensor
242 # jpg-->BGR(Image Tensor)-->Resize (scaled Image)-->LetterBox (Fix sized Image)-->(from HWC to)CHW-->float32-->1CHW
243 pipeline.add_pre_processing(
244 [
245 ConvertImageToBGR(), # jpg/png image to BGR in HWC layout
246 # Resize an arbitrary sized image to a fixed size in not_larger policy
247 Resize((h_in, w_in), policy='not_larger'),
248 LetterBox(target_shape=(h_in, w_in)), # padding or cropping the image to (h_in, w_in)
249 ChannelsLastToChannelsFirst(), # HWC to CHW
250 ImageBytesToFloat(), # Convert to float in range 0..1
251 Unsqueeze([0]), # add batch, CHW --> 1CHW
252 ]
253 )
254 # NMS and drawing boxes
255 post_processing_steps = [
256 Squeeze([0]), # - Squeeze to remove batch dimension
257 SplitOutBoxAndScore(num_classes=num_classes), # Separate bounding box and confidence outputs
258 SelectBestBoundingBoxesByNMS(), # Apply NMS to suppress bounding boxes
259 (ScaleBoundingBoxes(), # Scale bounding box coords back to original image
260 [
261 # A connection from original image to ScaleBoundingBoxes
262 # A connection from the resized image to ScaleBoundingBoxes
263 # A connection from the LetterBoxed image to ScaleBoundingBoxes
264 # We can use the three image to calculate the scale factor and offset.
265 # With scale and offset, we can scale the bounding box back to the original image.
266 utils.IoMapEntry("ConvertImageToBGR", producer_idx=0, consumer_idx=1),
267 utils.IoMapEntry("Resize", producer_idx=0, consumer_idx=2),
268 utils.IoMapEntry("LetterBox", producer_idx=0, consumer_idx=3),
269 ]),
270 # DrawBoundingBoxes on the original image
271 # Model imported from pytorch has CENTER_XYWH format
272 # two mode for how to color box,
273 # 1. colour_by_classes=True, (colour_by_classes), 2. colour_by_classes=False,(colour_by_confidence)
274 (DrawBoundingBoxes(mode='CENTER_XYWH', num_classes=num_classes, colour_by_classes=True),
275 [
276 utils.IoMapEntry("ConvertImageToBGR", producer_idx=0, consumer_idx=0),
277 utils.IoMapEntry("ScaleBoundingBoxes", producer_idx=0, consumer_idx=1),
278 ]),
279 # Encode to jpg/png
280 ConvertBGRToImage(image_format=output_format),
281 ]
282 # transpose to (num_boxes, coor+conf) if needed
283 if need_transpose:
284 post_processing_steps.insert(1, Transpose([1, 0]))
285
286 pipeline.add_post_processing(post_processing_steps)
287
288 new_model = pipeline.run(model)
289 new_mode = onnx.shape_inference.infer_shapes(new_model)
290 onnx.save_model(new_model, str(output_file.resolve()))
291
292
293class NLPTaskType(enum.Enum):
294 TokenClassification = enum.auto()
295 QuestionAnswering = enum.auto()
296 SequenceClassification = enum.auto()
297 NextSentencePrediction = enum.auto()
298
299
300class TokenizerType(enum.Enum):
301 BertTokenizer = enum.auto()
302 SentencePieceTokenizer = enum.auto()
303
304
305def transformers_and_bert(
306 input_model_file: Path,
307 output_model_file: Path,
308 vocab_file: Path,
309 tokenizer_type: Union[TokenizerType, str],
310 task_type: Union[NLPTaskType, str],
311 onnx_opset: int = 16,
312 add_debug_before_postprocessing=False,
313):
314 """construct the pipeline for a end2end model with pre and post processing. The final model can take text as inputs
315 and output the result in text format for model like QA.
316
317 Args:
318 input_model_file (Path): the model file needed to be updated.
319 output_model_file (Path): where to save the final onnx model.
320 vocab_file (Path): the vocab file for the tokenizer.
321 task_type (Union[NLPTaskType, str]): the task type of the model.
322 onnx_opset (int, optional): the opset version to use. Defaults to 16.
323 add_debug_before_postprocessing (bool, optional): whether to add a debug step before post processing.
324 Defaults to False.
325 """
326 if isinstance(task_type, str):
327 task_type = NLPTaskType[task_type]
328 if isinstance(tokenizer_type, str):
329 tokenizer_type = TokenizerType[tokenizer_type]
330
331 onnx_model = onnx.load(str(input_model_file.resolve(strict=True)))
332 # hardcode batch size to 1
333 inputs = [create_named_value("input_text", onnx.TensorProto.STRING, [1, "num_sentences"])]
334
335 pipeline = PrePostProcessor(inputs, onnx_opset)
336 tokenizer_args = TokenizerParam(
337 vocab_or_file=vocab_file,
338 do_lower_case=True,
339 tweaked_bos_id=0,
340 is_sentence_pair=True if task_type in [NLPTaskType.QuestionAnswering,
341 NLPTaskType.NextSentencePrediction] else False,
342 )
343
344 preprocessing = [
345 SentencePieceTokenizer(tokenizer_args)
346 if tokenizer_type == TokenizerType.SentencePieceTokenizer else BertTokenizer(tokenizer_args),
347 # uncomment this line to debug
348 # Debug(2),
349 ]
350
351 # For verify results with out postprocessing
352 postprocessing = [Debug()] if add_debug_before_postprocessing else []
353 if task_type == NLPTaskType.QuestionAnswering:
354 postprocessing.append((BertTokenizerQADecoder(tokenizer_args), [
355 # input_ids
356 utils.IoMapEntry("BertTokenizer", producer_idx=0, consumer_idx=2)]))
357 elif task_type == NLPTaskType.SequenceClassification:
358 postprocessing.append(ArgMax())
359 # the other tasks don't need postprocessing or we don't support it yet.
360
361 pipeline.add_pre_processing(preprocessing)
362 pipeline.add_post_processing(postprocessing)
363
364 new_model = pipeline.run(onnx_model)
365 onnx.save_model(new_model, str(output_model_file.resolve()))
366
367
368def main():
369 parser = argparse.ArgumentParser(
370 os.path.basename(__file__),
371 description="""Add pre and post processing to a model.
372
373 Currently supports updating:
374 Vision models:
375 - super resolution with YCbCr input
376 - imagenet trained mobilenet
377 - object detection with YOLOv3-YOLOV8
378
379 NLP models:
380 - MobileBert with different tasks
381 - XLM-Roberta with classification task
382
383 For Vision models:
384 To customize, the logic in the `mobilenet`, `superresolution` and `yolo_detection` functions can be used as a guide.
385 Create a pipeline and add the required pre/post processing 'Steps' in the order required. Configure
386 individual steps as needed.
387
388 For NLP models:
389 `transformers_and_bert` can be used for MobileBert QuestionAnswering/Classification tasks,
390 or serve as a guide of how to add pre/post processing to a transformer model.
391 Usually pre-processing includes adding a tokenizer. Post-processing includes conversion of output_ids to text.
392
393 You might need to pass the tokenizer model file (bert vocab file or SentencePieceTokenizer model)
394 and task_type to the function.
395
396 The updated model will be written in the same location as the original model,
397 with '.onnx' updated to '.with_pre_post_processing.onnx'
398
399 Example usage:
400 object detection:
401 - python -m onnxruntime_extensions.tools.add_pre_post_processing_to_model -t yolo -num_classes 80 --input_shape 640,640 yolov8n.onnx
402 """,
403 )
404
405 parser.add_argument(
406 "-t",
407 "--model_type",
408 type=str,
409 required=True,
410 choices=[
411 "superresolution",
412 "mobilenet",
413 "yolo",
414 "transformers",
415 ],
416 help="Model type.",
417 )
418
419 parser.add_argument(
420 "-s",
421 "--model_source",
422 type=str,
423 required=False,
424 choices=["pytorch", "tensorflow"],
425 default="pytorch",
426 help="""
427 Framework that model came from. In some cases there are known differences that can be taken into account when
428 adding the pre/post processing to the model. Currently this equates to choosing different normalization
429 behavior for mobilenet models.
430 """,
431 )
432
433 parser.add_argument(
434 "--output_format",
435 type=str,
436 required=False,
437 choices=["jpg", "png"],
438 default="png",
439 help="Image output format for superresolution model to produce.",
440 )
441
442 parser.add_argument(
443 "--num_classes",
444 type=int,
445 default=80,
446 help="Number of classes in object detection model.",
447 )
448
449 parser.add_argument(
450 "--input_shape",
451 type=str,
452 default="",
453 help="To specify input image shape(height,width) for the model. Such as \"224,224\", \
454 Tools will ask onnx model for input shape if input_shape is not specified.",
455 )
456
457 parser.add_argument(
458 "--nlp_task_type",
459 type=str,
460 choices=["QuestionAnswering",
461 "SequenceClassification",
462 "NextSentencePrediction"],
463 required=False,
464 help="The downstream task for NLP model.",
465 )
466
467 parser.add_argument(
468 "--vocab_file",
469 type=Path,
470 required=False,
471 help="Tokenizer model file for BertTokenizer or SentencePieceTokenizer.",
472 )
473
474 parser.add_argument(
475 "--tokenizer_type",
476 type=str,
477 choices=["BertTokenizer",
478 "SentencePieceTokenizer"],
479 required=False,
480 help="Tokenizer model file for BertTokenizer or SentencePieceTokenizer.",
481 )
482
483 parser.add_argument(
484 "--opset", type=int, required=False, default=16,
485 help="ONNX opset to use. Minimum allowed is 16. Opset 18 is required for Resize with anti-aliasing.",
486 )
487
488 parser.add_argument("model", type=Path, help="Provide path to ONNX model to update.")
489
490 args = parser.parse_args()
491
492 model_path = args.model.resolve(strict=True)
493 new_model_path = model_path.with_suffix(".with_pre_post_processing.onnx")
494
495 if args.model_type == "mobilenet":
496 source = ModelSource.PYTORCH if args.model_source == "pytorch" else ModelSource.TENSORFLOW
497 mobilenet(model_path, new_model_path, source, args.opset)
498 elif args.model_type == "superresolution":
499 superresolution(model_path, new_model_path, args.output_format, args.opset)
500 elif args.model_type == "yolo":
501 input_shape = None
502 if args.input_shape != "":
503 input_shape = [int(x) for x in args.input_shape.split(",")]
504 yolo_detection(model_path, new_model_path, args.output_format, args.opset, args.num_classes, input_shape)
505 else:
506 if args.vocab_file is None or args.nlp_task_type is None or args.tokenizer_type is None:
507 parser.error("Please provide vocab file/nlp_task_type/tokenizer_type.")
508 transformers_and_bert(model_path, new_model_path, args.tokenizer_type, args.vocab_file, args.nlp_task_type)
509
510
511if __name__ == "__main__":
512 main()
513