microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
edgchen1/fix_ci

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/tools/pre_post_processing/utils.py

112lines · modecode

1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License.
3
4import onnx
5
6from dataclasses import dataclass
7from typing import List, Union
8
9
10def create_named_value(name: str, data_type: int, shape: List[Union[str, int]]):
11 """
12 Helper to create a new model input.
13
14 Args:
15 name: Name for input. Must not already be in use in the model being updated.
16 data_type: onnx.TensorProto data type. e.g. onnx.TensorProto.FLOAT, onnx.TensorProto.UINT8
17 shape: Input shape. Use int for dimensions with known values and strings for symbolic dimensions.
18 e.g. ['batch_size', 256, 256] would be a rank 3 tensor with a symbolic first dimension named 'batch_size'
19
20
21 Returns:
22 An onnx.ValueInfoProto that can be used as a new model input.
23 """
24 tensor_type = onnx.helper.make_tensor_type_proto(elem_type=data_type, shape=shape)
25 return onnx.helper.make_value_info(name, tensor_type)
26
27
28# Create an onnx checker context that includes the ort-ext domain so that custom ops don't cause failure
29def create_custom_op_checker_context(onnx_opset: int):
30 """
31 Create an ONNX checker context that includes the ort-extensions custom op domains so that custom ops don't
32 cause failure when running onnx.checker.check_graph.
33
34 Args:
35 onnx_opset: ONNX opset to use in the checker context.
36
37 Returns:
38 ONNX checker context.
39 """
40 context = onnx.checker.C.CheckerContext()
41 context.ir_version = onnx.checker.DEFAULT_CONTEXT.ir_version
42 context.opset_imports = {"": onnx_opset, "com.microsoft.extensions": 1}
43
44 return context
45
46
47# The ONNX graph parser has it's own map of names just to be special
48# https://github.com/onnx/onnx/blob/604af9cb28f63a6b9924237dcb91530649233db9/onnx/defs/parser.h#L72
49TENSOR_TYPE_TO_ONNX_TYPE = {
50 int(onnx.TensorProto.FLOAT): "float",
51 int(onnx.TensorProto.UINT8): "uint8",
52 int(onnx.TensorProto.INT8): "int8",
53 int(onnx.TensorProto.UINT16): "uint16",
54 int(onnx.TensorProto.INT16): "int16",
55 int(onnx.TensorProto.INT32): "int32",
56 int(onnx.TensorProto.INT64): "int64",
57 int(onnx.TensorProto.STRING): "string",
58 int(onnx.TensorProto.BOOL): "bool",
59 int(onnx.TensorProto.FLOAT16): "float16",
60 int(onnx.TensorProto.DOUBLE): "double",
61 int(onnx.TensorProto.UINT32): "uint32",
62 int(onnx.TensorProto.UINT64): "uint64",
63 int(onnx.TensorProto.COMPLEX64): "complex64",
64 int(onnx.TensorProto.COMPLEX128): "complex128",
65 int(onnx.TensorProto.BFLOAT16): "bfloat16",
66}
67
68
69@dataclass
70class IoMapEntry:
71 """Entry to map the output index from a producer step to the input index of a consumer step."""
72
73 # optional producer
74 # Uses Step if provided.
75 # If a str with a previous Step name is provided the PrePostProcessor will find the relevant Step
76 # If neither are provided the producer is inferred to be the immediately previous Step in the pipeline
77 producer: Union["Step", str] = None
78 # output index from the producer step
79 producer_idx: int = 0
80 # input index of the consumer step
81 consumer_idx: int = 0
82
83
84def sanitize_output_names(graph: onnx.GraphProto):
85 """
86 Convert any usage of invalid characters like '/' and ';' in value names to '_'
87 This is common in models exported from TensorFlow [Lite].
88
89 ONNX parse_graph does not allow for that in a value name, and technically it's a violation of the ONNX spec as per
90 https://github.com/onnx/onnx/blob/main/docs/IR.md#names-within-a-graph
91
92 We do this for the original graph outputs only. The invalid naming has not been seen in model inputs, and we can
93 leave the internals of the graph intact to minimize changes.
94
95 Args:
96 graph: Graph to check and update any invalid names
97 """
98
99 bad_output_names = [o.name for o in graph.output if "/" in o.name or ";" in o.name]
100 if not bad_output_names:
101 return graph
102
103 renames = {}
104 for n in bad_output_names:
105 renames[n] = n.replace("/", "_").replace(";", "_")
106
107 for o in graph.output:
108 if o.name in bad_output_names:
109 # Add Identity node to rename the output, and update the name in graph.output
110 rename = onnx.helper.make_node("Identity", [o.name], [renames[o.name]], f"Rename {o.name}")
111 graph.node.append(rename)
112 o.name = renames[o.name]
113