microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
df7a9f337c69567dc9c58400d3ec8004bcafb794

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/_cuops.py

520lines · modecode

1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License. See License.txt in the project root for
3# license information.
4###############################################################################
5
6"""
7_cuops.py: Custom operators signatures for Python usage.
8"""
9
10import onnx
11import numpy
12from onnx import onnx_pb as onnx_proto
13from ._ocos import default_opset_domain, Opdef, PyCustomOpDef
14
15
16class CustomOp:
17
18 @classmethod
19 def op_type(cls):
20 rcls = cls
21 while CustomOp != rcls.__base__:
22 rcls = rcls.__base__
23 return rcls.__name__
24
25 @classmethod
26 def get_inputs(cls):
27 return None
28
29 @classmethod
30 def get_outputs(cls):
31 return None
32
33 @classmethod
34 def input_default_values(cls):
35 return None
36
37 @classmethod
38 def serialize_attr(cls, attrs):
39 """
40 Only support serialize the basic python type like list or dict,
41 All other types needs to be serialized by the users
42 :param attrs: the dict attributes
43 :return: the dict of serialized data
44 """
45 return attrs
46
47 io_def = onnx.helper.make_tensor_value_info
48
49
50class GPT2Tokenizer(CustomOp):
51
52 @classmethod
53 def get_inputs(cls):
54 return [
55 cls.io_def('input_text', onnx_proto.TensorProto.STRING, [None])
56 ]
57
58 @classmethod
59 def get_outputs(cls):
60 return [
61 cls.io_def("input_ids", onnx.TensorProto.INT64, [None, None]),
62 cls.io_def('attention_mask', onnx.TensorProto.INT64, [None, None])
63 ]
64
65
66class CLIPTokenizer(CustomOp):
67
68 @classmethod
69 def get_inputs(cls):
70 return [
71 cls.io_def('input_text', onnx_proto.TensorProto.STRING, [None])
72 ]
73
74 @classmethod
75 def get_outputs(cls):
76 return [
77 cls.io_def("input_ids", onnx.TensorProto.INT64, [None, None]),
78 cls.io_def('attention_mask', onnx.TensorProto.INT64, [None, None]),
79 cls.io_def('offset_mapping', onnx.TensorProto.INT64, [None, None, 2])
80 ]
81
82
83class RobertaTokenizer(CustomOp):
84
85 @classmethod
86 def get_inputs(cls):
87 return [
88 cls.io_def('input_text', onnx_proto.TensorProto.STRING, [None])
89 ]
90
91 @classmethod
92 def get_outputs(cls):
93 return [
94 cls.io_def("input_ids", onnx.TensorProto.INT64, [None, None]),
95 cls.io_def('attention_mask', onnx.TensorProto.INT64, [None, None]),
96 cls.io_def('offset_mapping', onnx.TensorProto.INT64, [None, None, 2])
97 ]
98
99
100class BpeDecoder(CustomOp):
101 @classmethod
102 def get_inputs(cls):
103 return [
104 cls.io_def("ids", onnx.TensorProto.INT64, None)
105 ]
106
107 @classmethod
108 def get_outputs(cls):
109 return [cls.io_def('str', onnx_proto.TensorProto.STRING, None)]
110
111
112class VectorToString(CustomOp):
113
114 @classmethod
115 def get_inputs(cls):
116 return [cls.io_def("token_ids", onnx.TensorProto.INT64, [])]
117
118 @classmethod
119 def get_outputs(cls):
120 return [cls.io_def('text', onnx_proto.TensorProto.STRING, [None])]
121
122 @classmethod
123 def serialize_attr(cls, attrs):
124 attr_data = {}
125 for k_, v_ in attrs.items():
126 if k_ == 'map' and isinstance(v_, dict):
127 attr_data[k_] = '\n'.join(k + "\t" +
128 " ".join([str(i) for i in v])
129 for k, v in v_.items())
130 elif k_ == 'map' and isinstance(v_, str):
131 attr_data[k_] = v_
132 else:
133 attr_data[k_] = v_
134 return attr_data
135
136
137class StringMapping(CustomOp):
138
139 @classmethod
140 def get_inputs(cls):
141 return [cls.io_def("input", onnx.TensorProto.STRING, [])]
142
143 @classmethod
144 def get_outputs(cls):
145 return [cls.io_def('output', onnx_proto.TensorProto.STRING, [])]
146
147 @classmethod
148 def serialize_attr(cls, attrs):
149 attr_data = {}
150 for k_, v_ in attrs.items():
151 if k_ == 'map' and isinstance(v_, dict):
152 attr_data[k_] = '\n'.join(k + "\t" + v for k, v in v_.items())
153 elif k_ == 'map' and isinstance(v_, str):
154 attr_data[k_] = v_
155 else:
156 attr_data[k_] = v_
157 return attr_data
158
159
160class MaskedFill(CustomOp):
161
162 @classmethod
163 def get_inputs(cls):
164 return [
165 cls.io_def("value", onnx.TensorProto.STRING, [None]),
166 cls.io_def("mask", onnx.TensorProto.BOOL, [None])
167 ]
168
169 @classmethod
170 def get_outputs(cls):
171 return [cls.io_def('output', onnx_proto.TensorProto.STRING, [None])]
172
173
174class StringToVector(CustomOp):
175
176 @classmethod
177 def get_inputs(cls):
178 return [cls.io_def("text", onnx.TensorProto.STRING, [None])]
179
180 @classmethod
181 def get_outputs(cls):
182 return [cls.io_def('token_ids', onnx_proto.TensorProto.INT64, [])]
183
184 @classmethod
185 def serialize_attr(cls, attrs):
186 attr_data = {}
187 for k_, v_ in attrs.items():
188 if k_ == 'map' and isinstance(v_, dict):
189 attr_data[k_] = '\n'.join(k + "\t" +
190 " ".join([str(i) for i in v])
191 for k, v in v_.items())
192 elif k_ == 'map' and isinstance(v_, str):
193 attr_data[k_] = v_
194 elif k_ == 'unk' and isinstance(v_, list):
195 attr_data[k_] = ' '.join(str(i) for i in v_)
196 else:
197 attr_data[k_] = v_
198 return attr_data
199
200
201class BlingFireSentenceBreaker(CustomOp):
202
203 @classmethod
204 def get_inputs(cls):
205 return [cls.io_def("text", onnx.TensorProto.STRING, [None])]
206
207 @classmethod
208 def get_outputs(cls):
209 return [cls.io_def('sentence', onnx_proto.TensorProto.STRING, [])]
210
211 @classmethod
212 def serialize_attr(cls, attrs):
213 attrs_data = {}
214 for k_, v_ in attrs.items():
215 if k_ == 'model':
216 with open(v_, "rb") as model_file:
217 attrs_data[k_] = model_file.read()
218 else:
219 attrs_data[k_] = v_
220 return attrs_data
221
222
223class SegmentExtraction(CustomOp):
224
225 @classmethod
226 def get_inputs(cls):
227 return [cls.io_def("input", onnx.TensorProto.INT64, [None, None])]
228
229 @classmethod
230 def get_outputs(cls):
231 return [
232 cls.io_def('position', onnx_proto.TensorProto.INT64, [None, 2]),
233 cls.io_def('value', onnx_proto.TensorProto.INT64, [None])
234 ]
235
236
237class BertTokenizer(CustomOp):
238
239 @classmethod
240 def get_inputs(cls):
241 return [cls.io_def("text", onnx.TensorProto.STRING, [None])]
242
243 @classmethod
244 def get_outputs(cls):
245 return [
246 cls.io_def('input_ids', onnx_proto.TensorProto.INT64, [None]),
247 cls.io_def('token_type_ids', onnx_proto.TensorProto.INT64, [None]),
248 cls.io_def('attention_mask', onnx_proto.TensorProto.INT64, [None]),
249 cls.io_def('offset_mapping', onnx.TensorProto.INT64, [None, 2])
250 ]
251
252 @classmethod
253 def serialize_attr(cls, attrs):
254 attrs_data = {}
255 for k_, v_ in attrs.items():
256 if k_ == 'vocab':
257 attrs_data['vocab_file'] = v_
258 elif k_ == 'vocab_file':
259 with open(v_, "r", encoding='utf-8') as model_file:
260 lines = model_file.readlines()
261 attrs_data[k_] = '\n'.join(lines)
262 else:
263 attrs_data[k_] = v_
264 return attrs_data
265
266
267class StringECMARegexReplace(CustomOp):
268
269 @classmethod
270 def get_inputs(cls):
271 return [
272 cls.io_def("input", onnx.TensorProto.STRING, [None]),
273 cls.io_def("pattern", onnx.TensorProto.STRING, [None]),
274 cls.io_def("rewrite", onnx.TensorProto.STRING, [None])
275 ]
276
277 @classmethod
278 def get_outputs(cls):
279 return [cls.io_def('output', onnx_proto.TensorProto.STRING, [None])]
280
281
282class BertTokenizerDecoder(CustomOp):
283
284 @classmethod
285 def get_inputs(cls):
286 return [
287 cls.io_def("ids", onnx.TensorProto.INT64, [None]),
288 cls.io_def("position", onnx.TensorProto.INT64, [None, None])
289 ]
290
291 @classmethod
292 def get_outputs(cls):
293 return [cls.io_def('str', onnx_proto.TensorProto.STRING, [None])]
294
295 @classmethod
296 def serialize_attr(cls, attrs):
297 attrs_data = {}
298 for k_, v_ in attrs.items():
299 if k_ == 'vocab_file':
300 with open(v_, "r", encoding='utf-8') as model_file:
301 lines = model_file.readlines()
302 attrs_data[k_] = '\n'.join(lines)
303 else:
304 attrs_data[k_] = v_
305 return attrs_data
306
307
308class SentencepieceTokenizer(CustomOp):
309
310 @classmethod
311 def get_inputs(cls):
312 return [
313 cls.io_def('inputs', onnx_proto.TensorProto.STRING, [None]),
314 cls.io_def('nbest_size', onnx_proto.TensorProto.INT64, [None]),
315 cls.io_def('alpha', onnx_proto.TensorProto.FLOAT, [None]),
316 cls.io_def('add_bos', onnx_proto.TensorProto.BOOL, [None]),
317 cls.io_def('add_eos', onnx_proto.TensorProto.BOOL, [None]),
318 cls.io_def('reverse', onnx_proto.TensorProto.BOOL, [None])
319 ]
320
321 # beyond Python 3.7, the order of the dict is guaranteed to be insertion order
322 @classmethod
323 def input_default_values(cls):
324 return {
325 'nbest_size': [0],
326 'alpha': [0],
327 'add_bos': [False],
328 'add_eos': [False],
329 'reverse': [False]
330 }
331
332 @classmethod
333 def get_outputs(cls):
334 return [
335 cls.io_def('tokens', onnx_proto.TensorProto.INT32, [None]),
336 cls.io_def('indices', onnx_proto.TensorProto.INT64, [None])
337 ]
338
339
340class SentencepieceDecoder(CustomOp):
341
342 @classmethod
343 def get_inputs(cls):
344 return [
345 cls.io_def("ids", onnx.TensorProto.INT64, [None])
346 ]
347
348 @classmethod
349 def get_outputs(cls):
350 return [cls.io_def('str', onnx_proto.TensorProto.STRING, [None])]
351
352
353class TrieTokenizer(CustomOp):
354 @classmethod
355 def get_inputs(cls):
356 return [cls.io_def('str', onnx_proto.TensorProto.STRING, ['N'])]
357
358 @classmethod
359 def get_outputs(cls):
360 return [cls.io_def("ids", onnx.TensorProto.INT64, ['N', None])]
361
362
363class TrieDetokenizer(CustomOp):
364 @classmethod
365 def get_inputs(cls):
366 return [cls.io_def("ids", onnx.TensorProto.INT64, ['N', None])]
367
368 @classmethod
369 def get_outputs(cls):
370 return [cls.io_def('str', onnx_proto.TensorProto.STRING, [None])]
371
372
373class Inverse(CustomOp):
374
375 @classmethod
376 def get_inputs(cls):
377 return [
378 cls.io_def('input', onnx_proto.TensorProto.FLOAT, [None, None])
379 ]
380
381 @classmethod
382 def get_outputs(cls):
383 return [
384 cls.io_def('output', onnx_proto.TensorProto.FLOAT, [None, None])
385 ]
386
387
388class ImageReader(CustomOp):
389
390 @classmethod
391 def get_inputs(cls):
392 return [
393 cls.io_def('image_paths', onnx_proto.TensorProto.STRING, [None])
394 ]
395
396 @classmethod
397 def get_outputs(cls):
398 return [
399 cls.io_def('nchw_bytes', onnx_proto.TensorProto.UINT8, [None, None, None, None])
400 ]
401
402
403class GaussianBlur(CustomOp):
404
405 @classmethod
406 def get_inputs(cls):
407 return [
408 cls.io_def('nhwc', onnx_proto.TensorProto.FLOAT, [None, None, None, None]),
409 cls.io_def('kernel_size', onnx_proto.TensorProto.INT64, [None]),
410 cls.io_def('sigma_xy', onnx_proto.TensorProto.DOUBLE, [None])
411 ]
412
413 @classmethod
414 def get_outputs(cls):
415 return [
416 cls.io_def('gb_nhwc', onnx_proto.TensorProto.FLOAT, [None, None, None, None])
417 ]
418
419
420class ImageDecoder(CustomOp):
421
422 @classmethod
423 def get_inputs(cls):
424 return [
425 cls.io_def('raw_input_image', onnx_proto.TensorProto.UINT8, [])
426 ]
427
428 @classmethod
429 def get_outputs(cls):
430 return [
431 cls.io_def('decoded_image', onnx_proto.TensorProto.UINT8, [None, None, 3])
432 ]
433
434
435class AudioDecoder(CustomOp):
436 @classmethod
437 def get_inputs(cls):
438 return [
439 cls.io_def('audio_stream', onnx_proto.TensorProto.UINT8, [1, None])
440 ]
441
442 @classmethod
443 def get_outputs(cls):
444 return [
445 cls.io_def('floatPCM', onnx_proto.TensorProto.FLOAT, [1, None])
446 ]
447
448
449class StftNorm(CustomOp):
450 @classmethod
451 def get_inputs(cls):
452 return [
453 cls.io_def('pcm_wave', onnx_proto.TensorProto.FLOAT, [1, None]),
454 cls.io_def('n_fft', onnx_proto.TensorProto.INT64, []),
455 cls.io_def('hop_length', onnx_proto.TensorProto.INT64, []),
456 cls.io_def('window', onnx_proto.TensorProto.FLOAT, [None]),
457 cls.io_def('frame_size', onnx_proto.TensorProto.INT64, []),
458 ]
459
460 @classmethod
461 def get_outputs(cls):
462 return [
463 cls.io_def('stft_norm', onnx_proto.TensorProto.FLOAT, [1, None, None])
464 ]
465
466
467# TODO: have a C++ impl.
468def _argsort_op(x, dim):
469 d = numpy.argsort(x, dim)
470 return d[:, ::-1]
471
472
473Opdef.create(_argsort_op,
474 op_type='ArgSort',
475 inputs=[PyCustomOpDef.dt_float, PyCustomOpDef.dt_int64],
476 outputs=[PyCustomOpDef.dt_int64])
477
478
479class CustomOpConverter:
480 pass
481
482
483class SingleOpGraph:
484
485 @classmethod
486 def get_next_id(cls):
487 if not hasattr(cls, '_id_counter'):
488 cls._id_counter = 0
489 cls._id_counter += 1
490 return cls._id_counter
491
492 @classmethod
493 def build_graph(cls, op_class, *args, **kwargs):
494 if isinstance(op_class, str):
495 op_class = cls.get_op_class(op_class)
496
497 cvt = kwargs.pop('cvt', None)
498 if cvt is None and len(args) > 0 and isinstance(args[0], CustomOpConverter):
499 cvt = args[0]
500 args = args[1:]
501
502 new_kwargs = kwargs if cvt is None else cvt(**kwargs)
503
504 op_type = op_class.op_type()
505 inputs = op_class.get_inputs()
506 outputs = op_class.get_outputs()
507 attrs = op_class.serialize_attr(new_kwargs)
508 cuop = onnx.helper.make_node(op_type, [i_.name for i_ in inputs],
509 [o_.name for o_ in outputs],
510 "{}_{}".format(op_type,
511 cls.get_next_id()),
512 **attrs,
513 domain=default_opset_domain())
514 graph = onnx.helper.make_graph([cuop], "og_{}_{}".format(
515 op_type, cls.get_next_id()), inputs, outputs)
516 return graph
517
518 @staticmethod
519 def get_op_class(op_type):
520 return globals()[op_type]
521