microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
onnxruntime_extensions/_cuops.py
165lines · modecode
| 1 | # Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | # Licensed under the MIT License. See License.txt in the project root for |
| 3 | # license information. |
| 4 | ############################################################################### |
| 5 | |
| 6 | import onnx |
| 7 | from onnx import onnx_pb as onnx_proto |
| 8 | from ._ocos import default_opset_domain, get_library_path # noqa |
| 9 | |
| 10 | |
| 11 | class CustomOp: |
| 12 | @classmethod |
| 13 | def op_type(cls): |
| 14 | rcls = cls |
| 15 | while CustomOp != rcls.__base__: |
| 16 | rcls = rcls.__base__ |
| 17 | return rcls.__name__ |
| 18 | |
| 19 | @classmethod |
| 20 | def get_inputs(cls): return None |
| 21 | |
| 22 | @classmethod |
| 23 | def get_output(cls): return None |
| 24 | |
| 25 | @classmethod |
| 26 | def serialize_attr(cls, attrs): |
| 27 | """ |
| 28 | Only support serialize the basic python type like list or dict, |
| 29 | All other types needs to be serialized by the users |
| 30 | :param attrs: the dict attributes |
| 31 | :return: the dict of serialized data |
| 32 | """ |
| 33 | return attrs |
| 34 | |
| 35 | io_def = onnx.helper.make_tensor_value_info |
| 36 | |
| 37 | |
| 38 | class GPT2Tokenizer(CustomOp): |
| 39 | @classmethod |
| 40 | def get_inputs(cls): |
| 41 | return [cls.io_def('input_text', onnx_proto.TensorProto.STRING, [None])] |
| 42 | |
| 43 | @classmethod |
| 44 | def get_outputs(cls): |
| 45 | return [cls.io_def("input_ids", onnx.TensorProto.INT64, [None, None]), |
| 46 | cls.io_def('attention_mask', onnx.TensorProto.INT64, [None, None])] |
| 47 | |
| 48 | |
| 49 | class VectorToString(CustomOp): |
| 50 | @classmethod |
| 51 | def get_inputs(cls): |
| 52 | return [cls.io_def("token_ids", onnx.TensorProto.INT64, [])] |
| 53 | |
| 54 | @classmethod |
| 55 | def get_outputs(cls): |
| 56 | return [cls.io_def('text', onnx_proto.TensorProto.STRING, [])] |
| 57 | |
| 58 | @classmethod |
| 59 | def serialize_attr(cls, attrs): |
| 60 | attr_data = {} |
| 61 | for k_, v_ in attrs.items(): |
| 62 | if k_ == 'map' and isinstance(v_, dict): |
| 63 | attr_data[k_] = '\n'.join(k + "\t" + " ".join([str(i) for i in v]) for k, v in v_.items()) |
| 64 | else: |
| 65 | attr_data[k_] = v_ |
| 66 | return attr_data |
| 67 | |
| 68 | |
| 69 | class StringToVector(CustomOp): |
| 70 | @classmethod |
| 71 | def get_inputs(cls): |
| 72 | return [cls.io_def("text", onnx.TensorProto.STRING, [None])] |
| 73 | |
| 74 | @classmethod |
| 75 | def get_outputs(cls): |
| 76 | return [cls.io_def('token_ids', onnx_proto.TensorProto.INT64, [])] |
| 77 | |
| 78 | @classmethod |
| 79 | def serialize_attr(cls, attrs): |
| 80 | attr_data = {} |
| 81 | for k_, v_ in attrs.items(): |
| 82 | if k_ == 'map' and isinstance(v_, dict): |
| 83 | attr_data[k_] = '\n'.join(k + "\t" + " ".join([str(i) for i in v]) for k, v in v_.items()) |
| 84 | elif k_ == 'unk' and isinstance(v_, list): |
| 85 | attr_data[k_] = ' '.join(str(i) for i in v_) |
| 86 | else: |
| 87 | attr_data[k_] = v_ |
| 88 | return attr_data |
| 89 | |
| 90 | |
| 91 | class BlingFireSentenceBreaker(CustomOp): |
| 92 | @classmethod |
| 93 | def get_inputs(cls): |
| 94 | return [cls.io_def("text", onnx.TensorProto.STRING, [None])] |
| 95 | |
| 96 | @classmethod |
| 97 | def get_outputs(cls): |
| 98 | return [cls.io_def('sentence', onnx_proto.TensorProto.STRING, [])] |
| 99 | |
| 100 | @classmethod |
| 101 | def serialize_attr(cls, attrs): |
| 102 | attrs_data = {} |
| 103 | for k_, v_ in attrs.items(): |
| 104 | if k_ == 'model': |
| 105 | with open(v_, "rb") as model_file: |
| 106 | attrs_data[k_] = model_file.read() |
| 107 | else: |
| 108 | attrs_data[k_] = v_ |
| 109 | return attrs_data |
| 110 | # TODO: list all custom operators schema here: |
| 111 | # ... |
| 112 | # ... |
| 113 | class SentencepieceTokenizer(CustomOp): |
| 114 | @classmethod |
| 115 | def get_inputs(cls): |
| 116 | return [ |
| 117 | cls.io_def('inputs', onnx_proto.TensorProto.STRING, [None]), |
| 118 | cls.io_def('nbest_size', onnx_proto.TensorProto.INT64, [None]), |
| 119 | cls.io_def('alpha', onnx_proto.TensorProto.FLOAT, [None]), |
| 120 | cls.io_def('add_bos', onnx_proto.TensorProto.BOOL, [None]), |
| 121 | cls.io_def('add_eos', onnx_proto.TensorProto.BOOL, [None]), |
| 122 | cls.io_def('reverse', onnx_proto.TensorProto.BOOL, [None]) |
| 123 | ] |
| 124 | |
| 125 | @classmethod |
| 126 | def get_outputs(cls): |
| 127 | return [ |
| 128 | cls.io_def('tokens', onnx_proto.TensorProto.INT32, [None]), |
| 129 | cls.io_def('indices', onnx_proto.TensorProto.INT64, [None]) |
| 130 | |
| 131 | ] |
| 132 | |
| 133 | |
| 134 | class SingleOpGraph: |
| 135 | @classmethod |
| 136 | def get_next_id(cls): |
| 137 | if not hasattr(cls, '_id_counter'): |
| 138 | cls._id_counter = 0 |
| 139 | cls._id_counter += 1 |
| 140 | return cls._id_counter |
| 141 | |
| 142 | @classmethod |
| 143 | def build_my_graph(cls, op_class, *args, **kwargs): |
| 144 | if isinstance(op_class, str): |
| 145 | op_class = cls.get_op_class(op_class) |
| 146 | |
| 147 | op_type = op_class.op_type() |
| 148 | inputs = op_class.get_inputs() |
| 149 | outputs = op_class.get_outputs() |
| 150 | attrs = op_class.serialize_attr(kwargs) |
| 151 | cuop = onnx.helper.make_node(op_type, |
| 152 | [i_.name for i_ in inputs], |
| 153 | [o_.name for o_ in outputs], |
| 154 | "{}_{}".format(op_type, cls.get_next_id()), |
| 155 | **attrs, |
| 156 | domain=default_opset_domain()) |
| 157 | graph = onnx.helper.make_graph([cuop], |
| 158 | "og_{}_{}".format(op_type, cls.get_next_id()), |
| 159 | inputs, |
| 160 | outputs) |
| 161 | return graph |
| 162 | |
| 163 | @staticmethod |
| 164 | def get_op_class(op_type): |
| 165 | return globals()[op_type] |
| 166 | |