microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
onnxruntime_extensions/tools/pre_post_processing/utils.py
112lines · modecode
| 1 | # Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | # Licensed under the MIT License. |
| 3 | |
| 4 | import onnx |
| 5 | |
| 6 | from dataclasses import dataclass |
| 7 | from typing import List, Union |
| 8 | |
| 9 | |
| 10 | def create_named_value(name: str, data_type: int, shape: List[Union[str, int]]): |
| 11 | """ |
| 12 | Helper to create a new model input. |
| 13 | |
| 14 | Args: |
| 15 | name: Name for input. Must not already be in use in the model being updated. |
| 16 | data_type: onnx.TensorProto data type. e.g. onnx.TensorProto.FLOAT, onnx.TensorProto.UINT8 |
| 17 | shape: Input shape. Use int for dimensions with known values and strings for symbolic dimensions. |
| 18 | e.g. ['batch_size', 256, 256] would be a rank 3 tensor with a symbolic first dimension named 'batch_size' |
| 19 | |
| 20 | |
| 21 | Returns: |
| 22 | An onnx.ValueInfoProto that can be used as a new model input. |
| 23 | """ |
| 24 | tensor_type = onnx.helper.make_tensor_type_proto(elem_type=data_type, shape=shape) |
| 25 | return onnx.helper.make_value_info(name, tensor_type) |
| 26 | |
| 27 | |
| 28 | # Create an onnx checker context that includes the ort-ext domain so that custom ops don't cause failure |
| 29 | def create_custom_op_checker_context(onnx_opset: int): |
| 30 | """ |
| 31 | Create an ONNX checker context that includes the ort-extensions custom op domains so that custom ops don't |
| 32 | cause failure when running onnx.checker.check_graph. |
| 33 | |
| 34 | Args: |
| 35 | onnx_opset: ONNX opset to use in the checker context. |
| 36 | |
| 37 | Returns: |
| 38 | ONNX checker context. |
| 39 | """ |
| 40 | context = onnx.checker.C.CheckerContext() |
| 41 | context.ir_version = onnx.checker.DEFAULT_CONTEXT.ir_version |
| 42 | context.opset_imports = {"": onnx_opset, "com.microsoft.extensions": 1} |
| 43 | |
| 44 | return context |
| 45 | |
| 46 | |
| 47 | # The ONNX graph parser has it's own map of names just to be special |
| 48 | # https://github.com/onnx/onnx/blob/604af9cb28f63a6b9924237dcb91530649233db9/onnx/defs/parser.h#L72 |
| 49 | TENSOR_TYPE_TO_ONNX_TYPE = { |
| 50 | int(onnx.TensorProto.FLOAT): "float", |
| 51 | int(onnx.TensorProto.UINT8): "uint8", |
| 52 | int(onnx.TensorProto.INT8): "int8", |
| 53 | int(onnx.TensorProto.UINT16): "uint16", |
| 54 | int(onnx.TensorProto.INT16): "int16", |
| 55 | int(onnx.TensorProto.INT32): "int32", |
| 56 | int(onnx.TensorProto.INT64): "int64", |
| 57 | int(onnx.TensorProto.STRING): "string", |
| 58 | int(onnx.TensorProto.BOOL): "bool", |
| 59 | int(onnx.TensorProto.FLOAT16): "float16", |
| 60 | int(onnx.TensorProto.DOUBLE): "double", |
| 61 | int(onnx.TensorProto.UINT32): "uint32", |
| 62 | int(onnx.TensorProto.UINT64): "uint64", |
| 63 | int(onnx.TensorProto.COMPLEX64): "complex64", |
| 64 | int(onnx.TensorProto.COMPLEX128): "complex128", |
| 65 | int(onnx.TensorProto.BFLOAT16): "bfloat16", |
| 66 | } |
| 67 | |
| 68 | |
| 69 | @dataclass |
| 70 | class IoMapEntry: |
| 71 | """Entry to map the output index from a producer step to the input index of a consumer step.""" |
| 72 | |
| 73 | # optional producer |
| 74 | # Uses Step if provided. |
| 75 | # If a str with a previous Step name is provided the PrePostProcessor will find the relevant Step |
| 76 | # If neither are provided the producer is inferred to be the immediately previous Step in the pipeline |
| 77 | producer: Union["Step", str] = None |
| 78 | # output index from the producer step |
| 79 | producer_idx: int = 0 |
| 80 | # input index of the consumer step |
| 81 | consumer_idx: int = 0 |
| 82 | |
| 83 | |
| 84 | def sanitize_output_names(graph: onnx.GraphProto): |
| 85 | """ |
| 86 | Convert any usage of invalid characters like '/' and ';' in value names to '_' |
| 87 | This is common in models exported from TensorFlow [Lite]. |
| 88 | |
| 89 | ONNX parse_graph does not allow for that in a value name, and technically it's a violation of the ONNX spec as per |
| 90 | https://github.com/onnx/onnx/blob/main/docs/IR.md#names-within-a-graph |
| 91 | |
| 92 | We do this for the original graph outputs only. The invalid naming has not been seen in model inputs, and we can |
| 93 | leave the internals of the graph intact to minimize changes. |
| 94 | |
| 95 | Args: |
| 96 | graph: Graph to check and update any invalid names |
| 97 | """ |
| 98 | |
| 99 | bad_output_names = [o.name for o in graph.output if "/" in o.name or ";" in o.name] |
| 100 | if not bad_output_names: |
| 101 | return graph |
| 102 | |
| 103 | renames = {} |
| 104 | for n in bad_output_names: |
| 105 | renames[n] = n.replace("/", "_").replace(";", "_") |
| 106 | |
| 107 | for o in graph.output: |
| 108 | if o.name in bad_output_names: |
| 109 | # Add Identity node to rename the output, and update the name in graph.output |
| 110 | rename = onnx.helper.make_node("Identity", [o.name], [renames[o.name]], f"Rename {o.name}") |
| 111 | graph.node.append(rename) |
| 112 | o.name = renames[o.name] |
| 113 | |