microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
13d9e27ccd8a0de9a1225756fbf6860a1931484f

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/_cuops.py

389lines · modecode

1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License. See License.txt in the project root for
3# license information.
4###############################################################################
5
6import onnx
7import numpy
8from onnx import onnx_pb as onnx_proto
9from ._ocos import default_opset_domain, Opdef, PyCustomOpDef
10
11
12class CustomOp:
13
14 @classmethod
15 def op_type(cls):
16 rcls = cls
17 while CustomOp != rcls.__base__:
18 rcls = rcls.__base__
19 return rcls.__name__
20
21 @classmethod
22 def get_inputs(cls):
23 return None
24
25 @classmethod
26 def get_outputs(cls):
27 return None
28
29 @classmethod
30 def serialize_attr(cls, attrs):
31 """
32 Only support serialize the basic python type like list or dict,
33 All other types needs to be serialized by the users
34 :param attrs: the dict attributes
35 :return: the dict of serialized data
36 """
37 return attrs
38
39 io_def = onnx.helper.make_tensor_value_info
40
41
42class GPT2Tokenizer(CustomOp):
43
44 @classmethod
45 def get_inputs(cls):
46 return [
47 cls.io_def('input_text', onnx_proto.TensorProto.STRING, [None])
48 ]
49
50 @classmethod
51 def get_outputs(cls):
52 return [
53 cls.io_def("input_ids", onnx.TensorProto.INT64, [None, None]),
54 cls.io_def('attention_mask', onnx.TensorProto.INT64, [None, None])
55 ]
56
57
58class VectorToString(CustomOp):
59
60 @classmethod
61 def get_inputs(cls):
62 return [cls.io_def("token_ids", onnx.TensorProto.INT64, [])]
63
64 @classmethod
65 def get_outputs(cls):
66 return [cls.io_def('text', onnx_proto.TensorProto.STRING, [None])]
67
68 @classmethod
69 def serialize_attr(cls, attrs):
70 attr_data = {}
71 for k_, v_ in attrs.items():
72 if k_ == 'map' and isinstance(v_, dict):
73 attr_data[k_] = '\n'.join(k + "\t" +
74 " ".join([str(i) for i in v])
75 for k, v in v_.items())
76 elif k_ == 'map' and isinstance(v_, str):
77 attr_data[k_] = v_
78 else:
79 attr_data[k_] = v_
80 return attr_data
81
82
83class StringMapping(CustomOp):
84
85 @classmethod
86 def get_inputs(cls):
87 return [cls.io_def("input", onnx.TensorProto.STRING, [])]
88
89 @classmethod
90 def get_outputs(cls):
91 return [cls.io_def('output', onnx_proto.TensorProto.STRING, [])]
92
93 @classmethod
94 def serialize_attr(cls, attrs):
95 attr_data = {}
96 for k_, v_ in attrs.items():
97 if k_ == 'map' and isinstance(v_, dict):
98 attr_data[k_] = '\n'.join(k + "\t" + v for k, v in v_.items())
99 elif k_ == 'map' and isinstance(v_, str):
100 attr_data[k_] = v_
101 else:
102 attr_data[k_] = v_
103 return attr_data
104
105
106class MaskedFill(CustomOp):
107
108 @classmethod
109 def get_inputs(cls):
110 return [
111 cls.io_def("value", onnx.TensorProto.STRING, [None]),
112 cls.io_def("mask", onnx.TensorProto.BOOL, [None])
113 ]
114
115 @classmethod
116 def get_outputs(cls):
117 return [cls.io_def('output', onnx_proto.TensorProto.STRING, [None])]
118
119
120class StringToVector(CustomOp):
121
122 @classmethod
123 def get_inputs(cls):
124 return [cls.io_def("text", onnx.TensorProto.STRING, [None])]
125
126 @classmethod
127 def get_outputs(cls):
128 return [cls.io_def('token_ids', onnx_proto.TensorProto.INT64, [])]
129
130 @classmethod
131 def serialize_attr(cls, attrs):
132 attr_data = {}
133 for k_, v_ in attrs.items():
134 if k_ == 'map' and isinstance(v_, dict):
135 attr_data[k_] = '\n'.join(k + "\t" +
136 " ".join([str(i) for i in v])
137 for k, v in v_.items())
138 elif k_ == 'map' and isinstance(v_, str):
139 attr_data[k_] = v_
140 elif k_ == 'unk' and isinstance(v_, list):
141 attr_data[k_] = ' '.join(str(i) for i in v_)
142 else:
143 attr_data[k_] = v_
144 return attr_data
145
146
147class BlingFireSentenceBreaker(CustomOp):
148
149 @classmethod
150 def get_inputs(cls):
151 return [cls.io_def("text", onnx.TensorProto.STRING, [None])]
152
153 @classmethod
154 def get_outputs(cls):
155 return [cls.io_def('sentence', onnx_proto.TensorProto.STRING, [])]
156
157 @classmethod
158 def serialize_attr(cls, attrs):
159 attrs_data = {}
160 for k_, v_ in attrs.items():
161 if k_ == 'model':
162 with open(v_, "rb") as model_file:
163 attrs_data[k_] = model_file.read()
164 else:
165 attrs_data[k_] = v_
166 return attrs_data
167
168
169class SegmentExtraction(CustomOp):
170
171 @classmethod
172 def get_inputs(cls):
173 return [cls.io_def("input", onnx.TensorProto.INT64, [None, None])]
174
175 @classmethod
176 def get_outputs(cls):
177 return [
178 cls.io_def('position', onnx_proto.TensorProto.INT64, [None, 2]),
179 cls.io_def('value', onnx_proto.TensorProto.INT64, [None])
180 ]
181
182
183class BertTokenizer(CustomOp):
184
185 @classmethod
186 def get_inputs(cls):
187 return [cls.io_def("text", onnx.TensorProto.STRING, [None])]
188
189 @classmethod
190 def get_outputs(cls):
191 return [
192 cls.io_def('input_ids', onnx_proto.TensorProto.INT64, [None]),
193 cls.io_def('token_type_ids', onnx_proto.TensorProto.INT64, [None]),
194 cls.io_def('attention_mask', onnx_proto.TensorProto.INT64, [None])
195 ]
196
197 @classmethod
198 def serialize_attr(cls, attrs):
199 attrs_data = {}
200 for k_, v_ in attrs.items():
201 if k_ == 'vocab_file':
202 with open(v_, "r", encoding='utf-8') as model_file:
203 lines = model_file.readlines()
204 attrs_data[k_] = '\n'.join(lines)
205 else:
206 attrs_data[k_] = v_
207 return attrs_data
208
209
210class StringECMARegexReplace(CustomOp):
211
212 @classmethod
213 def get_inputs(cls):
214 return [
215 cls.io_def("input", onnx.TensorProto.STRING, [None]),
216 cls.io_def("pattern", onnx.TensorProto.STRING, [None]),
217 cls.io_def("rewrite", onnx.TensorProto.STRING, [None])
218 ]
219
220 @classmethod
221 def get_outputs(cls):
222 return [cls.io_def('output', onnx_proto.TensorProto.STRING, [None])]
223
224
225class BertTokenizerDecoder(CustomOp):
226
227 @classmethod
228 def get_inputs(cls):
229 return [
230 cls.io_def("ids", onnx.TensorProto.INT64, [None]),
231 cls.io_def("position", onnx.TensorProto.INT64, [None, None])
232 ]
233
234 @classmethod
235 def get_outputs(cls):
236 return [cls.io_def('str', onnx_proto.TensorProto.STRING, [None])]
237
238 @classmethod
239 def serialize_attr(cls, attrs):
240 attrs_data = {}
241 for k_, v_ in attrs.items():
242 if k_ == 'vocab_file':
243 with open(v_, "r", encoding='utf-8') as model_file:
244 lines = model_file.readlines()
245 attrs_data[k_] = '\n'.join(lines)
246 else:
247 attrs_data[k_] = v_
248 return attrs_data
249
250
251class SentencepieceTokenizer(CustomOp):
252
253 @classmethod
254 def get_inputs(cls):
255 return [
256 cls.io_def('inputs', onnx_proto.TensorProto.STRING, [None]),
257 cls.io_def('nbest_size', onnx_proto.TensorProto.INT64, [None]),
258 cls.io_def('alpha', onnx_proto.TensorProto.FLOAT, [None]),
259 cls.io_def('add_bos', onnx_proto.TensorProto.BOOL, [None]),
260 cls.io_def('add_eos', onnx_proto.TensorProto.BOOL, [None]),
261 cls.io_def('reverse', onnx_proto.TensorProto.BOOL, [None])
262 ]
263
264 @classmethod
265 def get_outputs(cls):
266 return [
267 cls.io_def('tokens', onnx_proto.TensorProto.INT32, [None]),
268 cls.io_def('indices', onnx_proto.TensorProto.INT64, [None])
269 ]
270
271
272class SentencepieceDecoder(CustomOp):
273
274 @classmethod
275 def get_inputs(cls):
276 return [
277 cls.io_def("ids", onnx.TensorProto.INT64, [None])
278 ]
279
280 @classmethod
281 def get_outputs(cls):
282 return [cls.io_def('str', onnx_proto.TensorProto.STRING, [None])]
283
284
285class Inverse(CustomOp):
286
287 @classmethod
288 def get_inputs(cls):
289 return [
290 cls.io_def('input', onnx_proto.TensorProto.FLOAT, [None, None])
291 ]
292
293 @classmethod
294 def get_outputs(cls):
295 return [
296 cls.io_def('output', onnx_proto.TensorProto.FLOAT, [None, None])
297 ]
298
299
300class ImageReader(CustomOp):
301
302 @classmethod
303 def get_inputs(cls):
304 return [
305 cls.io_def('image_paths', onnx_proto.TensorProto.STRING, [None])
306 ]
307
308 @classmethod
309 def get_outputs(cls):
310 return [
311 cls.io_def('nchw_bytes', onnx_proto.TensorProto.UINT8, [None, None, None, None])
312 ]
313
314
315class GaussianBlur(CustomOp):
316
317 @classmethod
318 def get_inputs(cls):
319 return [
320 cls.io_def('nhwc', onnx_proto.TensorProto.FLOAT, [None, None, None, None]),
321 cls.io_def('kernel_size', onnx_proto.TensorProto.INT64, [None]),
322 cls.io_def('sigma_xy', onnx_proto.TensorProto.DOUBLE, [None])
323 ]
324
325 @classmethod
326 def get_outputs(cls):
327 return [
328 cls.io_def('gb_nhwc', onnx_proto.TensorProto.FLOAT, [None, None, None, None])
329 ]
330
331
332class ImageDecoder(CustomOp):
333
334 @classmethod
335 def get_inputs(cls):
336 return [
337 cls.io_def('raw_input_image', onnx_proto.TensorProto.UINT8, [])
338 ]
339
340 @classmethod
341 def get_outputs(cls):
342 return [
343 cls.io_def('decoded_image', onnx_proto.TensorProto.UINT8, [None, None, 3])
344 ]
345
346
347class SingleOpGraph:
348
349 @classmethod
350 def get_next_id(cls):
351 if not hasattr(cls, '_id_counter'):
352 cls._id_counter = 0
353 cls._id_counter += 1
354 return cls._id_counter
355
356 @classmethod
357 def build_my_graph(cls, op_class, *args, **kwargs):
358 if isinstance(op_class, str):
359 op_class = cls.get_op_class(op_class)
360
361 op_type = op_class.op_type()
362 inputs = op_class.get_inputs()
363 outputs = op_class.get_outputs()
364 attrs = op_class.serialize_attr(kwargs)
365 cuop = onnx.helper.make_node(op_type, [i_.name for i_ in inputs],
366 [o_.name for o_ in outputs],
367 "{}_{}".format(op_type,
368 cls.get_next_id()),
369 **attrs,
370 domain=default_opset_domain())
371 graph = onnx.helper.make_graph([cuop], "og_{}_{}".format(
372 op_type, cls.get_next_id()), inputs, outputs)
373 return graph
374
375 @staticmethod
376 def get_op_class(op_type):
377 return globals()[op_type]
378
379
380# TODO: have a C++ impl.
381def _argsort_op(x, dim):
382 d = numpy.argsort(x, dim)
383 return d[:, ::-1]
384
385
386Opdef.create(_argsort_op,
387 op_type='ArgSort',
388 inputs=[PyCustomOpDef.dt_float, PyCustomOpDef.dt_int64],
389 outputs=[PyCustomOpDef.dt_int64])