microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
onnxruntime_extensions/onnxprocess/_session.py
355lines · modecode
| 1 | import copy |
| 2 | import onnx |
| 3 | import torch |
| 4 | import warnings |
| 5 | import numpy as np |
| 6 | from onnx import helper, mapping |
| 7 | from collections import namedtuple |
| 8 | from .._ortapi2 import OrtPyFunction |
| 9 | from ._builder import is_path as _is_path |
| 10 | from ._onnx_ops import ONNXElementContainer, make_model_ex |
| 11 | from ._tensor import tensor_from_onnx, tensor_from_torch, tensor_set_session |
| 12 | |
| 13 | |
| 14 | def _is_numpy_object(x): |
| 15 | return isinstance(x, (np.ndarray, np.generic)) |
| 16 | |
| 17 | |
| 18 | def _is_numpy_string_type(arr): |
| 19 | return arr.dtype.kind in {'U', 'S'} |
| 20 | |
| 21 | |
| 22 | def _is_string_type(x): |
| 23 | if not _is_numpy_object(x): |
| 24 | x = np.array(x) |
| 25 | return _is_numpy_string_type(x) |
| 26 | |
| 27 | |
| 28 | class ONNXModelUtils: |
| 29 | @staticmethod |
| 30 | def _rename_iter(iterables, prefix_name, inplace=False): |
| 31 | new_iz = iterables if inplace else [copy.deepcopy(iz_) for iz_ in iterables] |
| 32 | for iz_ in new_iz: |
| 33 | iz_.name = "{}_{}".format(prefix_name, iz_.name) |
| 34 | return new_iz |
| 35 | |
| 36 | @classmethod |
| 37 | def _rename_graph(cls, graph, prefix, graph_or_container): |
| 38 | def io_rename(node, prefix_name, idx): |
| 39 | new_node = copy.deepcopy(node) |
| 40 | if not node.name: |
| 41 | new_node.name = "{}_op{}".format(prefix_name, idx) |
| 42 | |
| 43 | del new_node.input[:] |
| 44 | new_node.input.extend("{}_{}".format(prefix_name, nm_) if nm_ else '' for nm_ in node.input) |
| 45 | del new_node.output[:] |
| 46 | new_node.output.extend("{}_{}".format(prefix_name, nm_) if nm_ else '' for nm_ in node.output) |
| 47 | return new_node |
| 48 | |
| 49 | assert prefix is not None, 'The graph prefix could not be None' |
| 50 | graph_or_container.initializer.extend(cls._rename_iter(graph.initializer, prefix)) |
| 51 | graph_or_container.value_info.extend(cls._rename_iter(graph.value_info, prefix)) |
| 52 | return list(io_rename(nd_, prefix, idx_) for idx_, nd_ in enumerate(graph.node)) |
| 53 | |
| 54 | @classmethod |
| 55 | def _process_node_body(cls, node, prefix): |
| 56 | if all(attr.name != 'body' for attr in node.attribute): |
| 57 | return node |
| 58 | |
| 59 | def _process_attr(attr, prefix_name): |
| 60 | if attr.name == 'body': |
| 61 | new_attr = copy.deepcopy(attr) |
| 62 | del new_attr.g.value_info[:] |
| 63 | del new_attr.g.node[:] |
| 64 | new_attr.g.node.extend(cls._rename_graph(attr.g, prefix_name, new_attr.g)) |
| 65 | cls._rename_iter(new_attr.g.input, prefix_name, inplace=True) |
| 66 | cls._rename_iter(new_attr.g.output, prefix_name, inplace=True) |
| 67 | return new_attr |
| 68 | else: |
| 69 | return attr |
| 70 | |
| 71 | attr_list = list(_process_attr(attr_, prefix) for attr_ in node.attribute) |
| 72 | del node.attribute[:] |
| 73 | node.attribute.extend(attr_list) |
| 74 | return node |
| 75 | |
| 76 | @classmethod |
| 77 | def unfold_model_node(cls, container: ONNXElementContainer): |
| 78 | top_containter = container |
| 79 | while top_containter.parent is not None: # only one opset_import in the model. |
| 80 | top_containter = top_containter.parent |
| 81 | |
| 82 | nodes = container.nodes |
| 83 | model_nodes = {node.name: node for node in nodes if hasattr(node, 'model')} |
| 84 | onnx_nodes = [nd_ for nd_ in nodes if nd_.name not in model_nodes] |
| 85 | |
| 86 | for node in model_nodes.values(): |
| 87 | renamed_nodes = cls._rename_graph(node.model.graph, node.name, container) |
| 88 | onnx_nodes.extend(cls._process_node_body(nd_, node.name) for nd_ in renamed_nodes) |
| 89 | |
| 90 | top_containter.node_domain_version_pair_sets.update([(opset_.domain, opset_.version) for opset_ in node.model.opset_import]) |
| 91 | return onnx_nodes |
| 92 | |
| 93 | @classmethod |
| 94 | def topological_sort(cls, container, nodes, inputs, outputs): |
| 95 | op_output_map = {} |
| 96 | DynNode = namedtuple('DynNode', ['name', 'output']) |
| 97 | input_nodes = [DynNode(name='placeholder', |
| 98 | output=[nm_.name for nm_ in inputs] + |
| 99 | [it_.name for it_ in container.initializers])] +\ |
| 100 | [nd_ for nd_ in nodes if nd_.op_type == 'Constant'] |
| 101 | |
| 102 | for nd_ in nodes + input_nodes: |
| 103 | for ky_ in nd_.output: |
| 104 | op_output_map[ky_] = nd_ |
| 105 | |
| 106 | edges = {} |
| 107 | for op in nodes: |
| 108 | for x in op.input: |
| 109 | if x == '': |
| 110 | continue |
| 111 | try: |
| 112 | predecessor = op_output_map[x] |
| 113 | except KeyError: |
| 114 | raise RuntimeError( |
| 115 | "{}: cannot find an operator to produce the tensor: {}".format(op.name, x)) from None |
| 116 | |
| 117 | val = edges.get(predecessor.name, []) |
| 118 | val.append(op) |
| 119 | edges[predecessor.name] = val |
| 120 | |
| 121 | for y_ in outputs: |
| 122 | op = op_output_map[y_.name].name |
| 123 | if op not in edges: |
| 124 | edges[op] = [] |
| 125 | |
| 126 | visited = set() |
| 127 | sorted_nodes = [] |
| 128 | unfinished_nodes = set() |
| 129 | |
| 130 | def recursive_helper(node): |
| 131 | if node.name in visited: |
| 132 | return |
| 133 | |
| 134 | if node.name in unfinished_nodes: |
| 135 | raise RuntimeError("ONNX Graph is not a DAG, the cycle is found at {}".format(node.name)) |
| 136 | |
| 137 | unfinished_nodes.add(node.name) |
| 138 | if node.name in edges: # if the node's output is not in the Graph output. |
| 139 | assert node.name != '', 'this topological-sort depends on the unique node name.' |
| 140 | for successor in edges[node.name]: |
| 141 | recursive_helper(successor) |
| 142 | |
| 143 | unfinished_nodes.remove(node.name) |
| 144 | visited.add(node.name) |
| 145 | if node is not input_nodes[0]: |
| 146 | sorted_nodes.insert(0, node) |
| 147 | |
| 148 | for nd_ in input_nodes: |
| 149 | recursive_helper(nd_) |
| 150 | |
| 151 | return sorted_nodes |
| 152 | |
| 153 | @staticmethod |
| 154 | def value_info_from_numpy(name, value): |
| 155 | dtype = onnx.onnx_pb.TensorProto.STRING if \ |
| 156 | _is_numpy_string_type(value) else mapping.NP_TYPE_TO_TENSOR_TYPE[value.dtype] |
| 157 | return helper.make_tensor_value_info(name, dtype, shape=value.shape) |
| 158 | |
| 159 | @staticmethod |
| 160 | def model_from_ops(container, ops, ts_from, ts_to): |
| 161 | all_inputs = [] |
| 162 | all_outputs = [] |
| 163 | iz_needed = set() |
| 164 | iz_set = set(iz_.name for iz_ in container.initializer) |
| 165 | for op in ops: |
| 166 | iz_needed.update(it_ for it_ in op.input if it_ in iz_set) |
| 167 | all_inputs.extend(it_ for it_ in op.input if (it_ != '') and it_ not in iz_set) |
| 168 | all_outputs.extend(ot_ for ot_ in op.output) |
| 169 | |
| 170 | intersections = set(all_inputs).intersection(set(all_outputs)) |
| 171 | assert set(all_inputs).difference(intersections) == set(ts_.name for ts_ in ts_from), \ |
| 172 | "The input list is different from the calculated from the op nodes" |
| 173 | assert set(all_outputs).difference(intersections) == set(ts_.name for ts_ in ts_to), \ |
| 174 | "The output list is different from the calculated from the op nodes" |
| 175 | |
| 176 | final_iz = [iz_ for iz_ in container.initializers if iz_.name in iz_needed] |
| 177 | graph = helper.make_graph(ops, 'dyngraph', ts_from, ts_to, final_iz) |
| 178 | oxml = make_model_ex(graph, |
| 179 | container.node_domain_version_pair_sets, |
| 180 | container.target_opset) |
| 181 | return oxml |
| 182 | |
| 183 | |
| 184 | class ONNXTraceSession: |
| 185 | activated_sessions = [] |
| 186 | |
| 187 | def __init__(self, target_opset): |
| 188 | self.container = ONNXElementContainer(target_opset) |
| 189 | self.inputs = [] |
| 190 | self.outputs = [] |
| 191 | |
| 192 | def __enter__(self): |
| 193 | assert len(self.activated_sessions) > 0 and self.activated_sessions[-1] is self, "trace not started?" |
| 194 | return self |
| 195 | |
| 196 | # need this exit to close the session |
| 197 | def __exit__(self, exec_type, exec_value, exec_tb): |
| 198 | tensor_set_session(None) |
| 199 | assert self is self.activated_sessions.pop() |
| 200 | |
| 201 | @classmethod |
| 202 | def trace_for_onnx(cls, *inputs, names=None, target_opset=11) -> 'ONNXTraceSession': |
| 203 | """ |
| 204 | Starting the trace all tensor computation for ONNX graph generation. |
| 205 | :param inputs: the input tensor, could a torch.Tensor or a numpy ndarray. |
| 206 | :param names: The input names the ONNX graph |
| 207 | :param target_opset: The ONNX model opset_version |
| 208 | :return: A tracing session object, in most case, it should be used in the with statement. |
| 209 | """ |
| 210 | self = ONNXTraceSession(target_opset) |
| 211 | self.activated_sessions.append(self) |
| 212 | tensor_set_session(self) |
| 213 | |
| 214 | np_inputs = [np.array(x) if _is_string_type(x) else x for x in inputs] |
| 215 | np_inputs = [ |
| 216 | x if isinstance(x, (np.ndarray, np.generic, torch.Tensor)) or _is_string_type(x) |
| 217 | else torch.tensor(x) for x in np_inputs] |
| 218 | itensors = [tensor_from_torch(i_, None) if isinstance(i_, torch.Tensor) |
| 219 | else tensor_from_onnx(i_, None, None) for i_ in np_inputs] |
| 220 | if names is None: |
| 221 | names = [] |
| 222 | if len(inputs) != len(names): |
| 223 | warnings.warn("the name number doesn't match the inputs', assign to the ones in the front.") |
| 224 | names.extend([''] * (len(inputs) - len(names))) |
| 225 | for idx_ in range(len(inputs)): |
| 226 | names[idx_] = names[idx_] if names[idx_] else "input{}".format(idx_) |
| 227 | num = min(len(itensors), len(names)) |
| 228 | for idx_ in range(num): |
| 229 | itensors[idx_].name = names[idx_] |
| 230 | self.inputs = itensors |
| 231 | return self |
| 232 | |
| 233 | def runops(self, ts_from, ts_to): |
| 234 | nodes = self.container.nodes |
| 235 | inset = set(ts_.name for ts_ in ts_from) |
| 236 | inset.update(iz_.name for iz_ in self.container.initializer) |
| 237 | outset = set(ts_.name for ts_ in ts_to) |
| 238 | missing_ts_set = set() |
| 239 | node_num = len(nodes) - 1 |
| 240 | while node_num >= 0: |
| 241 | node = nodes[node_num] |
| 242 | for ot_ in node.output: |
| 243 | if ot_ in missing_ts_set: |
| 244 | missing_ts_set.remove(ot_) |
| 245 | elif ot_ in outset: |
| 246 | outset.remove(ot_) |
| 247 | for it_ in node.input: |
| 248 | if it_ not in inset: |
| 249 | missing_ts_set.add(it_) |
| 250 | if len(missing_ts_set) == 0: |
| 251 | break |
| 252 | node_num -= 1 |
| 253 | |
| 254 | assert len(outset) == 0, "Some output cannot be in the node list." |
| 255 | assert len(missing_ts_set) == 0, "Some input cannot be in the node list." |
| 256 | collected_nodes = nodes[node_num:] |
| 257 | vi_input = [ONNXModelUtils.value_info_from_numpy(ts_.name, ts_.numpy()) |
| 258 | for ts_ in ts_from] |
| 259 | vi_output = [ONNXModelUtils.value_info_from_numpy(ts_.name, ts_.numpy()) |
| 260 | for ts_ in ts_to] |
| 261 | oxml = ONNXModelUtils.model_from_ops(self.container, |
| 262 | collected_nodes, |
| 263 | vi_input, |
| 264 | vi_output) |
| 265 | result = None |
| 266 | try: |
| 267 | oxfunc = OrtPyFunction.from_model(oxml) |
| 268 | result = oxfunc(*[ts_.numpy() for ts_ in ts_from]) |
| 269 | finally: |
| 270 | if result is None: |
| 271 | onnx.save_model(oxml, 'mt_debmodel.onnx') |
| 272 | |
| 273 | return result if isinstance(result, (list, tuple)) else [result], oxml |
| 274 | |
| 275 | def get_inputs(self): |
| 276 | return self.inputs |
| 277 | |
| 278 | def stack_container(self): |
| 279 | assert self.container is not None, "Stacked container must be in another one." |
| 280 | sub_container = ONNXElementContainer(self.container.target_opset, self.container) |
| 281 | self.container = sub_container |
| 282 | return self.container |
| 283 | |
| 284 | def pop_container(self): |
| 285 | assert self.container.parent is not None, "Cannot pop the root container." |
| 286 | self.container = self.container.parent |
| 287 | return self.container |
| 288 | |
| 289 | @staticmethod |
| 290 | def build_graph(container, ts_inputs, ts_outputs, graph_name=None): |
| 291 | # some constant ops are created to simulate the tensors generated from the runtime in the loop, |
| 292 | # so we need to remove the node here |
| 293 | to_del = [] |
| 294 | input_names = {it_.name: None for it_ in ts_inputs} |
| 295 | for idx_, nd_ in enumerate(container.nodes): |
| 296 | if nd_.op_type == 'Constant' and list(nd_.output)[0] in input_names: |
| 297 | to_del.append(idx_) |
| 298 | |
| 299 | for idx_ in to_del[::-1]: |
| 300 | container.nodes.pop(idx_) |
| 301 | |
| 302 | graph_name = container.get_unique_operator_name('subg') if not graph_name else graph_name |
| 303 | nodes = ONNXModelUtils.unfold_model_node(container) |
| 304 | nodes = ONNXModelUtils.topological_sort(container, nodes, ts_inputs, ts_outputs) |
| 305 | |
| 306 | for vi_ in container.value_info: |
| 307 | if vi_.name in input_names: |
| 308 | input_names[vi_.name] = vi_ |
| 309 | |
| 310 | inputs = [helper.make_tensor_value_info(si.name, si.onnx_type, si.get_shape()) |
| 311 | if input_names.get(si.name) is None else input_names[si.name] for si in ts_inputs] |
| 312 | outputs = [helper.make_tensor_value_info(so.name, so.onnx_type, |
| 313 | so.get_shape()) for so in ts_outputs] |
| 314 | |
| 315 | graph = helper.make_graph(nodes, graph_name, inputs, |
| 316 | outputs, container.initializers) |
| 317 | return graph |
| 318 | |
| 319 | def build_model(self, model_name=None, doc_string=None) -> onnx.ModelProto: |
| 320 | model_name = 'tcm' if model_name is None else model_name |
| 321 | doc_string = '' if doc_string is None else doc_string |
| 322 | container = self.container |
| 323 | graph = self.build_graph(container, self.inputs, self.outputs, model_name) |
| 324 | onnx_model = make_model_ex(graph, container.node_domain_version_pair_sets, |
| 325 | container.target_opset, doc_string=doc_string) |
| 326 | return onnx_model |
| 327 | |
| 328 | def save_as_onnx(self, file_like_or_path, outputs, model_name=None, doc_string=None): |
| 329 | """ |
| 330 | Build the ONNX model from the traced computation graph. |
| 331 | :param file_like_or_path: an io.BytesIO like object or a file path |
| 332 | :param outputs: the output tensor to be specified as the ONNX graph output, |
| 333 | Could be a string if there are multiple output tensors. |
| 334 | :param model_name: The ONNX model internal name |
| 335 | :param doc_string: The doc string for the model |
| 336 | :return: A ONNX ModelProto object. |
| 337 | """ |
| 338 | if len(self.outputs) == 0 and outputs is None: |
| 339 | raise RuntimeError("No output of the graph specified.") |
| 340 | |
| 341 | if len(self.outputs) == 0: |
| 342 | self.outputs = outputs if isinstance(outputs, (list, tuple)) else [outputs] |
| 343 | |
| 344 | m = self.build_model(model_name, doc_string) |
| 345 | |
| 346 | if file_like_or_path is not None: |
| 347 | if _is_path(file_like_or_path): |
| 348 | with open(file_like_or_path, 'wb') as f: |
| 349 | f.write(m.SerializeToString()) |
| 350 | else: |
| 351 | f = file_like_or_path |
| 352 | f.write(m.SerializeToString()) |
| 353 | f.flush() |
| 354 | |
| 355 | return m |
| 356 | |