microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
natke-patch-1

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/_cuops.py

471lines · modecode

1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License. See License.txt in the project root for
3# license information.
4###############################################################################
5
6import onnx
7import numpy
8from onnx import onnx_pb as onnx_proto
9from ._ocos import default_opset_domain, Opdef, PyCustomOpDef
10
11
12class CustomOp:
13
14 @classmethod
15 def op_type(cls):
16 rcls = cls
17 while CustomOp != rcls.__base__:
18 rcls = rcls.__base__
19 return rcls.__name__
20
21 @classmethod
22 def get_inputs(cls):
23 return None
24
25 @classmethod
26 def get_outputs(cls):
27 return None
28
29 @classmethod
30 def serialize_attr(cls, attrs):
31 """
32 Only support serialize the basic python type like list or dict,
33 All other types needs to be serialized by the users
34 :param attrs: the dict attributes
35 :return: the dict of serialized data
36 """
37 return attrs
38
39 io_def = onnx.helper.make_tensor_value_info
40
41
42class GPT2Tokenizer(CustomOp):
43
44 @classmethod
45 def get_inputs(cls):
46 return [
47 cls.io_def('input_text', onnx_proto.TensorProto.STRING, [None])
48 ]
49
50 @classmethod
51 def get_outputs(cls):
52 return [
53 cls.io_def("input_ids", onnx.TensorProto.INT64, [None, None]),
54 cls.io_def('attention_mask', onnx.TensorProto.INT64, [None, None])
55 ]
56
57
58class CLIPTokenizer(CustomOp):
59
60 @classmethod
61 def get_inputs(cls):
62 return [
63 cls.io_def('input_text', onnx_proto.TensorProto.STRING, [None])
64 ]
65
66 @classmethod
67 def get_outputs(cls):
68 return [
69 cls.io_def("input_ids", onnx.TensorProto.INT64, [None, None]),
70 cls.io_def('attention_mask', onnx.TensorProto.INT64, [None, None]),
71 cls.io_def('offset_mapping', onnx.TensorProto.INT64, [None, None, 2])
72 ]
73
74
75class RobertaTokenizer(CustomOp):
76
77 @classmethod
78 def get_inputs(cls):
79 return [
80 cls.io_def('input_text', onnx_proto.TensorProto.STRING, [None])
81 ]
82
83 @classmethod
84 def get_outputs(cls):
85 return [
86 cls.io_def("input_ids", onnx.TensorProto.INT64, [None, None]),
87 cls.io_def('attention_mask', onnx.TensorProto.INT64, [None, None]),
88 cls.io_def('offset_mapping', onnx.TensorProto.INT64, [None, None, 2])
89 ]
90
91
92class BpeDecoder(CustomOp):
93 @classmethod
94 def get_inputs(cls):
95 return [
96 cls.io_def("ids", onnx.TensorProto.INT64, [])
97 ]
98
99 @classmethod
100 def get_outputs(cls):
101 return [cls.io_def('str', onnx_proto.TensorProto.STRING, [])]
102
103
104class VectorToString(CustomOp):
105
106 @classmethod
107 def get_inputs(cls):
108 return [cls.io_def("token_ids", onnx.TensorProto.INT64, [])]
109
110 @classmethod
111 def get_outputs(cls):
112 return [cls.io_def('text', onnx_proto.TensorProto.STRING, [None])]
113
114 @classmethod
115 def serialize_attr(cls, attrs):
116 attr_data = {}
117 for k_, v_ in attrs.items():
118 if k_ == 'map' and isinstance(v_, dict):
119 attr_data[k_] = '\n'.join(k + "\t" +
120 " ".join([str(i) for i in v])
121 for k, v in v_.items())
122 elif k_ == 'map' and isinstance(v_, str):
123 attr_data[k_] = v_
124 else:
125 attr_data[k_] = v_
126 return attr_data
127
128
129class StringMapping(CustomOp):
130
131 @classmethod
132 def get_inputs(cls):
133 return [cls.io_def("input", onnx.TensorProto.STRING, [])]
134
135 @classmethod
136 def get_outputs(cls):
137 return [cls.io_def('output', onnx_proto.TensorProto.STRING, [])]
138
139 @classmethod
140 def serialize_attr(cls, attrs):
141 attr_data = {}
142 for k_, v_ in attrs.items():
143 if k_ == 'map' and isinstance(v_, dict):
144 attr_data[k_] = '\n'.join(k + "\t" + v for k, v in v_.items())
145 elif k_ == 'map' and isinstance(v_, str):
146 attr_data[k_] = v_
147 else:
148 attr_data[k_] = v_
149 return attr_data
150
151
152class MaskedFill(CustomOp):
153
154 @classmethod
155 def get_inputs(cls):
156 return [
157 cls.io_def("value", onnx.TensorProto.STRING, [None]),
158 cls.io_def("mask", onnx.TensorProto.BOOL, [None])
159 ]
160
161 @classmethod
162 def get_outputs(cls):
163 return [cls.io_def('output', onnx_proto.TensorProto.STRING, [None])]
164
165
166class StringToVector(CustomOp):
167
168 @classmethod
169 def get_inputs(cls):
170 return [cls.io_def("text", onnx.TensorProto.STRING, [None])]
171
172 @classmethod
173 def get_outputs(cls):
174 return [cls.io_def('token_ids', onnx_proto.TensorProto.INT64, [])]
175
176 @classmethod
177 def serialize_attr(cls, attrs):
178 attr_data = {}
179 for k_, v_ in attrs.items():
180 if k_ == 'map' and isinstance(v_, dict):
181 attr_data[k_] = '\n'.join(k + "\t" +
182 " ".join([str(i) for i in v])
183 for k, v in v_.items())
184 elif k_ == 'map' and isinstance(v_, str):
185 attr_data[k_] = v_
186 elif k_ == 'unk' and isinstance(v_, list):
187 attr_data[k_] = ' '.join(str(i) for i in v_)
188 else:
189 attr_data[k_] = v_
190 return attr_data
191
192
193class BlingFireSentenceBreaker(CustomOp):
194
195 @classmethod
196 def get_inputs(cls):
197 return [cls.io_def("text", onnx.TensorProto.STRING, [None])]
198
199 @classmethod
200 def get_outputs(cls):
201 return [cls.io_def('sentence', onnx_proto.TensorProto.STRING, [])]
202
203 @classmethod
204 def serialize_attr(cls, attrs):
205 attrs_data = {}
206 for k_, v_ in attrs.items():
207 if k_ == 'model':
208 with open(v_, "rb") as model_file:
209 attrs_data[k_] = model_file.read()
210 else:
211 attrs_data[k_] = v_
212 return attrs_data
213
214
215class SegmentExtraction(CustomOp):
216
217 @classmethod
218 def get_inputs(cls):
219 return [cls.io_def("input", onnx.TensorProto.INT64, [None, None])]
220
221 @classmethod
222 def get_outputs(cls):
223 return [
224 cls.io_def('position', onnx_proto.TensorProto.INT64, [None, 2]),
225 cls.io_def('value', onnx_proto.TensorProto.INT64, [None])
226 ]
227
228
229class BertTokenizer(CustomOp):
230
231 @classmethod
232 def get_inputs(cls):
233 return [cls.io_def("text", onnx.TensorProto.STRING, [None])]
234
235 @classmethod
236 def get_outputs(cls):
237 return [
238 cls.io_def('input_ids', onnx_proto.TensorProto.INT64, [None]),
239 cls.io_def('token_type_ids', onnx_proto.TensorProto.INT64, [None]),
240 cls.io_def('attention_mask', onnx_proto.TensorProto.INT64, [None])
241 ]
242
243 @classmethod
244 def serialize_attr(cls, attrs):
245 attrs_data = {}
246 for k_, v_ in attrs.items():
247 if k_ == 'vocab_file':
248 with open(v_, "r", encoding='utf-8') as model_file:
249 lines = model_file.readlines()
250 attrs_data[k_] = '\n'.join(lines)
251 else:
252 attrs_data[k_] = v_
253 return attrs_data
254
255
256class StringECMARegexReplace(CustomOp):
257
258 @classmethod
259 def get_inputs(cls):
260 return [
261 cls.io_def("input", onnx.TensorProto.STRING, [None]),
262 cls.io_def("pattern", onnx.TensorProto.STRING, [None]),
263 cls.io_def("rewrite", onnx.TensorProto.STRING, [None])
264 ]
265
266 @classmethod
267 def get_outputs(cls):
268 return [cls.io_def('output', onnx_proto.TensorProto.STRING, [None])]
269
270
271class BertTokenizerDecoder(CustomOp):
272
273 @classmethod
274 def get_inputs(cls):
275 return [
276 cls.io_def("ids", onnx.TensorProto.INT64, [None]),
277 cls.io_def("position", onnx.TensorProto.INT64, [None, None])
278 ]
279
280 @classmethod
281 def get_outputs(cls):
282 return [cls.io_def('str', onnx_proto.TensorProto.STRING, [None])]
283
284 @classmethod
285 def serialize_attr(cls, attrs):
286 attrs_data = {}
287 for k_, v_ in attrs.items():
288 if k_ == 'vocab_file':
289 with open(v_, "r", encoding='utf-8') as model_file:
290 lines = model_file.readlines()
291 attrs_data[k_] = '\n'.join(lines)
292 else:
293 attrs_data[k_] = v_
294 return attrs_data
295
296
297class SentencepieceTokenizer(CustomOp):
298
299 @classmethod
300 def get_inputs(cls):
301 return [
302 cls.io_def('inputs', onnx_proto.TensorProto.STRING, [None]),
303 cls.io_def('nbest_size', onnx_proto.TensorProto.INT64, [None]),
304 cls.io_def('alpha', onnx_proto.TensorProto.FLOAT, [None]),
305 cls.io_def('add_bos', onnx_proto.TensorProto.BOOL, [None]),
306 cls.io_def('add_eos', onnx_proto.TensorProto.BOOL, [None]),
307 cls.io_def('reverse', onnx_proto.TensorProto.BOOL, [None])
308 ]
309
310 @classmethod
311 def get_outputs(cls):
312 return [
313 cls.io_def('tokens', onnx_proto.TensorProto.INT32, [None]),
314 cls.io_def('indices', onnx_proto.TensorProto.INT64, [None])
315 ]
316
317
318class SentencepieceDecoder(CustomOp):
319
320 @classmethod
321 def get_inputs(cls):
322 return [
323 cls.io_def("ids", onnx.TensorProto.INT64, [None])
324 ]
325
326 @classmethod
327 def get_outputs(cls):
328 return [cls.io_def('str', onnx_proto.TensorProto.STRING, [None])]
329
330
331class Inverse(CustomOp):
332
333 @classmethod
334 def get_inputs(cls):
335 return [
336 cls.io_def('input', onnx_proto.TensorProto.FLOAT, [None, None])
337 ]
338
339 @classmethod
340 def get_outputs(cls):
341 return [
342 cls.io_def('output', onnx_proto.TensorProto.FLOAT, [None, None])
343 ]
344
345
346class ImageReader(CustomOp):
347
348 @classmethod
349 def get_inputs(cls):
350 return [
351 cls.io_def('image_paths', onnx_proto.TensorProto.STRING, [None])
352 ]
353
354 @classmethod
355 def get_outputs(cls):
356 return [
357 cls.io_def('nchw_bytes', onnx_proto.TensorProto.UINT8, [None, None, None, None])
358 ]
359
360
361class GaussianBlur(CustomOp):
362
363 @classmethod
364 def get_inputs(cls):
365 return [
366 cls.io_def('nhwc', onnx_proto.TensorProto.FLOAT, [None, None, None, None]),
367 cls.io_def('kernel_size', onnx_proto.TensorProto.INT64, [None]),
368 cls.io_def('sigma_xy', onnx_proto.TensorProto.DOUBLE, [None])
369 ]
370
371 @classmethod
372 def get_outputs(cls):
373 return [
374 cls.io_def('gb_nhwc', onnx_proto.TensorProto.FLOAT, [None, None, None, None])
375 ]
376
377
378class ImageDecoder(CustomOp):
379
380 @classmethod
381 def get_inputs(cls):
382 return [
383 cls.io_def('raw_input_image', onnx_proto.TensorProto.UINT8, [])
384 ]
385
386 @classmethod
387 def get_outputs(cls):
388 return [
389 cls.io_def('decoded_image', onnx_proto.TensorProto.UINT8, [None, None, 3])
390 ]
391
392
393class AudioDecoder(CustomOp):
394 @classmethod
395 def get_inputs(cls):
396 return [
397 cls.io_def('audio_stream', onnx_proto.TensorProto.UINT8, [1, None])
398 ]
399
400 @classmethod
401 def get_outputs(cls):
402 return [
403 cls.io_def('floatPCM', onnx_proto.TensorProto.FLOAT, [1, None])
404 ]
405
406
407class StftNorm(CustomOp):
408 @classmethod
409 def get_inputs(cls):
410 return [
411 cls.io_def('pcm_wave', onnx_proto.TensorProto.FLOAT, [1, None]),
412 cls.io_def('n_fft', onnx_proto.TensorProto.INT64, []),
413 cls.io_def('hop_length', onnx_proto.TensorProto.INT64, []),
414 cls.io_def('window', onnx_proto.TensorProto.FLOAT, [None]),
415 cls.io_def('frame_size', onnx_proto.TensorProto.INT64, []),
416 ]
417
418 @classmethod
419 def get_outputs(cls):
420 return [
421 cls.io_def('stft_norm', onnx_proto.TensorProto.FLOAT, [1, None, None])
422 ]
423
424
425class SingleOpGraph:
426
427 @classmethod
428 def get_next_id(cls):
429 if not hasattr(cls, '_id_counter'):
430 cls._id_counter = 0
431 cls._id_counter += 1
432 return cls._id_counter
433
434 @classmethod
435 def build_my_graph(cls, op_class, *args, **kwargs):
436 if isinstance(op_class, str):
437 op_class = cls.get_op_class(op_class)
438
439 op_type = op_class.op_type()
440 inputs = op_class.get_inputs()
441 outputs = op_class.get_outputs()
442 attrs = op_class.serialize_attr(kwargs)
443 cuop = onnx.helper.make_node(op_type, [i_.name for i_ in inputs],
444 [o_.name for o_ in outputs],
445 "{}_{}".format(op_type,
446 cls.get_next_id()),
447 **attrs,
448 domain=default_opset_domain())
449 graph = onnx.helper.make_graph([cuop], "og_{}_{}".format(
450 op_type, cls.get_next_id()), inputs, outputs)
451 return graph
452
453 @staticmethod
454 def get_op_class(op_type):
455 return globals()[op_type]
456
457
458# TODO: have a C++ impl.
459def _argsort_op(x, dim):
460 d = numpy.argsort(x, dim)
461 return d[:, ::-1]
462
463
464Opdef.create(_argsort_op,
465 op_type='ArgSort',
466 inputs=[PyCustomOpDef.dt_float, PyCustomOpDef.dt_int64],
467 outputs=[PyCustomOpDef.dt_int64])
468
469
470class CustomOpConverter:
471 pass
472