microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
onnxruntime_extensions/tools/pre_post_processing/utils.py
139lines · modecode
| 1 | # Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | # Licensed under the MIT License. |
| 3 | |
| 4 | import onnx |
| 5 | |
| 6 | from dataclasses import dataclass |
| 7 | from typing import List, Union |
| 8 | |
| 9 | |
| 10 | def create_named_value(name: str, data_type: int, shape: List[Union[str, int]]): |
| 11 | """ |
| 12 | Helper to create a new model input. |
| 13 | |
| 14 | Args: |
| 15 | name: Name for input. Must not already be in use in the model being updated. |
| 16 | data_type: onnx.TensorProto data type. e.g. onnx.TensorProto.FLOAT, onnx.TensorProto.UINT8 |
| 17 | shape: Input shape. Use int for dimensions with known values and strings for symbolic dimensions. |
| 18 | e.g. ['batch_size', 256, 256] would be a rank 3 tensor with a symbolic first dimension named 'batch_size' |
| 19 | |
| 20 | |
| 21 | Returns: |
| 22 | An onnx.ValueInfoProto that can be used as a new model input. |
| 23 | """ |
| 24 | tensor_type = onnx.helper.make_tensor_type_proto( |
| 25 | elem_type=data_type, shape=shape) |
| 26 | return onnx.helper.make_value_info(name, tensor_type) |
| 27 | |
| 28 | |
| 29 | # Create an onnx checker context that includes the ort-ext domain so that custom ops don't cause failure |
| 30 | def create_custom_op_checker_context(onnx_opset: int): |
| 31 | """ |
| 32 | Create an ONNX checker context that includes the ort-extensions custom op domains so that custom ops don't |
| 33 | cause failure when running onnx.checker.check_graph. |
| 34 | |
| 35 | Args: |
| 36 | onnx_opset: ONNX opset to use in the checker context. |
| 37 | |
| 38 | Returns: |
| 39 | ONNX checker context. |
| 40 | """ |
| 41 | context = onnx.checker.C.CheckerContext() |
| 42 | context.ir_version = onnx.checker.DEFAULT_CONTEXT.ir_version |
| 43 | context.opset_imports = {"": onnx_opset, "com.microsoft.extensions": 1} |
| 44 | |
| 45 | return context |
| 46 | |
| 47 | |
| 48 | # The ONNX graph parser has it's own map of names just to be special |
| 49 | # https://github.com/onnx/onnx/blob/604af9cb28f63a6b9924237dcb91530649233db9/onnx/defs/parser.h#L72 |
| 50 | TENSOR_TYPE_TO_ONNX_TYPE = { |
| 51 | int(onnx.TensorProto.FLOAT): "float", |
| 52 | int(onnx.TensorProto.UINT8): "uint8", |
| 53 | int(onnx.TensorProto.INT8): "int8", |
| 54 | int(onnx.TensorProto.UINT16): "uint16", |
| 55 | int(onnx.TensorProto.INT16): "int16", |
| 56 | int(onnx.TensorProto.INT32): "int32", |
| 57 | int(onnx.TensorProto.INT64): "int64", |
| 58 | int(onnx.TensorProto.STRING): "string", |
| 59 | int(onnx.TensorProto.BOOL): "bool", |
| 60 | int(onnx.TensorProto.FLOAT16): "float16", |
| 61 | int(onnx.TensorProto.DOUBLE): "double", |
| 62 | int(onnx.TensorProto.UINT32): "uint32", |
| 63 | int(onnx.TensorProto.UINT64): "uint64", |
| 64 | int(onnx.TensorProto.COMPLEX64): "complex64", |
| 65 | int(onnx.TensorProto.COMPLEX128): "complex128", |
| 66 | int(onnx.TensorProto.BFLOAT16): "bfloat16", |
| 67 | } |
| 68 | |
| 69 | |
| 70 | @dataclass |
| 71 | class IoMapEntry: |
| 72 | """Entry to map the output index from a producer step to the input index of a consumer step.""" |
| 73 | |
| 74 | # optional producer |
| 75 | # Uses Step if provided. |
| 76 | # If a str with a previous Step name is provided the PrePostProcessor will find the relevant Step |
| 77 | # If neither are provided the producer is inferred to be the immediately previous Step in the pipeline |
| 78 | producer: Union["Step", str] = None |
| 79 | # output index from the producer step |
| 80 | producer_idx: int = 0 |
| 81 | # input index of the consumer step |
| 82 | consumer_idx: int = 0 |
| 83 | |
| 84 | |
| 85 | @dataclass |
| 86 | class IOEntryValuePreserver: |
| 87 | """ |
| 88 | used to allow an output value to have multiple consumers, |
| 89 | which is only possible when IoMapEntry is used to create those additional connections. |
| 90 | |
| 91 | Generally, a connection consumes an output and an input, then the output is removed from the graph. |
| 92 | This class enabled one-to-many connections by making the other consumers share the same output. |
| 93 | |
| 94 | How this class works: |
| 95 | 1. when the IoMapEntry is created, this class will be created simultaneously. |
| 96 | 2. It records the producer and consumer steps, and the output index of the producer step. |
| 97 | when producer step is running, this IOEntryValuePreserver will be activated and start to preserve the output. |
| 98 | 3. when graph merge happens, this class will check if the output is still in the graph, if not, |
| 99 | it will add the output |
| 100 | 4. when consumer step is running, this class will be deactivated and remove output from preserved_list. |
| 101 | """ |
| 102 | |
| 103 | producer: Union["Step", str] = None |
| 104 | consumer: Union["Step", str] = None |
| 105 | # output index from the producer step |
| 106 | producer_idx: int = 0 |
| 107 | is_active: bool = False |
| 108 | output: str = None |
| 109 | |
| 110 | |
| 111 | def sanitize_output_names(graph: onnx.GraphProto): |
| 112 | """ |
| 113 | Convert any usage of invalid characters like '/' and ';' in value names to '_' |
| 114 | This is common in models exported from TensorFlow [Lite]. |
| 115 | |
| 116 | ONNX parse_graph does not allow for that in a value name, and technically it's a violation of the ONNX spec as per |
| 117 | https://github.com/onnx/onnx/blob/main/docs/IR.md#names-within-a-graph |
| 118 | |
| 119 | We do this for the original graph outputs only. The invalid naming has not been seen in model inputs, and we can |
| 120 | leave the internals of the graph intact to minimize changes. |
| 121 | |
| 122 | Args: |
| 123 | graph: Graph to check and update any invalid names |
| 124 | """ |
| 125 | |
| 126 | bad_output_names = [o.name for o in graph.output if "/" in o.name or ";" in o.name] |
| 127 | if not bad_output_names: |
| 128 | return graph |
| 129 | |
| 130 | renames = {} |
| 131 | for n in bad_output_names: |
| 132 | renames[n] = n.replace("/", "_").replace(";", "_") |
| 133 | |
| 134 | for o in graph.output: |
| 135 | if o.name in bad_output_names: |
| 136 | # Add Identity node to rename the output, and update the name in graph.output |
| 137 | rename = onnx.helper.make_node("Identity", [o.name], [renames[o.name]], f"Rename {o.name}") |
| 138 | graph.node.append(rename) |
| 139 | o.name = renames[o.name] |
| 140 | |