microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
3e82549bcb53bc93d04c25d64f16a52d495725c1

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/_cuops.py

165lines · modecode

1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License. See License.txt in the project root for
3# license information.
4###############################################################################
5
6import onnx
7from onnx import onnx_pb as onnx_proto
8from ._ocos import default_opset_domain, get_library_path # noqa
9
10
11class CustomOp:
12 @classmethod
13 def op_type(cls):
14 rcls = cls
15 while CustomOp != rcls.__base__:
16 rcls = rcls.__base__
17 return rcls.__name__
18
19 @classmethod
20 def get_inputs(cls): return None
21
22 @classmethod
23 def get_output(cls): return None
24
25 @classmethod
26 def serialize_attr(cls, attrs):
27 """
28 Only support serialize the basic python type like list or dict,
29 All other types needs to be serialized by the users
30 :param attrs: the dict attributes
31 :return: the dict of serialized data
32 """
33 return attrs
34
35 io_def = onnx.helper.make_tensor_value_info
36
37
38class GPT2Tokenizer(CustomOp):
39 @classmethod
40 def get_inputs(cls):
41 return [cls.io_def('input_text', onnx_proto.TensorProto.STRING, [None])]
42
43 @classmethod
44 def get_outputs(cls):
45 return [cls.io_def("input_ids", onnx.TensorProto.INT64, [None, None]),
46 cls.io_def('attention_mask', onnx.TensorProto.INT64, [None, None])]
47
48
49class VectorToString(CustomOp):
50 @classmethod
51 def get_inputs(cls):
52 return [cls.io_def("token_ids", onnx.TensorProto.INT64, [])]
53
54 @classmethod
55 def get_outputs(cls):
56 return [cls.io_def('text', onnx_proto.TensorProto.STRING, [])]
57
58 @classmethod
59 def serialize_attr(cls, attrs):
60 attr_data = {}
61 for k_, v_ in attrs.items():
62 if k_ == 'map' and isinstance(v_, dict):
63 attr_data[k_] = '\n'.join(k + "\t" + " ".join([str(i) for i in v]) for k, v in v_.items())
64 else:
65 attr_data[k_] = v_
66 return attr_data
67
68
69class StringToVector(CustomOp):
70 @classmethod
71 def get_inputs(cls):
72 return [cls.io_def("text", onnx.TensorProto.STRING, [None])]
73
74 @classmethod
75 def get_outputs(cls):
76 return [cls.io_def('token_ids', onnx_proto.TensorProto.INT64, [])]
77
78 @classmethod
79 def serialize_attr(cls, attrs):
80 attr_data = {}
81 for k_, v_ in attrs.items():
82 if k_ == 'map' and isinstance(v_, dict):
83 attr_data[k_] = '\n'.join(k + "\t" + " ".join([str(i) for i in v]) for k, v in v_.items())
84 elif k_ == 'unk' and isinstance(v_, list):
85 attr_data[k_] = ' '.join(str(i) for i in v_)
86 else:
87 attr_data[k_] = v_
88 return attr_data
89
90
91class BlingFireSentenceBreaker(CustomOp):
92 @classmethod
93 def get_inputs(cls):
94 return [cls.io_def("text", onnx.TensorProto.STRING, [None])]
95
96 @classmethod
97 def get_outputs(cls):
98 return [cls.io_def('sentence', onnx_proto.TensorProto.STRING, [])]
99
100 @classmethod
101 def serialize_attr(cls, attrs):
102 attrs_data = {}
103 for k_, v_ in attrs.items():
104 if k_ == 'model':
105 with open(v_, "rb") as model_file:
106 attrs_data[k_] = model_file.read()
107 else:
108 attrs_data[k_] = v_
109 return attrs_data
110# TODO: list all custom operators schema here:
111# ...
112# ...
113class SentencepieceTokenizer(CustomOp):
114 @classmethod
115 def get_inputs(cls):
116 return [
117 cls.io_def('inputs', onnx_proto.TensorProto.STRING, [None]),
118 cls.io_def('nbest_size', onnx_proto.TensorProto.INT64, [None]),
119 cls.io_def('alpha', onnx_proto.TensorProto.FLOAT, [None]),
120 cls.io_def('add_bos', onnx_proto.TensorProto.BOOL, [None]),
121 cls.io_def('add_eos', onnx_proto.TensorProto.BOOL, [None]),
122 cls.io_def('reverse', onnx_proto.TensorProto.BOOL, [None])
123 ]
124
125 @classmethod
126 def get_outputs(cls):
127 return [
128 cls.io_def('tokens', onnx_proto.TensorProto.INT32, [None]),
129 cls.io_def('indices', onnx_proto.TensorProto.INT64, [None])
130
131 ]
132
133
134class SingleOpGraph:
135 @classmethod
136 def get_next_id(cls):
137 if not hasattr(cls, '_id_counter'):
138 cls._id_counter = 0
139 cls._id_counter += 1
140 return cls._id_counter
141
142 @classmethod
143 def build_my_graph(cls, op_class, *args, **kwargs):
144 if isinstance(op_class, str):
145 op_class = cls.get_op_class(op_class)
146
147 op_type = op_class.op_type()
148 inputs = op_class.get_inputs()
149 outputs = op_class.get_outputs()
150 attrs = op_class.serialize_attr(kwargs)
151 cuop = onnx.helper.make_node(op_type,
152 [i_.name for i_ in inputs],
153 [o_.name for o_ in outputs],
154 "{}_{}".format(op_type, cls.get_next_id()),
155 **attrs,
156 domain=default_opset_domain())
157 graph = onnx.helper.make_graph([cuop],
158 "og_{}_{}".format(op_type, cls.get_next_id()),
159 inputs,
160 outputs)
161 return graph
162
163 @staticmethod
164 def get_op_class(op_type):
165 return globals()[op_type]
166