microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
rel-0.7

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/tools/add_pre_post_processing_to_model.py

359lines · modecode

1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License.
3
4import argparse
5import enum
6import onnx
7import os
8
9from pathlib import Path
10from typing import Union
11# NOTE: If you're working on this script install onnxruntime_extensions using `pip install -e .` from the repo root
12# and run with `python -m onnxruntime_extensions.tools.add_pre_post_processing_to_model`
13# Running directly will result in an error from a relative import.
14from .pre_post_processing import *
15
16
17class ModelSource(enum.Enum):
18 PYTORCH = 0
19 TENSORFLOW = 1
20 OTHER = 2
21
22
23def imagenet_preprocessing(model_source: ModelSource = ModelSource.PYTORCH):
24 """
25 Common pre-processing for an imagenet trained model.
26
27 - Resize so smallest side is 256
28 - Centered crop to 224 x 224
29 - Convert image bytes to floating point values in range 0..1
30 - [Channels last to channels first (convert to ONNX layout) if model came from pytorch and has NCHW layout]
31 - Normalize
32 - (value - mean) / stddev
33 - for a pytorch model, this applies per-channel normalization parameters
34 - for a tensorflow model this simply moves the image bytes into the range -1..1
35 - adds a batch dimension with a value of 1
36 """
37
38 # These utils cover both cases of typical pytorch/tensorflow pre-processing for an imagenet trained model
39 # https://github.com/keras-team/keras/blob/b80dd12da9c0bc3f569eca3455e77762cf2ee8ef/keras/applications/imagenet_utils.py#L177
40
41 steps = [
42 Resize(256),
43 CenterCrop(224, 224),
44 ImageBytesToFloat()
45 ]
46
47 if model_source == ModelSource.PYTORCH:
48 # pytorch model has NCHW layout
49 steps.extend([
50 ChannelsLastToChannelsFirst(),
51 Normalize([(0.485, 0.229), (0.456, 0.224), (0.406, 0.225)], layout="CHW")
52 ])
53 else:
54 # TF processing involves moving the data into the range -1..1 instead of 0..1.
55 # ImageBytesToFloat converts to range 0..1, so we use 0.5 for the mean to move into the range -0.5..0.5
56 # and 0.5 for the stddev to expand to -1..1
57 steps.append(Normalize([(0.5, 0.5)], layout="HWC"))
58
59 steps.append(Unsqueeze([0])) # add batch dim
60
61 return steps
62
63
64def mobilenet(model_file: Path, output_file: Path, model_source: ModelSource, onnx_opset: int = 16):
65 model = onnx.load(str(model_file.resolve(strict=True)))
66 inputs = [create_named_value("image", onnx.TensorProto.UINT8, ["num_bytes"])]
67
68 pipeline = PrePostProcessor(inputs, onnx_opset)
69
70 # support user providing encoded image bytes
71 preprocessing = [
72 ConvertImageToBGR(), # custom op to convert jpg/png to BGR (output is HWC)
73 ReverseAxis(axis=2, dim_value=3, name="BGR_to_RGB"),
74 ] # Normalization params are for RGB ordering
75 # plug in default imagenet pre-processing
76 preprocessing.extend(imagenet_preprocessing(model_source))
77
78 pipeline.add_pre_processing(preprocessing)
79
80 # for mobilenet we convert the score to probabilities with softmax if necessary. the TF model includes Softmax
81 if model.graph.node[-1].op_type != "Softmax":
82 pipeline.add_post_processing([Softmax()])
83
84 new_model = pipeline.run(model)
85
86 onnx.save_model(new_model, str(output_file.resolve()))
87
88
89def superresolution(model_file: Path, output_file: Path, output_format: str, onnx_opset: int = 16):
90 # TODO: There seems to be a split with some super resolution models processing RGB input and some processing
91 # the Y channel after converting to YCbCr.
92 # For the sake of this example implementation we do the trickier YCbCr processing as that involves joining the
93 # Cb and Cr channels with the model output to create the resized image.
94 # Model is from https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
95 model = onnx.load(str(model_file.resolve(strict=True)))
96 inputs = [create_named_value("image", onnx.TensorProto.UINT8, ["num_bytes"])]
97
98 # assuming input is *CHW, infer the input sizes from the model.
99 # requires the model input and output has a fixed size for the input and output height and width.
100 model_input_shape = model.graph.input[0].type.tensor_type.shape
101 model_output_shape = model.graph.output[0].type.tensor_type.shape
102 assert model_input_shape.dim[-1].HasField("dim_value")
103 assert model_input_shape.dim[-2].HasField("dim_value")
104 assert model_output_shape.dim[-1].HasField("dim_value")
105 assert model_output_shape.dim[-2].HasField("dim_value")
106
107 w_in = model_input_shape.dim[-1].dim_value
108 h_in = model_input_shape.dim[-2].dim_value
109 h_out = model_output_shape.dim[-2].dim_value
110 w_out = model_output_shape.dim[-1].dim_value
111
112 # pre/post processing for https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
113 pipeline = PrePostProcessor(inputs, onnx_opset)
114 pipeline.add_pre_processing(
115 [
116 ConvertImageToBGR(), # jpg/png image to BGR in HWC layout
117 Resize((h_in, w_in)),
118 CenterCrop(h_in, w_in),
119 # this produces Y, Cb and Cr outputs. each has shape {h_in, w_in}. only Y is input to model
120 PixelsToYCbCr(layout="BGR"),
121 # if you inserted this Debug step here the 3 outputs from PixelsToYCbCr would also be model outputs
122 # Debug(num_inputs=3),
123 ImageBytesToFloat(), # Convert Y to float in range 0..1
124 Unsqueeze([0, 1]), # add batch and channels dim to Y so shape is {1, 1, h_in, w_in}
125 ]
126 )
127
128 # Post-processing is complicated here. resize the Cb and Cr outputs from the pre-processing to match
129 # the model output size, merge those with the Y` model output, and convert back to RGB.
130
131 # create the Steps we need to use in the manual connections
132 pipeline.add_post_processing(
133 [
134 Squeeze([0, 1]), # remove batch and channels dims from Y'
135 FloatToImageBytes(name="Y1_uint8"), # convert Y' to uint8 in range 0..255
136
137 # Resize the Cb values (output 1 from PixelsToYCbCr)
138 (Resize((h_out, w_out), "HW"),
139 [IoMapEntry(producer="PixelsToYCbCr", producer_idx=1, consumer_idx=0)]),
140
141 # the Cb and Cr values are already in the range 0..255 so multiplier is 1. we're using the step to round
142 # for accuracy (a direct Cast would just truncate) and clip (to ensure range 0..255) the values post-Resize
143 FloatToImageBytes(multiplier=1.0, name="Cb1_uint8"),
144
145 (Resize((h_out, w_out), "HW"), [IoMapEntry("PixelsToYCbCr", 2, 0)]),
146 FloatToImageBytes(multiplier=1.0, name="Cr1_uint8"),
147
148 # as we're selecting outputs from multiple previous steps we need to map them to the inputs using step names
149 (
150 YCbCrToPixels(layout="BGR"),
151 [
152 IoMapEntry("Y1_uint8", 0, 0), # uint8 Y' with shape {h, w}
153 IoMapEntry("Cb1_uint8", 0, 1),
154 IoMapEntry("Cr1_uint8", 0, 2),
155 ],
156 ),
157 ConvertBGRToImage(image_format=output_format), # jpg or png are supported
158 ]
159 )
160
161 new_model = pipeline.run(model)
162 onnx.save_model(new_model, str(output_file.resolve()))
163
164
165class NLPTaskType(enum.Enum):
166 TokenClassification = enum.auto()
167 QuestionAnswering = enum.auto()
168 SequenceClassification = enum.auto()
169 NextSentencePrediction = enum.auto()
170
171
172class TokenizerType(enum.Enum):
173 BertTokenizer = enum.auto()
174 SentencePieceTokenizer = enum.auto()
175
176
177def transformers_and_bert(
178 input_model_file: Path,
179 output_model_file: Path,
180 vocab_file: Path,
181 tokenizer_type: Union[TokenizerType, str],
182 task_type: Union[NLPTaskType, str],
183 onnx_opset: int = 16,
184 add_debug_before_postprocessing=False,
185):
186 """construct the pipeline for a end2end model with pre and post processing. The final model can take text as inputs
187 and output the result in text format for model like QA.
188
189 Args:
190 input_model_file (Path): the model file needed to be updated.
191 output_model_file (Path): where to save the final onnx model.
192 vocab_file (Path): the vocab file for the tokenizer.
193 task_type (Union[NLPTaskType, str]): the task type of the model.
194 onnx_opset (int, optional): the opset version to use. Defaults to 16.
195 add_debug_before_postprocessing (bool, optional): whether to add a debug step before post processing.
196 Defaults to False.
197 """
198 if isinstance(task_type, str):
199 task_type = NLPTaskType[task_type]
200 if isinstance(tokenizer_type, str):
201 tokenizer_type = TokenizerType[tokenizer_type]
202
203 onnx_model = onnx.load(str(input_model_file.resolve(strict=True)))
204 # hardcode batch size to 1
205 inputs = [create_named_value("input_text", onnx.TensorProto.STRING, [1, "num_sentences"])]
206
207 pipeline = PrePostProcessor(inputs, onnx_opset)
208 tokenizer_args = TokenizerParam(
209 vocab_or_file=vocab_file,
210 do_lower_case=True,
211 tweaked_bos_id=0,
212 is_sentence_pair=True if task_type in [NLPTaskType.QuestionAnswering,
213 NLPTaskType.NextSentencePrediction] else False,
214 )
215
216 preprocessing = [
217 SentencePieceTokenizer(tokenizer_args)
218 if tokenizer_type == TokenizerType.SentencePieceTokenizer else BertTokenizer(tokenizer_args),
219 # uncomment this line to debug
220 # Debug(2),
221 ]
222
223 # For verify results with out postprocessing
224 postprocessing = [Debug()] if add_debug_before_postprocessing else []
225 if task_type == NLPTaskType.QuestionAnswering:
226 postprocessing.append((BertTokenizerQADecoder(tokenizer_args), [
227 # input_ids
228 utils.IoMapEntry("BertTokenizer", producer_idx=0, consumer_idx=2)]))
229 elif task_type == NLPTaskType.SequenceClassification:
230 postprocessing.append(ArgMax())
231 # the other tasks don't need postprocessing or we don't support it yet.
232
233 pipeline.add_pre_processing(preprocessing)
234 pipeline.add_post_processing(postprocessing)
235
236 new_model = pipeline.run(onnx_model)
237 onnx.save_model(new_model, str(output_model_file.resolve()))
238
239
240def main():
241 parser = argparse.ArgumentParser(
242 os.path.basename(__file__),
243 description="""Add pre and post processing to a model.
244
245 Currently supports updating:
246 Vision models:
247 - super resolution with YCbCr input
248 - imagenet trained mobilenet
249 NLP models:
250
251 - MobileBert with different tasks
252 - XLM-Roberta with classification task
253
254 For Vision models:
255 To customize, the logic in the `mobilenet` and `superresolution` functions can be used as a guide.
256 Create a pipeline and add the required pre/post processing 'Steps' in the order required. Configure
257 individual steps as needed.
258
259 For NLP models:
260 `transformers_and_bert` can be used for MobileBert QuestionAnswering/Classification tasks,
261 or serve as a guide of how to add pre/post processing to a transformer model.
262 Usually pre-processing includes adding a tokenizer. Post-processing includes conversion of output_ids to text.
263
264 You might need to pass the tokenizer model file (bert vocab file or SentencePieceTokenizer model)
265 and task_type to the function.
266
267 The updated model will be written in the same location as the original model,
268 with '.onnx' updated to '.with_pre_post_processing.onnx'
269 """,
270 )
271
272 parser.add_argument(
273 "-t",
274 "--model_type",
275 type=str,
276 required=True,
277 choices=[
278 "superresolution",
279 "mobilenet",
280 "transformers",
281 ],
282 help="Model type.",
283 )
284
285 parser.add_argument(
286 "-s",
287 "--model_source",
288 type=str,
289 required=False,
290 choices=["pytorch", "tensorflow"],
291 default="pytorch",
292 help="""
293 Framework that model came from. In some cases there are known differences that can be taken into account when
294 adding the pre/post processing to the model. Currently this equates to choosing different normalization
295 behavior for mobilenet models.
296 """,
297 )
298
299 parser.add_argument(
300 "--output_format",
301 type=str,
302 required=False,
303 choices=["jpg", "png"],
304 default="png",
305 help="Image output format for superresolution model to produce.",
306 )
307
308 parser.add_argument(
309 "--nlp_task_type",
310 type=str,
311 choices=["QuestionAnswering",
312 "SequenceClassification",
313 "NextSentencePrediction"],
314 required=False,
315 help="The downstream task for NLP model.",
316 )
317
318 parser.add_argument(
319 "--vocab_file",
320 type=Path,
321 required=False,
322 help="Tokenizer model file for BertTokenizer or SentencePieceTokenizer.",
323 )
324
325 parser.add_argument(
326 "--tokenizer_type",
327 type=str,
328 choices=["BertTokenizer",
329 "SentencePieceTokenizer"],
330 required=False,
331 help="Tokenizer model file for BertTokenizer or SentencePieceTokenizer.",
332 )
333
334 parser.add_argument(
335 "--opset", type=int, required=False, default=16,
336 help="ONNX opset to use. Minimum allowed is 16. Opset 18 is required for Resize with anti-aliasing.",
337 )
338
339 parser.add_argument("model", type=Path, help="Provide path to ONNX model to update.")
340
341 args = parser.parse_args()
342
343 model_path = args.model.resolve(strict=True)
344 new_model_path = model_path.with_suffix(".with_pre_post_processing.onnx")
345
346 if args.model_type == "mobilenet":
347 source = ModelSource.PYTORCH if args.model_source == "pytorch" else ModelSource.TENSORFLOW
348 mobilenet(model_path, new_model_path, source, args.opset)
349 elif args.model_type == "superresolution":
350 superresolution(model_path, new_model_path,
351 args.output_format, args.opset)
352 else:
353 if args.vocab_file is None or args.nlp_task_type is None or args.tokenizer_type is None:
354 parser.error("Please provide vocab file/nlp_task_type/tokenizer_type.")
355 transformers_and_bert(model_path, new_model_path, args.tokenizer_type, args.vocab_file, args.nlp_task_type)
356
357
358if __name__ == "__main__":
359 main()
360