microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
onnxruntime_extensions/pnp/_utils.py
302lines · modecode
| 1 | import copy |
| 2 | import onnx |
| 3 | from onnx import helper, numpy_helper |
| 4 | from collections import namedtuple |
| 5 | |
| 6 | |
| 7 | class _Container: |
| 8 | def __init__(self): |
| 9 | self.parent = None |
| 10 | self.initializer=[] |
| 11 | self.value_info=[] |
| 12 | self.nodes = [] |
| 13 | self.node_domain_version_pair_sets = {} |
| 14 | |
| 15 | def add_model(self, oxml): |
| 16 | self.initializer.extend(oxml.graph.initializer) |
| 17 | self.value_info.extend(oxml.graph.value_info) |
| 18 | self.nodes.extend(oxml.graph.node) |
| 19 | self.node_domain_version_pair_sets.update( |
| 20 | [(opset_.domain, opset_.version) for opset_ in oxml.opset_import]) |
| 21 | return self |
| 22 | |
| 23 | |
| 24 | class ONNXModelUtils: |
| 25 | @staticmethod |
| 26 | def merge_name(prefix, name): |
| 27 | return "{}_{}".format(prefix, name) |
| 28 | |
| 29 | @staticmethod |
| 30 | def _rename_iter(iterables, prefix_name, inplace=False): |
| 31 | new_iz = iterables if inplace else [copy.deepcopy(iz_) for iz_ in iterables] |
| 32 | for iz_ in new_iz: |
| 33 | iz_.name = ONNXModelUtils.merge_name(prefix_name, iz_.name) |
| 34 | return new_iz |
| 35 | |
| 36 | @classmethod |
| 37 | def _rename_graph(cls, graph, prefix, graph_or_container): |
| 38 | def io_rename(node, prefix_name, idx): |
| 39 | new_node = copy.deepcopy(node) |
| 40 | if not node.name: |
| 41 | new_node.name = cls.merge_name(prefix_name, "op{}".format(idx)) |
| 42 | else: |
| 43 | new_node.name = cls.merge_name(prefix_name, node.name) |
| 44 | |
| 45 | del new_node.input[:] |
| 46 | new_node.input.extend(ONNXModelUtils.merge_name(prefix_name, nm_) if nm_ else '' for nm_ in node.input) |
| 47 | del new_node.output[:] |
| 48 | new_node.output.extend(ONNXModelUtils.merge_name(prefix_name, nm_) if nm_ else '' for nm_ in node.output) |
| 49 | return new_node |
| 50 | |
| 51 | assert prefix is not None, 'The graph prefix could not be None' |
| 52 | graph_or_container.initializer.extend(cls._rename_iter(graph.initializer, prefix)) |
| 53 | graph_or_container.value_info.extend(cls._rename_iter(graph.value_info, prefix)) |
| 54 | return list(io_rename(nd_, prefix, idx_) for idx_, nd_ in enumerate(graph.node)) |
| 55 | |
| 56 | @classmethod |
| 57 | def _process_node_body(cls, node, prefix): |
| 58 | if all(attr.name != 'body' for attr in node.attribute): |
| 59 | return node |
| 60 | |
| 61 | def _process_attr(attr, prefix_name): |
| 62 | if attr.name == 'body': |
| 63 | new_attr = copy.deepcopy(attr) |
| 64 | del new_attr.g.value_info[:] |
| 65 | del new_attr.g.node[:] |
| 66 | new_attr.g.node.extend(cls._rename_graph(attr.g, prefix_name, new_attr.g)) |
| 67 | cls._rename_iter(new_attr.g.input, prefix_name, inplace=True) |
| 68 | cls._rename_iter(new_attr.g.output, prefix_name, inplace=True) |
| 69 | return new_attr |
| 70 | else: |
| 71 | return attr |
| 72 | |
| 73 | attr_list = list(_process_attr(attr_, prefix) for attr_ in node.attribute) |
| 74 | del node.attribute[:] |
| 75 | node.attribute.extend(attr_list) |
| 76 | return node |
| 77 | |
| 78 | @staticmethod |
| 79 | def get_model_name_abbr(node): |
| 80 | no = node.name.split('_')[-1] |
| 81 | return 'm_' + no |
| 82 | |
| 83 | @staticmethod |
| 84 | def get_model_id_from_arg0(nodes, node): |
| 85 | arg0_name = node.input[0] |
| 86 | c_node = [n_ for n_ in nodes if |
| 87 | n_.op_type == 'Constant' and n_.output[0] == arg0_name] |
| 88 | assert len(c_node) == 1, 'internal error, multiple nodes with the same output.' |
| 89 | c_node = c_node[0] |
| 90 | tensor_value = onnx.helper.get_attribute_value(c_node.attribute[0]) |
| 91 | _id = numpy_helper.to_array(tensor_value).item() |
| 92 | return _id |
| 93 | |
| 94 | @classmethod |
| 95 | def _unfold_model_node(cls, container, name, model, io_mapping=None): |
| 96 | top_container = container |
| 97 | while top_container.parent is not None: # only one opset_import in the model. |
| 98 | top_container = top_container.parent |
| 99 | |
| 100 | renamed_nodes = cls._rename_graph(model.graph, name, container) |
| 101 | onnx_nodes = [cls._process_node_body(nd_, name) for nd_ in renamed_nodes] |
| 102 | |
| 103 | top_container.node_domain_version_pair_sets.update( |
| 104 | [(opset_.domain, opset_.version) for opset_ in model.opset_import]) |
| 105 | return onnx_nodes |
| 106 | |
| 107 | @classmethod |
| 108 | def unfold_model(cls, oxml, id_to_model, io_mapping=None): |
| 109 | container = _Container().add_model(oxml) |
| 110 | nodes = [] |
| 111 | for _nid, _node in enumerate(oxml.graph.node): |
| 112 | if _node.op_type != '_ModelFunctionCall': |
| 113 | nodes.append(_node) |
| 114 | else: |
| 115 | model_id = cls.get_model_id_from_arg0(list(oxml.graph.node), _node) |
| 116 | if model_id not in id_to_model: |
| 117 | raise RuntimeError("Cannot find the model id({}) in the table".format(model_id)) |
| 118 | |
| 119 | prefix = cls.get_model_name_abbr(_node) |
| 120 | nest_model = id_to_model[model_id] |
| 121 | |
| 122 | input_mapping = [] |
| 123 | output_mapping = [] |
| 124 | for idx_, in_ in enumerate(nest_model.graph.input): |
| 125 | _renamed_in = "{}_{}".format(prefix, in_.name) |
| 126 | _nd = onnx.helper.make_node('Identity', |
| 127 | [_node.input[idx_ + 1]], # the first arg is model id, skip it. |
| 128 | [_renamed_in], |
| 129 | name='i_' + _renamed_in) |
| 130 | input_mapping.append(_nd) |
| 131 | nds = cls._unfold_model_node(container, |
| 132 | prefix, |
| 133 | nest_model, |
| 134 | io_mapping) |
| 135 | for idx_, out_ in enumerate(nest_model.graph.output): |
| 136 | if idx_ >= len(_node.output): |
| 137 | continue |
| 138 | _renamed_out = "{}_{}".format(prefix, out_.name) |
| 139 | _nd = onnx.helper.make_node('Identity', |
| 140 | [_renamed_out], |
| 141 | [_node.output[idx_]], |
| 142 | name='o_' + _renamed_out) |
| 143 | output_mapping.append(_nd) |
| 144 | if io_mapping is not None: |
| 145 | assert callable(io_mapping), "io_mapping is a custom function to build the linkage of the models" |
| 146 | input_mapping, output_mapping = io_mapping(input_mapping, output_mapping) |
| 147 | # attention: the order of the list operations is important, which avoids the topological sort. |
| 148 | nodes.extend(input_mapping) |
| 149 | nodes.extend(nds) |
| 150 | nodes.extend(output_mapping) |
| 151 | |
| 152 | intlzs = cls._remove_unused_initializers(nodes, container.initializer) |
| 153 | oxml = copy.deepcopy(oxml) |
| 154 | del oxml.graph.node[:] |
| 155 | oxml.graph.node.extend(nodes) |
| 156 | del oxml.graph.initializer[:] |
| 157 | oxml.graph.initializer.extend(intlzs) |
| 158 | return oxml |
| 159 | |
| 160 | @classmethod |
| 161 | def topological_sort(cls, container, nodes, inputs, outputs): |
| 162 | op_output_map = {} |
| 163 | DynNode = namedtuple('DynNode', ['name', 'output']) |
| 164 | input_nodes = [DynNode(name='placeholder', |
| 165 | output=[nm_.name for nm_ in inputs] + |
| 166 | [it_.name for it_ in container.initializers])] + \ |
| 167 | [nd_ for nd_ in nodes if nd_.op_type == 'Constant'] |
| 168 | |
| 169 | for nd_ in nodes + input_nodes: |
| 170 | for ky_ in nd_.output: |
| 171 | op_output_map[ky_] = nd_ |
| 172 | |
| 173 | edges = {} |
| 174 | for op in nodes: |
| 175 | for x in op.input: |
| 176 | if x == '': |
| 177 | continue |
| 178 | try: |
| 179 | predecessor = op_output_map[x] |
| 180 | except KeyError: |
| 181 | raise RuntimeError( |
| 182 | "{}: cannot find an operator to produce the tensor: {}".format(op.name, x)) from None |
| 183 | |
| 184 | val = edges.get(predecessor.name, []) |
| 185 | val.append(op) |
| 186 | edges[predecessor.name] = val |
| 187 | |
| 188 | for y_ in outputs: |
| 189 | op = op_output_map[y_.name].name |
| 190 | if op not in edges: |
| 191 | edges[op] = [] |
| 192 | |
| 193 | visited = set() |
| 194 | sorted_nodes = [] |
| 195 | unfinished_nodes = set() |
| 196 | |
| 197 | def recursive_helper(node): |
| 198 | if node.name in visited: |
| 199 | return |
| 200 | |
| 201 | if node.name in unfinished_nodes: |
| 202 | raise RuntimeError("ONNX Graph is not a DAG, the cycle is found at {}".format(node.name)) |
| 203 | |
| 204 | unfinished_nodes.add(node.name) |
| 205 | if node.name in edges: # if the node's output is not in the Graph output. |
| 206 | assert node.name != '', 'this topological-sort depends on the unique node name.' |
| 207 | for successor in edges[node.name]: |
| 208 | recursive_helper(successor) |
| 209 | |
| 210 | unfinished_nodes.remove(node.name) |
| 211 | visited.add(node.name) |
| 212 | if node is not input_nodes[0]: |
| 213 | sorted_nodes.insert(0, node) |
| 214 | |
| 215 | for nd_ in input_nodes: |
| 216 | recursive_helper(nd_) |
| 217 | |
| 218 | return sorted_nodes |
| 219 | |
| 220 | @staticmethod |
| 221 | def _remove_unused_initializers(nodes, initializers, reserved_names=None): |
| 222 | if reserved_names is None: |
| 223 | reserved_names = set() |
| 224 | nodes_input_set = set() |
| 225 | for nd_ in nodes: |
| 226 | nodes_input_set.update(n_ for n_ in nd_.input) |
| 227 | |
| 228 | return [intlz_ for intlz_ in initializers if intlz_.name in nodes_input_set or intlz_.name in reserved_names] |
| 229 | |
| 230 | @classmethod |
| 231 | def join_models(cls, *models, io_mapping=None): |
| 232 | # generate the prefix id for the embedding graph to avoid the name conflict |
| 233 | mdl_prefix = [] |
| 234 | for _i in range(len(models)): |
| 235 | mdl_prefix.append("g{}".format(_i + 1)) |
| 236 | |
| 237 | inputs = cls._rename_iter(models[0].graph.input, mdl_prefix[0]) |
| 238 | outputs = cls._rename_iter(models[-1].graph.output, mdl_prefix[-1]) |
| 239 | |
| 240 | port_mapping = {} |
| 241 | if io_mapping is not None: |
| 242 | assert callable(io_mapping), "io_mapping is a custom function to build the linkage of the models" |
| 243 | ModelPort = namedtuple('ModelPort', "input output") |
| 244 | ports = [] |
| 245 | for _idx in range(len(models)): |
| 246 | mio = ModelPort([cls.merge_name(mdl_prefix[_idx], _x.name) for _x in models[_idx].graph.input], |
| 247 | [cls.merge_name(mdl_prefix[_idx], _y.name) for _y in models[_idx].graph.output]) |
| 248 | ports.append(mio) |
| 249 | port_mapping = io_mapping(ports) |
| 250 | for _idx in range(len(models) - 1): |
| 251 | for _i, _x in enumerate(models[_idx + 1].graph.input): |
| 252 | iname = cls.merge_name(mdl_prefix[_idx + 1], _x.name) |
| 253 | if iname not in port_mapping: |
| 254 | oname = cls.merge_name(mdl_prefix[_idx], models[_idx].graph.output[_i].name) |
| 255 | port_mapping[iname] = oname |
| 256 | |
| 257 | nodes = [] |
| 258 | container = _Container() |
| 259 | for _idx, _m in enumerate(models): |
| 260 | container.add_model(_m) |
| 261 | nodes += cls._rename_graph(_m.graph, mdl_prefix[_idx], container) |
| 262 | |
| 263 | for _n in nodes: |
| 264 | replaceable = False |
| 265 | for _i in _n.input: |
| 266 | if _i in port_mapping: |
| 267 | replaceable = True |
| 268 | break |
| 269 | if replaceable: |
| 270 | new_input = copy.deepcopy(_n.input) |
| 271 | del _n.input[:] |
| 272 | _n.input.extend([port_mapping[_i] if _i in port_mapping else _i for _i in new_input]) |
| 273 | |
| 274 | name = "_".join([_mdl.graph.name for _mdl in models]) |
| 275 | domains = set() |
| 276 | _opset = [] |
| 277 | for _mdl in models: |
| 278 | for _ops in _mdl.opset_import: |
| 279 | domain = _ops.domain if _ops.domain else "ai.onnx" |
| 280 | if domain in domains: |
| 281 | if domain == "ai.onnx": |
| 282 | assert _ops.version == _opset[0].version, \ |
| 283 | f"ai.onnx domain version doesn't match {_ops.version} != {_opset[0].version}" |
| 284 | else: |
| 285 | domains.add(domain) |
| 286 | if domain == "ai.onnx": |
| 287 | _opset.insert(0, _ops) |
| 288 | else: |
| 289 | _opset.append(_ops) |
| 290 | |
| 291 | inits = cls._remove_unused_initializers(nodes, container.initializer) |
| 292 | g = helper.make_graph(nodes, name, inputs, outputs, |
| 293 | initializer=inits, |
| 294 | value_info=container.value_info) |
| 295 | |
| 296 | if hasattr(helper, 'make_model_gen_version'): |
| 297 | # make_model_gen_version doesn't accept the custom domain. |
| 298 | m = helper.make_model_gen_version(g, opset_imports=_opset[:1]) |
| 299 | m.opset_import.extend(_opset[1:]) |
| 300 | else: |
| 301 | m = helper.make_model(g, opset_imports=_opset) |
| 302 | return m |
| 303 | |