microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
922b7cc387ab361cfc9105a256bdaf54d7796601

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/_cuops.py

500lines · modecode

1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License. See License.txt in the project root for
3# license information.
4###############################################################################
5
6"""
7_cuops.py: Custom operators signatures for Python usage.
8"""
9
10import onnx
11import numpy
12from onnx import onnx_pb as onnx_proto
13from ._ocos import default_opset_domain, Opdef, PyCustomOpDef
14
15
16class CustomOp:
17
18 @classmethod
19 def op_type(cls):
20 rcls = cls
21 while CustomOp != rcls.__base__:
22 rcls = rcls.__base__
23 return rcls.__name__
24
25 @classmethod
26 def get_inputs(cls):
27 return None
28
29 @classmethod
30 def get_outputs(cls):
31 return None
32
33 @classmethod
34 def input_default_values(cls):
35 return None
36
37 @classmethod
38 def serialize_attr(cls, attrs):
39 """
40 Only support serialize the basic python type like list or dict,
41 All other types needs to be serialized by the users
42 :param attrs: the dict attributes
43 :return: the dict of serialized data
44 """
45 return attrs
46
47 io_def = onnx.helper.make_tensor_value_info
48
49
50class GPT2Tokenizer(CustomOp):
51
52 @classmethod
53 def get_inputs(cls):
54 return [
55 cls.io_def('input_text', onnx_proto.TensorProto.STRING, [None])
56 ]
57
58 @classmethod
59 def get_outputs(cls):
60 return [
61 cls.io_def("input_ids", onnx.TensorProto.INT64, [None, None]),
62 cls.io_def('attention_mask', onnx.TensorProto.INT64, [None, None])
63 ]
64
65
66class CLIPTokenizer(CustomOp):
67
68 @classmethod
69 def get_inputs(cls):
70 return [
71 cls.io_def('input_text', onnx_proto.TensorProto.STRING, [None])
72 ]
73
74 @classmethod
75 def get_outputs(cls):
76 return [
77 cls.io_def("input_ids", onnx.TensorProto.INT64, [None, None]),
78 cls.io_def('attention_mask', onnx.TensorProto.INT64, [None, None]),
79 cls.io_def('offset_mapping', onnx.TensorProto.INT64, [None, None, 2])
80 ]
81
82
83class RobertaTokenizer(CustomOp):
84
85 @classmethod
86 def get_inputs(cls):
87 return [
88 cls.io_def('input_text', onnx_proto.TensorProto.STRING, [None])
89 ]
90
91 @classmethod
92 def get_outputs(cls):
93 return [
94 cls.io_def("input_ids", onnx.TensorProto.INT64, [None, None]),
95 cls.io_def('attention_mask', onnx.TensorProto.INT64, [None, None]),
96 cls.io_def('offset_mapping', onnx.TensorProto.INT64, [None, None, 2])
97 ]
98
99
100class BpeDecoder(CustomOp):
101 @classmethod
102 def get_inputs(cls):
103 return [
104 cls.io_def("ids", onnx.TensorProto.INT64, None)
105 ]
106
107 @classmethod
108 def get_outputs(cls):
109 return [cls.io_def('str', onnx_proto.TensorProto.STRING, None)]
110
111
112class VectorToString(CustomOp):
113
114 @classmethod
115 def get_inputs(cls):
116 return [cls.io_def("token_ids", onnx.TensorProto.INT64, [])]
117
118 @classmethod
119 def get_outputs(cls):
120 return [cls.io_def('text', onnx_proto.TensorProto.STRING, [None])]
121
122 @classmethod
123 def serialize_attr(cls, attrs):
124 attr_data = {}
125 for k_, v_ in attrs.items():
126 if k_ == 'map' and isinstance(v_, dict):
127 attr_data[k_] = '\n'.join(k + "\t" +
128 " ".join([str(i) for i in v])
129 for k, v in v_.items())
130 elif k_ == 'map' and isinstance(v_, str):
131 attr_data[k_] = v_
132 else:
133 attr_data[k_] = v_
134 return attr_data
135
136
137class StringMapping(CustomOp):
138
139 @classmethod
140 def get_inputs(cls):
141 return [cls.io_def("input", onnx.TensorProto.STRING, [])]
142
143 @classmethod
144 def get_outputs(cls):
145 return [cls.io_def('output', onnx_proto.TensorProto.STRING, [])]
146
147 @classmethod
148 def serialize_attr(cls, attrs):
149 attr_data = {}
150 for k_, v_ in attrs.items():
151 if k_ == 'map' and isinstance(v_, dict):
152 attr_data[k_] = '\n'.join(k + "\t" + v for k, v in v_.items())
153 elif k_ == 'map' and isinstance(v_, str):
154 attr_data[k_] = v_
155 else:
156 attr_data[k_] = v_
157 return attr_data
158
159
160class MaskedFill(CustomOp):
161
162 @classmethod
163 def get_inputs(cls):
164 return [
165 cls.io_def("value", onnx.TensorProto.STRING, [None]),
166 cls.io_def("mask", onnx.TensorProto.BOOL, [None])
167 ]
168
169 @classmethod
170 def get_outputs(cls):
171 return [cls.io_def('output', onnx_proto.TensorProto.STRING, [None])]
172
173
174class StringToVector(CustomOp):
175
176 @classmethod
177 def get_inputs(cls):
178 return [cls.io_def("text", onnx.TensorProto.STRING, [None])]
179
180 @classmethod
181 def get_outputs(cls):
182 return [cls.io_def('token_ids', onnx_proto.TensorProto.INT64, [])]
183
184 @classmethod
185 def serialize_attr(cls, attrs):
186 attr_data = {}
187 for k_, v_ in attrs.items():
188 if k_ == 'map' and isinstance(v_, dict):
189 attr_data[k_] = '\n'.join(k + "\t" +
190 " ".join([str(i) for i in v])
191 for k, v in v_.items())
192 elif k_ == 'map' and isinstance(v_, str):
193 attr_data[k_] = v_
194 elif k_ == 'unk' and isinstance(v_, list):
195 attr_data[k_] = ' '.join(str(i) for i in v_)
196 else:
197 attr_data[k_] = v_
198 return attr_data
199
200
201class BlingFireSentenceBreaker(CustomOp):
202
203 @classmethod
204 def get_inputs(cls):
205 return [cls.io_def("text", onnx.TensorProto.STRING, [None])]
206
207 @classmethod
208 def get_outputs(cls):
209 return [cls.io_def('sentence', onnx_proto.TensorProto.STRING, [])]
210
211 @classmethod
212 def serialize_attr(cls, attrs):
213 attrs_data = {}
214 for k_, v_ in attrs.items():
215 if k_ == 'model':
216 with open(v_, "rb") as model_file:
217 attrs_data[k_] = model_file.read()
218 else:
219 attrs_data[k_] = v_
220 return attrs_data
221
222
223class SegmentExtraction(CustomOp):
224
225 @classmethod
226 def get_inputs(cls):
227 return [cls.io_def("input", onnx.TensorProto.INT64, [None, None])]
228
229 @classmethod
230 def get_outputs(cls):
231 return [
232 cls.io_def('position', onnx_proto.TensorProto.INT64, [None, 2]),
233 cls.io_def('value', onnx_proto.TensorProto.INT64, [None])
234 ]
235
236
237class BertTokenizer(CustomOp):
238
239 @classmethod
240 def get_inputs(cls):
241 return [cls.io_def("text", onnx.TensorProto.STRING, [None])]
242
243 @classmethod
244 def get_outputs(cls):
245 return [
246 cls.io_def('input_ids', onnx_proto.TensorProto.INT64, [None]),
247 cls.io_def('token_type_ids', onnx_proto.TensorProto.INT64, [None]),
248 cls.io_def('attention_mask', onnx_proto.TensorProto.INT64, [None]),
249 cls.io_def('offset_mapping', onnx.TensorProto.INT64, [None, 2])
250 ]
251
252 @classmethod
253 def serialize_attr(cls, attrs):
254 attrs_data = {}
255 for k_, v_ in attrs.items():
256 if k_ == 'vocab':
257 attrs_data['vocab_file'] = v_
258 elif k_ == 'vocab_file':
259 with open(v_, "r", encoding='utf-8') as model_file:
260 lines = model_file.readlines()
261 attrs_data[k_] = '\n'.join(lines)
262 else:
263 attrs_data[k_] = v_
264 return attrs_data
265
266
267class StringECMARegexReplace(CustomOp):
268
269 @classmethod
270 def get_inputs(cls):
271 return [
272 cls.io_def("input", onnx.TensorProto.STRING, [None]),
273 cls.io_def("pattern", onnx.TensorProto.STRING, [None]),
274 cls.io_def("rewrite", onnx.TensorProto.STRING, [None])
275 ]
276
277 @classmethod
278 def get_outputs(cls):
279 return [cls.io_def('output', onnx_proto.TensorProto.STRING, [None])]
280
281
282class BertTokenizerDecoder(CustomOp):
283
284 @classmethod
285 def get_inputs(cls):
286 return [
287 cls.io_def("ids", onnx.TensorProto.INT64, [None]),
288 cls.io_def("position", onnx.TensorProto.INT64, [None, None])
289 ]
290
291 @classmethod
292 def get_outputs(cls):
293 return [cls.io_def('str', onnx_proto.TensorProto.STRING, [None])]
294
295 @classmethod
296 def serialize_attr(cls, attrs):
297 attrs_data = {}
298 for k_, v_ in attrs.items():
299 if k_ == 'vocab_file':
300 with open(v_, "r", encoding='utf-8') as model_file:
301 lines = model_file.readlines()
302 attrs_data[k_] = '\n'.join(lines)
303 else:
304 attrs_data[k_] = v_
305 return attrs_data
306
307
308class SentencepieceTokenizer(CustomOp):
309
310 @classmethod
311 def get_inputs(cls):
312 return [
313 cls.io_def('inputs', onnx_proto.TensorProto.STRING, [None]),
314 cls.io_def('nbest_size', onnx_proto.TensorProto.INT64, [None]),
315 cls.io_def('alpha', onnx_proto.TensorProto.FLOAT, [None]),
316 cls.io_def('add_bos', onnx_proto.TensorProto.BOOL, [None]),
317 cls.io_def('add_eos', onnx_proto.TensorProto.BOOL, [None]),
318 cls.io_def('reverse', onnx_proto.TensorProto.BOOL, [None])
319 ]
320
321 # beyond Python 3.7, the order of the dict is guaranteed to be insertion order
322 @classmethod
323 def input_default_values(cls):
324 return {
325 'nbest_size': [0],
326 'alpha': [0],
327 'add_bos': [False],
328 'add_eos': [False],
329 'reverse': [False]
330 }
331
332 @classmethod
333 def get_outputs(cls):
334 return [
335 cls.io_def('tokens', onnx_proto.TensorProto.INT32, [None]),
336 cls.io_def('indices', onnx_proto.TensorProto.INT64, [None])
337 ]
338
339
340class SentencepieceDecoder(CustomOp):
341
342 @classmethod
343 def get_inputs(cls):
344 return [
345 cls.io_def("ids", onnx.TensorProto.INT64, [None])
346 ]
347
348 @classmethod
349 def get_outputs(cls):
350 return [cls.io_def('str', onnx_proto.TensorProto.STRING, [None])]
351
352
353class Inverse(CustomOp):
354
355 @classmethod
356 def get_inputs(cls):
357 return [
358 cls.io_def('input', onnx_proto.TensorProto.FLOAT, [None, None])
359 ]
360
361 @classmethod
362 def get_outputs(cls):
363 return [
364 cls.io_def('output', onnx_proto.TensorProto.FLOAT, [None, None])
365 ]
366
367
368class ImageReader(CustomOp):
369
370 @classmethod
371 def get_inputs(cls):
372 return [
373 cls.io_def('image_paths', onnx_proto.TensorProto.STRING, [None])
374 ]
375
376 @classmethod
377 def get_outputs(cls):
378 return [
379 cls.io_def('nchw_bytes', onnx_proto.TensorProto.UINT8, [None, None, None, None])
380 ]
381
382
383class GaussianBlur(CustomOp):
384
385 @classmethod
386 def get_inputs(cls):
387 return [
388 cls.io_def('nhwc', onnx_proto.TensorProto.FLOAT, [None, None, None, None]),
389 cls.io_def('kernel_size', onnx_proto.TensorProto.INT64, [None]),
390 cls.io_def('sigma_xy', onnx_proto.TensorProto.DOUBLE, [None])
391 ]
392
393 @classmethod
394 def get_outputs(cls):
395 return [
396 cls.io_def('gb_nhwc', onnx_proto.TensorProto.FLOAT, [None, None, None, None])
397 ]
398
399
400class ImageDecoder(CustomOp):
401
402 @classmethod
403 def get_inputs(cls):
404 return [
405 cls.io_def('raw_input_image', onnx_proto.TensorProto.UINT8, [])
406 ]
407
408 @classmethod
409 def get_outputs(cls):
410 return [
411 cls.io_def('decoded_image', onnx_proto.TensorProto.UINT8, [None, None, 3])
412 ]
413
414
415class AudioDecoder(CustomOp):
416 @classmethod
417 def get_inputs(cls):
418 return [
419 cls.io_def('audio_stream', onnx_proto.TensorProto.UINT8, [1, None])
420 ]
421
422 @classmethod
423 def get_outputs(cls):
424 return [
425 cls.io_def('floatPCM', onnx_proto.TensorProto.FLOAT, [1, None])
426 ]
427
428
429class StftNorm(CustomOp):
430 @classmethod
431 def get_inputs(cls):
432 return [
433 cls.io_def('pcm_wave', onnx_proto.TensorProto.FLOAT, [1, None]),
434 cls.io_def('n_fft', onnx_proto.TensorProto.INT64, []),
435 cls.io_def('hop_length', onnx_proto.TensorProto.INT64, []),
436 cls.io_def('window', onnx_proto.TensorProto.FLOAT, [None]),
437 cls.io_def('frame_size', onnx_proto.TensorProto.INT64, []),
438 ]
439
440 @classmethod
441 def get_outputs(cls):
442 return [
443 cls.io_def('stft_norm', onnx_proto.TensorProto.FLOAT, [1, None, None])
444 ]
445
446
447# TODO: have a C++ impl.
448def _argsort_op(x, dim):
449 d = numpy.argsort(x, dim)
450 return d[:, ::-1]
451
452
453Opdef.create(_argsort_op,
454 op_type='ArgSort',
455 inputs=[PyCustomOpDef.dt_float, PyCustomOpDef.dt_int64],
456 outputs=[PyCustomOpDef.dt_int64])
457
458
459class CustomOpConverter:
460 pass
461
462
463class SingleOpGraph:
464
465 @classmethod
466 def get_next_id(cls):
467 if not hasattr(cls, '_id_counter'):
468 cls._id_counter = 0
469 cls._id_counter += 1
470 return cls._id_counter
471
472 @classmethod
473 def build_graph(cls, op_class, *args, **kwargs):
474 if isinstance(op_class, str):
475 op_class = cls.get_op_class(op_class)
476
477 cvt = kwargs.pop('cvt', None)
478 if cvt is None and len(args) > 0 and isinstance(args[0], CustomOpConverter):
479 cvt = args[0]
480 args = args[1:]
481
482 new_kwargs = kwargs if cvt is None else cvt(**kwargs)
483
484 op_type = op_class.op_type()
485 inputs = op_class.get_inputs()
486 outputs = op_class.get_outputs()
487 attrs = op_class.serialize_attr(new_kwargs)
488 cuop = onnx.helper.make_node(op_type, [i_.name for i_ in inputs],
489 [o_.name for o_ in outputs],
490 "{}_{}".format(op_type,
491 cls.get_next_id()),
492 **attrs,
493 domain=default_opset_domain())
494 graph = onnx.helper.make_graph([cuop], "og_{}_{}".format(
495 op_type, cls.get_next_id()), inputs, outputs)
496 return graph
497
498 @staticmethod
499 def get_op_class(op_type):
500 return globals()[op_type]
501