microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
2cf9bab611e9ad563822dee69c44e23bd017fadc

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/pnp/_utils.py

302lines · modecode

1import copy
2import onnx
3from onnx import helper, numpy_helper
4from collections import namedtuple
5
6
7class _Container:
8 def __init__(self):
9 self.parent = None
10 self.initializer=[]
11 self.value_info=[]
12 self.nodes = []
13 self.node_domain_version_pair_sets = {}
14
15 def add_model(self, oxml):
16 self.initializer.extend(oxml.graph.initializer)
17 self.value_info.extend(oxml.graph.value_info)
18 self.nodes.extend(oxml.graph.node)
19 self.node_domain_version_pair_sets.update(
20 [(opset_.domain, opset_.version) for opset_ in oxml.opset_import])
21 return self
22
23
24class ONNXModelUtils:
25 @staticmethod
26 def merge_name(prefix, name):
27 return "{}_{}".format(prefix, name)
28
29 @staticmethod
30 def _rename_iter(iterables, prefix_name, inplace=False):
31 new_iz = iterables if inplace else [copy.deepcopy(iz_) for iz_ in iterables]
32 for iz_ in new_iz:
33 iz_.name = ONNXModelUtils.merge_name(prefix_name, iz_.name)
34 return new_iz
35
36 @classmethod
37 def _rename_graph(cls, graph, prefix, graph_or_container):
38 def io_rename(node, prefix_name, idx):
39 new_node = copy.deepcopy(node)
40 if not node.name:
41 new_node.name = cls.merge_name(prefix_name, "op{}".format(idx))
42 else:
43 new_node.name = cls.merge_name(prefix_name, node.name)
44
45 del new_node.input[:]
46 new_node.input.extend(ONNXModelUtils.merge_name(prefix_name, nm_) if nm_ else '' for nm_ in node.input)
47 del new_node.output[:]
48 new_node.output.extend(ONNXModelUtils.merge_name(prefix_name, nm_) if nm_ else '' for nm_ in node.output)
49 return new_node
50
51 assert prefix is not None, 'The graph prefix could not be None'
52 graph_or_container.initializer.extend(cls._rename_iter(graph.initializer, prefix))
53 graph_or_container.value_info.extend(cls._rename_iter(graph.value_info, prefix))
54 return list(io_rename(nd_, prefix, idx_) for idx_, nd_ in enumerate(graph.node))
55
56 @classmethod
57 def _process_node_body(cls, node, prefix):
58 if all(attr.name != 'body' for attr in node.attribute):
59 return node
60
61 def _process_attr(attr, prefix_name):
62 if attr.name == 'body':
63 new_attr = copy.deepcopy(attr)
64 del new_attr.g.value_info[:]
65 del new_attr.g.node[:]
66 new_attr.g.node.extend(cls._rename_graph(attr.g, prefix_name, new_attr.g))
67 cls._rename_iter(new_attr.g.input, prefix_name, inplace=True)
68 cls._rename_iter(new_attr.g.output, prefix_name, inplace=True)
69 return new_attr
70 else:
71 return attr
72
73 attr_list = list(_process_attr(attr_, prefix) for attr_ in node.attribute)
74 del node.attribute[:]
75 node.attribute.extend(attr_list)
76 return node
77
78 @staticmethod
79 def get_model_name_abbr(node):
80 no = node.name.split('_')[-1]
81 return 'm_' + no
82
83 @staticmethod
84 def get_model_id_from_arg0(nodes, node):
85 arg0_name = node.input[0]
86 c_node = [n_ for n_ in nodes if
87 n_.op_type == 'Constant' and n_.output[0] == arg0_name]
88 assert len(c_node) == 1, 'internal error, multiple nodes with the same output.'
89 c_node = c_node[0]
90 tensor_value = onnx.helper.get_attribute_value(c_node.attribute[0])
91 _id = numpy_helper.to_array(tensor_value).item()
92 return _id
93
94 @classmethod
95 def _unfold_model_node(cls, container, name, model, io_mapping=None):
96 top_container = container
97 while top_container.parent is not None: # only one opset_import in the model.
98 top_container = top_container.parent
99
100 renamed_nodes = cls._rename_graph(model.graph, name, container)
101 onnx_nodes = [cls._process_node_body(nd_, name) for nd_ in renamed_nodes]
102
103 top_container.node_domain_version_pair_sets.update(
104 [(opset_.domain, opset_.version) for opset_ in model.opset_import])
105 return onnx_nodes
106
107 @classmethod
108 def unfold_model(cls, oxml, id_to_model, io_mapping=None):
109 container = _Container().add_model(oxml)
110 nodes = []
111 for _nid, _node in enumerate(oxml.graph.node):
112 if _node.op_type != '_ModelFunctionCall':
113 nodes.append(_node)
114 else:
115 model_id = cls.get_model_id_from_arg0(list(oxml.graph.node), _node)
116 if model_id not in id_to_model:
117 raise RuntimeError("Cannot find the model id({}) in the table".format(model_id))
118
119 prefix = cls.get_model_name_abbr(_node)
120 nest_model = id_to_model[model_id]
121
122 input_mapping = []
123 output_mapping = []
124 for idx_, in_ in enumerate(nest_model.graph.input):
125 _renamed_in = "{}_{}".format(prefix, in_.name)
126 _nd = onnx.helper.make_node('Identity',
127 [_node.input[idx_ + 1]], # the first arg is model id, skip it.
128 [_renamed_in],
129 name='i_' + _renamed_in)
130 input_mapping.append(_nd)
131 nds = cls._unfold_model_node(container,
132 prefix,
133 nest_model,
134 io_mapping)
135 for idx_, out_ in enumerate(nest_model.graph.output):
136 if idx_ >= len(_node.output):
137 continue
138 _renamed_out = "{}_{}".format(prefix, out_.name)
139 _nd = onnx.helper.make_node('Identity',
140 [_renamed_out],
141 [_node.output[idx_]],
142 name='o_' + _renamed_out)
143 output_mapping.append(_nd)
144 if io_mapping is not None:
145 assert callable(io_mapping), "io_mapping is a custom function to build the linkage of the models"
146 input_mapping, output_mapping = io_mapping(input_mapping, output_mapping)
147 # attention: the order of the list operations is important, which avoids the topological sort.
148 nodes.extend(input_mapping)
149 nodes.extend(nds)
150 nodes.extend(output_mapping)
151
152 intlzs = cls._remove_unused_initializers(nodes, container.initializer)
153 oxml = copy.deepcopy(oxml)
154 del oxml.graph.node[:]
155 oxml.graph.node.extend(nodes)
156 del oxml.graph.initializer[:]
157 oxml.graph.initializer.extend(intlzs)
158 return oxml
159
160 @classmethod
161 def topological_sort(cls, container, nodes, inputs, outputs):
162 op_output_map = {}
163 DynNode = namedtuple('DynNode', ['name', 'output'])
164 input_nodes = [DynNode(name='placeholder',
165 output=[nm_.name for nm_ in inputs] +
166 [it_.name for it_ in container.initializers])] + \
167 [nd_ for nd_ in nodes if nd_.op_type == 'Constant']
168
169 for nd_ in nodes + input_nodes:
170 for ky_ in nd_.output:
171 op_output_map[ky_] = nd_
172
173 edges = {}
174 for op in nodes:
175 for x in op.input:
176 if x == '':
177 continue
178 try:
179 predecessor = op_output_map[x]
180 except KeyError:
181 raise RuntimeError(
182 "{}: cannot find an operator to produce the tensor: {}".format(op.name, x)) from None
183
184 val = edges.get(predecessor.name, [])
185 val.append(op)
186 edges[predecessor.name] = val
187
188 for y_ in outputs:
189 op = op_output_map[y_.name].name
190 if op not in edges:
191 edges[op] = []
192
193 visited = set()
194 sorted_nodes = []
195 unfinished_nodes = set()
196
197 def recursive_helper(node):
198 if node.name in visited:
199 return
200
201 if node.name in unfinished_nodes:
202 raise RuntimeError("ONNX Graph is not a DAG, the cycle is found at {}".format(node.name))
203
204 unfinished_nodes.add(node.name)
205 if node.name in edges: # if the node's output is not in the Graph output.
206 assert node.name != '', 'this topological-sort depends on the unique node name.'
207 for successor in edges[node.name]:
208 recursive_helper(successor)
209
210 unfinished_nodes.remove(node.name)
211 visited.add(node.name)
212 if node is not input_nodes[0]:
213 sorted_nodes.insert(0, node)
214
215 for nd_ in input_nodes:
216 recursive_helper(nd_)
217
218 return sorted_nodes
219
220 @staticmethod
221 def _remove_unused_initializers(nodes, initializers, reserved_names=None):
222 if reserved_names is None:
223 reserved_names = set()
224 nodes_input_set = set()
225 for nd_ in nodes:
226 nodes_input_set.update(n_ for n_ in nd_.input)
227
228 return [intlz_ for intlz_ in initializers if intlz_.name in nodes_input_set or intlz_.name in reserved_names]
229
230 @classmethod
231 def join_models(cls, *models, io_mapping=None):
232 # generate the prefix id for the embedding graph to avoid the name conflict
233 mdl_prefix = []
234 for _i in range(len(models)):
235 mdl_prefix.append("g{}".format(_i + 1))
236
237 inputs = cls._rename_iter(models[0].graph.input, mdl_prefix[0])
238 outputs = cls._rename_iter(models[-1].graph.output, mdl_prefix[-1])
239
240 port_mapping = {}
241 if io_mapping is not None:
242 assert callable(io_mapping), "io_mapping is a custom function to build the linkage of the models"
243 ModelPort = namedtuple('ModelPort', "input output")
244 ports = []
245 for _idx in range(len(models)):
246 mio = ModelPort([cls.merge_name(mdl_prefix[_idx], _x.name) for _x in models[_idx].graph.input],
247 [cls.merge_name(mdl_prefix[_idx], _y.name) for _y in models[_idx].graph.output])
248 ports.append(mio)
249 port_mapping = io_mapping(ports)
250 for _idx in range(len(models) - 1):
251 for _i, _x in enumerate(models[_idx + 1].graph.input):
252 iname = cls.merge_name(mdl_prefix[_idx + 1], _x.name)
253 if iname not in port_mapping:
254 oname = cls.merge_name(mdl_prefix[_idx], models[_idx].graph.output[_i].name)
255 port_mapping[iname] = oname
256
257 nodes = []
258 container = _Container()
259 for _idx, _m in enumerate(models):
260 container.add_model(_m)
261 nodes += cls._rename_graph(_m.graph, mdl_prefix[_idx], container)
262
263 for _n in nodes:
264 replaceable = False
265 for _i in _n.input:
266 if _i in port_mapping:
267 replaceable = True
268 break
269 if replaceable:
270 new_input = copy.deepcopy(_n.input)
271 del _n.input[:]
272 _n.input.extend([port_mapping[_i] if _i in port_mapping else _i for _i in new_input])
273
274 name = "_".join([_mdl.graph.name for _mdl in models])
275 domains = set()
276 _opset = []
277 for _mdl in models:
278 for _ops in _mdl.opset_import:
279 domain = _ops.domain if _ops.domain else "ai.onnx"
280 if domain in domains:
281 if domain == "ai.onnx":
282 assert _ops.version == _opset[0].version, \
283 f"ai.onnx domain version doesn't match {_ops.version} != {_opset[0].version}"
284 else:
285 domains.add(domain)
286 if domain == "ai.onnx":
287 _opset.insert(0, _ops)
288 else:
289 _opset.append(_ops)
290
291 inits = cls._remove_unused_initializers(nodes, container.initializer)
292 g = helper.make_graph(nodes, name, inputs, outputs,
293 initializer=inits,
294 value_info=container.value_info)
295
296 if hasattr(helper, 'make_model_gen_version'):
297 # make_model_gen_version doesn't accept the custom domain.
298 m = helper.make_model_gen_version(g, opset_imports=_opset[:1])
299 m.opset_import.extend(_opset[1:])
300 else:
301 m = helper.make_model(g, opset_imports=_opset)
302 return m
303