microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
367f59c6fad2e820b9b1bb5807065c5d6c3886d0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/onnxprocess/_session.py

355lines · modecode

1import copy
2import onnx
3import torch
4import warnings
5import numpy as np
6from onnx import helper, mapping
7from collections import namedtuple
8from .._ortapi2 import OrtPyFunction
9from ._builder import is_path as _is_path
10from ._onnx_ops import ONNXElementContainer, make_model_ex
11from ._tensor import tensor_from_onnx, tensor_from_torch, tensor_set_session
12
13
14def _is_numpy_object(x):
15 return isinstance(x, (np.ndarray, np.generic))
16
17
18def _is_numpy_string_type(arr):
19 return arr.dtype.kind in {'U', 'S'}
20
21
22def _is_string_type(x):
23 if not _is_numpy_object(x):
24 x = np.array(x)
25 return _is_numpy_string_type(x)
26
27
28class ONNXModelUtils:
29 @staticmethod
30 def _rename_iter(iterables, prefix_name, inplace=False):
31 new_iz = iterables if inplace else [copy.deepcopy(iz_) for iz_ in iterables]
32 for iz_ in new_iz:
33 iz_.name = "{}_{}".format(prefix_name, iz_.name)
34 return new_iz
35
36 @classmethod
37 def _rename_graph(cls, graph, prefix, graph_or_container):
38 def io_rename(node, prefix_name, idx):
39 new_node = copy.deepcopy(node)
40 if not node.name:
41 new_node.name = "{}_op{}".format(prefix_name, idx)
42
43 del new_node.input[:]
44 new_node.input.extend("{}_{}".format(prefix_name, nm_) if nm_ else '' for nm_ in node.input)
45 del new_node.output[:]
46 new_node.output.extend("{}_{}".format(prefix_name, nm_) if nm_ else '' for nm_ in node.output)
47 return new_node
48
49 assert prefix is not None, 'The graph prefix could not be None'
50 graph_or_container.initializer.extend(cls._rename_iter(graph.initializer, prefix))
51 graph_or_container.value_info.extend(cls._rename_iter(graph.value_info, prefix))
52 return list(io_rename(nd_, prefix, idx_) for idx_, nd_ in enumerate(graph.node))
53
54 @classmethod
55 def _process_node_body(cls, node, prefix):
56 if all(attr.name != 'body' for attr in node.attribute):
57 return node
58
59 def _process_attr(attr, prefix_name):
60 if attr.name == 'body':
61 new_attr = copy.deepcopy(attr)
62 del new_attr.g.value_info[:]
63 del new_attr.g.node[:]
64 new_attr.g.node.extend(cls._rename_graph(attr.g, prefix_name, new_attr.g))
65 cls._rename_iter(new_attr.g.input, prefix_name, inplace=True)
66 cls._rename_iter(new_attr.g.output, prefix_name, inplace=True)
67 return new_attr
68 else:
69 return attr
70
71 attr_list = list(_process_attr(attr_, prefix) for attr_ in node.attribute)
72 del node.attribute[:]
73 node.attribute.extend(attr_list)
74 return node
75
76 @classmethod
77 def unfold_model_node(cls, container: ONNXElementContainer):
78 top_containter = container
79 while top_containter.parent is not None: # only one opset_import in the model.
80 top_containter = top_containter.parent
81
82 nodes = container.nodes
83 model_nodes = {node.name: node for node in nodes if hasattr(node, 'model')}
84 onnx_nodes = [nd_ for nd_ in nodes if nd_.name not in model_nodes]
85
86 for node in model_nodes.values():
87 renamed_nodes = cls._rename_graph(node.model.graph, node.name, container)
88 onnx_nodes.extend(cls._process_node_body(nd_, node.name) for nd_ in renamed_nodes)
89
90 top_containter.node_domain_version_pair_sets.update([(opset_.domain, opset_.version) for opset_ in node.model.opset_import])
91 return onnx_nodes
92
93 @classmethod
94 def topological_sort(cls, container, nodes, inputs, outputs):
95 op_output_map = {}
96 DynNode = namedtuple('DynNode', ['name', 'output'])
97 input_nodes = [DynNode(name='placeholder',
98 output=[nm_.name for nm_ in inputs] +
99 [it_.name for it_ in container.initializers])] +\
100 [nd_ for nd_ in nodes if nd_.op_type == 'Constant']
101
102 for nd_ in nodes + input_nodes:
103 for ky_ in nd_.output:
104 op_output_map[ky_] = nd_
105
106 edges = {}
107 for op in nodes:
108 for x in op.input:
109 if x == '':
110 continue
111 try:
112 predecessor = op_output_map[x]
113 except KeyError:
114 raise RuntimeError(
115 "{}: cannot find an operator to produce the tensor: {}".format(op.name, x)) from None
116
117 val = edges.get(predecessor.name, [])
118 val.append(op)
119 edges[predecessor.name] = val
120
121 for y_ in outputs:
122 op = op_output_map[y_.name].name
123 if op not in edges:
124 edges[op] = []
125
126 visited = set()
127 sorted_nodes = []
128 unfinished_nodes = set()
129
130 def recursive_helper(node):
131 if node.name in visited:
132 return
133
134 if node.name in unfinished_nodes:
135 raise RuntimeError("ONNX Graph is not a DAG, the cycle is found at {}".format(node.name))
136
137 unfinished_nodes.add(node.name)
138 if node.name in edges: # if the node's output is not in the Graph output.
139 assert node.name != '', 'this topological-sort depends on the unique node name.'
140 for successor in edges[node.name]:
141 recursive_helper(successor)
142
143 unfinished_nodes.remove(node.name)
144 visited.add(node.name)
145 if node is not input_nodes[0]:
146 sorted_nodes.insert(0, node)
147
148 for nd_ in input_nodes:
149 recursive_helper(nd_)
150
151 return sorted_nodes
152
153 @staticmethod
154 def value_info_from_numpy(name, value):
155 dtype = onnx.onnx_pb.TensorProto.STRING if \
156 _is_numpy_string_type(value) else mapping.NP_TYPE_TO_TENSOR_TYPE[value.dtype]
157 return helper.make_tensor_value_info(name, dtype, shape=value.shape)
158
159 @staticmethod
160 def model_from_ops(container, ops, ts_from, ts_to):
161 all_inputs = []
162 all_outputs = []
163 iz_needed = set()
164 iz_set = set(iz_.name for iz_ in container.initializer)
165 for op in ops:
166 iz_needed.update(it_ for it_ in op.input if it_ in iz_set)
167 all_inputs.extend(it_ for it_ in op.input if (it_ != '') and it_ not in iz_set)
168 all_outputs.extend(ot_ for ot_ in op.output)
169
170 intersections = set(all_inputs).intersection(set(all_outputs))
171 assert set(all_inputs).difference(intersections) == set(ts_.name for ts_ in ts_from), \
172 "The input list is different from the calculated from the op nodes"
173 assert set(all_outputs).difference(intersections) == set(ts_.name for ts_ in ts_to), \
174 "The output list is different from the calculated from the op nodes"
175
176 final_iz = [iz_ for iz_ in container.initializers if iz_.name in iz_needed]
177 graph = helper.make_graph(ops, 'dyngraph', ts_from, ts_to, final_iz)
178 oxml = make_model_ex(graph,
179 container.node_domain_version_pair_sets,
180 container.target_opset)
181 return oxml
182
183
184class ONNXTraceSession:
185 activated_sessions = []
186
187 def __init__(self, target_opset):
188 self.container = ONNXElementContainer(target_opset)
189 self.inputs = []
190 self.outputs = []
191
192 def __enter__(self):
193 assert len(self.activated_sessions) > 0 and self.activated_sessions[-1] is self, "trace not started?"
194 return self
195
196 # need this exit to close the session
197 def __exit__(self, exec_type, exec_value, exec_tb):
198 tensor_set_session(None)
199 assert self is self.activated_sessions.pop()
200
201 @classmethod
202 def trace_for_onnx(cls, *inputs, names=None, target_opset=11) -> 'ONNXTraceSession':
203 """
204 Starting the trace all tensor computation for ONNX graph generation.
205 :param inputs: the input tensor, could a torch.Tensor or a numpy ndarray.
206 :param names: The input names the ONNX graph
207 :param target_opset: The ONNX model opset_version
208 :return: A tracing session object, in most case, it should be used in the with statement.
209 """
210 self = ONNXTraceSession(target_opset)
211 self.activated_sessions.append(self)
212 tensor_set_session(self)
213
214 np_inputs = [np.array(x) if _is_string_type(x) else x for x in inputs]
215 np_inputs = [
216 x if isinstance(x, (np.ndarray, np.generic, torch.Tensor)) or _is_string_type(x)
217 else torch.tensor(x) for x in np_inputs]
218 itensors = [tensor_from_torch(i_, None) if isinstance(i_, torch.Tensor)
219 else tensor_from_onnx(i_, None, None) for i_ in np_inputs]
220 if names is None:
221 names = []
222 if len(inputs) != len(names):
223 warnings.warn("the name number doesn't match the inputs', assign to the ones in the front.")
224 names.extend([''] * (len(inputs) - len(names)))
225 for idx_ in range(len(inputs)):
226 names[idx_] = names[idx_] if names[idx_] else "input{}".format(idx_)
227 num = min(len(itensors), len(names))
228 for idx_ in range(num):
229 itensors[idx_].name = names[idx_]
230 self.inputs = itensors
231 return self
232
233 def runops(self, ts_from, ts_to):
234 nodes = self.container.nodes
235 inset = set(ts_.name for ts_ in ts_from)
236 inset.update(iz_.name for iz_ in self.container.initializer)
237 outset = set(ts_.name for ts_ in ts_to)
238 missing_ts_set = set()
239 node_num = len(nodes) - 1
240 while node_num >= 0:
241 node = nodes[node_num]
242 for ot_ in node.output:
243 if ot_ in missing_ts_set:
244 missing_ts_set.remove(ot_)
245 elif ot_ in outset:
246 outset.remove(ot_)
247 for it_ in node.input:
248 if it_ not in inset:
249 missing_ts_set.add(it_)
250 if len(missing_ts_set) == 0:
251 break
252 node_num -= 1
253
254 assert len(outset) == 0, "Some output cannot be in the node list."
255 assert len(missing_ts_set) == 0, "Some input cannot be in the node list."
256 collected_nodes = nodes[node_num:]
257 vi_input = [ONNXModelUtils.value_info_from_numpy(ts_.name, ts_.numpy())
258 for ts_ in ts_from]
259 vi_output = [ONNXModelUtils.value_info_from_numpy(ts_.name, ts_.numpy())
260 for ts_ in ts_to]
261 oxml = ONNXModelUtils.model_from_ops(self.container,
262 collected_nodes,
263 vi_input,
264 vi_output)
265 result = None
266 try:
267 oxfunc = OrtPyFunction.from_model(oxml)
268 result = oxfunc(*[ts_.numpy() for ts_ in ts_from])
269 finally:
270 if result is None:
271 onnx.save_model(oxml, 'mt_debmodel.onnx')
272
273 return result if isinstance(result, (list, tuple)) else [result], oxml
274
275 def get_inputs(self):
276 return self.inputs
277
278 def stack_container(self):
279 assert self.container is not None, "Stacked container must be in another one."
280 sub_container = ONNXElementContainer(self.container.target_opset, self.container)
281 self.container = sub_container
282 return self.container
283
284 def pop_container(self):
285 assert self.container.parent is not None, "Cannot pop the root container."
286 self.container = self.container.parent
287 return self.container
288
289 @staticmethod
290 def build_graph(container, ts_inputs, ts_outputs, graph_name=None):
291 # some constant ops are created to simulate the tensors generated from the runtime in the loop,
292 # so we need to remove the node here
293 to_del = []
294 input_names = {it_.name: None for it_ in ts_inputs}
295 for idx_, nd_ in enumerate(container.nodes):
296 if nd_.op_type == 'Constant' and list(nd_.output)[0] in input_names:
297 to_del.append(idx_)
298
299 for idx_ in to_del[::-1]:
300 container.nodes.pop(idx_)
301
302 graph_name = container.get_unique_operator_name('subg') if not graph_name else graph_name
303 nodes = ONNXModelUtils.unfold_model_node(container)
304 nodes = ONNXModelUtils.topological_sort(container, nodes, ts_inputs, ts_outputs)
305
306 for vi_ in container.value_info:
307 if vi_.name in input_names:
308 input_names[vi_.name] = vi_
309
310 inputs = [helper.make_tensor_value_info(si.name, si.onnx_type, si.get_shape())
311 if input_names.get(si.name) is None else input_names[si.name] for si in ts_inputs]
312 outputs = [helper.make_tensor_value_info(so.name, so.onnx_type,
313 so.get_shape()) for so in ts_outputs]
314
315 graph = helper.make_graph(nodes, graph_name, inputs,
316 outputs, container.initializers)
317 return graph
318
319 def build_model(self, model_name=None, doc_string=None) -> onnx.ModelProto:
320 model_name = 'tcm' if model_name is None else model_name
321 doc_string = '' if doc_string is None else doc_string
322 container = self.container
323 graph = self.build_graph(container, self.inputs, self.outputs, model_name)
324 onnx_model = make_model_ex(graph, container.node_domain_version_pair_sets,
325 container.target_opset, doc_string=doc_string)
326 return onnx_model
327
328 def save_as_onnx(self, file_like_or_path, outputs, model_name=None, doc_string=None):
329 """
330 Build the ONNX model from the traced computation graph.
331 :param file_like_or_path: an io.BytesIO like object or a file path
332 :param outputs: the output tensor to be specified as the ONNX graph output,
333 Could be a string if there are multiple output tensors.
334 :param model_name: The ONNX model internal name
335 :param doc_string: The doc string for the model
336 :return: A ONNX ModelProto object.
337 """
338 if len(self.outputs) == 0 and outputs is None:
339 raise RuntimeError("No output of the graph specified.")
340
341 if len(self.outputs) == 0:
342 self.outputs = outputs if isinstance(outputs, (list, tuple)) else [outputs]
343
344 m = self.build_model(model_name, doc_string)
345
346 if file_like_or_path is not None:
347 if _is_path(file_like_or_path):
348 with open(file_like_or_path, 'wb') as f:
349 f.write(m.SerializeToString())
350 else:
351 f = file_like_or_path
352 f.write(m.SerializeToString())
353 f.flush()
354
355 return m
356