microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
rel-0.7

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/tools/pre_post_processing/utils.py

139lines · modecode

1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License.
3
4import onnx
5
6from dataclasses import dataclass
7from typing import List, Union
8
9
10def create_named_value(name: str, data_type: int, shape: List[Union[str, int]]):
11 """
12 Helper to create a new model input.
13
14 Args:
15 name: Name for input. Must not already be in use in the model being updated.
16 data_type: onnx.TensorProto data type. e.g. onnx.TensorProto.FLOAT, onnx.TensorProto.UINT8
17 shape: Input shape. Use int for dimensions with known values and strings for symbolic dimensions.
18 e.g. ['batch_size', 256, 256] would be a rank 3 tensor with a symbolic first dimension named 'batch_size'
19
20
21 Returns:
22 An onnx.ValueInfoProto that can be used as a new model input.
23 """
24 tensor_type = onnx.helper.make_tensor_type_proto(
25 elem_type=data_type, shape=shape)
26 return onnx.helper.make_value_info(name, tensor_type)
27
28
29# Create an onnx checker context that includes the ort-ext domain so that custom ops don't cause failure
30def create_custom_op_checker_context(onnx_opset: int):
31 """
32 Create an ONNX checker context that includes the ort-extensions custom op domains so that custom ops don't
33 cause failure when running onnx.checker.check_graph.
34
35 Args:
36 onnx_opset: ONNX opset to use in the checker context.
37
38 Returns:
39 ONNX checker context.
40 """
41 context = onnx.checker.C.CheckerContext()
42 context.ir_version = onnx.checker.DEFAULT_CONTEXT.ir_version
43 context.opset_imports = {"": onnx_opset, "com.microsoft.extensions": 1}
44
45 return context
46
47
48# The ONNX graph parser has it's own map of names just to be special
49# https://github.com/onnx/onnx/blob/604af9cb28f63a6b9924237dcb91530649233db9/onnx/defs/parser.h#L72
50TENSOR_TYPE_TO_ONNX_TYPE = {
51 int(onnx.TensorProto.FLOAT): "float",
52 int(onnx.TensorProto.UINT8): "uint8",
53 int(onnx.TensorProto.INT8): "int8",
54 int(onnx.TensorProto.UINT16): "uint16",
55 int(onnx.TensorProto.INT16): "int16",
56 int(onnx.TensorProto.INT32): "int32",
57 int(onnx.TensorProto.INT64): "int64",
58 int(onnx.TensorProto.STRING): "string",
59 int(onnx.TensorProto.BOOL): "bool",
60 int(onnx.TensorProto.FLOAT16): "float16",
61 int(onnx.TensorProto.DOUBLE): "double",
62 int(onnx.TensorProto.UINT32): "uint32",
63 int(onnx.TensorProto.UINT64): "uint64",
64 int(onnx.TensorProto.COMPLEX64): "complex64",
65 int(onnx.TensorProto.COMPLEX128): "complex128",
66 int(onnx.TensorProto.BFLOAT16): "bfloat16",
67}
68
69
70@dataclass
71class IoMapEntry:
72 """Entry to map the output index from a producer step to the input index of a consumer step."""
73
74 # optional producer
75 # Uses Step if provided.
76 # If a str with a previous Step name is provided the PrePostProcessor will find the relevant Step
77 # If neither are provided the producer is inferred to be the immediately previous Step in the pipeline
78 producer: Union["Step", str] = None
79 # output index from the producer step
80 producer_idx: int = 0
81 # input index of the consumer step
82 consumer_idx: int = 0
83
84
85@dataclass
86class IOEntryValuePreserver:
87 """
88 used to allow an output value to have multiple consumers,
89 which is only possible when IoMapEntry is used to create those additional connections.
90
91 Generally, a connection consumes an output and an input, then the output is removed from the graph.
92 This class enabled one-to-many connections by making the other consumers share the same output.
93
94 How this class works:
95 1. when the IoMapEntry is created, this class will be created simultaneously.
96 2. It records the producer and consumer steps, and the output index of the producer step.
97 when producer step is running, this IOEntryValuePreserver will be activated and start to preserve the output.
98 3. when graph merge happens, this class will check if the output is still in the graph, if not,
99 it will add the output
100 4. when consumer step is running, this class will be deactivated and remove output from preserved_list.
101 """
102
103 producer: Union["Step", str] = None
104 consumer: Union["Step", str] = None
105 # output index from the producer step
106 producer_idx: int = 0
107 is_active: bool = False
108 output: str = None
109
110
111def sanitize_output_names(graph: onnx.GraphProto):
112 """
113 Convert any usage of invalid characters like '/' and ';' in value names to '_'
114 This is common in models exported from TensorFlow [Lite].
115
116 ONNX parse_graph does not allow for that in a value name, and technically it's a violation of the ONNX spec as per
117 https://github.com/onnx/onnx/blob/main/docs/IR.md#names-within-a-graph
118
119 We do this for the original graph outputs only. The invalid naming has not been seen in model inputs, and we can
120 leave the internals of the graph intact to minimize changes.
121
122 Args:
123 graph: Graph to check and update any invalid names
124 """
125
126 bad_output_names = [o.name for o in graph.output if "/" in o.name or ";" in o.name]
127 if not bad_output_names:
128 return graph
129
130 renames = {}
131 for n in bad_output_names:
132 renames[n] = n.replace("/", "_").replace(";", "_")
133
134 for o in graph.output:
135 if o.name in bad_output_names:
136 # Add Identity node to rename the output, and update the name in graph.output
137 rename = onnx.helper.make_node("Identity", [o.name], [renames[o.name]], f"Rename {o.name}")
138 graph.node.append(rename)
139 o.name = renames[o.name]
140