microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
b5ba84a185c5c85c436e744c40ac9a8cbd9d3f1f

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

test/test_string_ops.py

869lines · modecode

1# coding: utf-8
2import unittest
3import re
4from binascii import crc32
5import numpy as np
6from onnx import helper, onnx_pb as onnx_proto
7import onnxruntime as _ort
8from onnxruntime_customops import (
9 onnx_op, PyCustomOpDef,
10 get_library_path as _get_library_path,
11 hash_64)
12
13NUM_BUCKETS = 23
14
15
16def _create_test_model_string_upper(prefix, domain='ai.onnx.contrib'):
17 nodes = []
18 nodes[0:] = [helper.make_node('Identity', ['input_1'], ['identity1'])]
19 nodes[1:] = [helper.make_node('%sStringUpper' % prefix,
20 ['identity1'], ['customout'],
21 domain=domain)]
22
23 input0 = helper.make_tensor_value_info(
24 'input_1', onnx_proto.TensorProto.STRING, [None, None])
25 output0 = helper.make_tensor_value_info(
26 'customout', onnx_proto.TensorProto.STRING, [None, None])
27
28 graph = helper.make_graph(nodes, 'test0', [input0], [output0])
29 model = helper.make_model(
30 graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
31 return model
32
33
34def _create_test_model_string_join(prefix, domain='ai.onnx.contrib'):
35 nodes = []
36 nodes.append(
37 helper.make_node('Identity', ['text'], ['identity1']))
38 nodes.append(
39 helper.make_node('Identity', ['sep'], ['identity2']))
40 nodes.append(
41 helper.make_node('Identity', ['axis'], ['identity3']))
42 nodes.append(
43 helper.make_node(
44 '%sStringJoin' % prefix, ['identity1', 'identity2', 'identity3'],
45 ['customout'], domain=domain))
46
47 input0 = helper.make_tensor_value_info(
48 'text', onnx_proto.TensorProto.STRING, None)
49 input1 = helper.make_tensor_value_info(
50 'sep', onnx_proto.TensorProto.STRING, [1])
51 input2 = helper.make_tensor_value_info(
52 'axis', onnx_proto.TensorProto.INT64, [1])
53 output0 = helper.make_tensor_value_info(
54 'customout', onnx_proto.TensorProto.STRING, None)
55
56 graph = helper.make_graph(
57 nodes, 'test0', [input0, input1, input2], [output0])
58 model = helper.make_model(
59 graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
60 return model
61
62
63def _create_test_model_string_replace(prefix, domain='ai.onnx.contrib', global_replace=True):
64 nodes = []
65 nodes.append(
66 helper.make_node('Identity', ['text'], ['id1']))
67 nodes.append(
68 helper.make_node('Identity', ['pattern'], ['id2']))
69 nodes.append(
70 helper.make_node('Identity', ['rewrite'], ['id3']))
71 if global_replace:
72 nodes.append(
73 helper.make_node(
74 '%sStringRegexReplace' % prefix, ['id1', 'id2', 'id3'],
75 ['customout'], domain=domain))
76 else:
77 nodes.append(
78 helper.make_node(
79 '%sStringRegexReplace' % prefix, ['id1', 'id2', 'id3'],
80 ['customout'], domain=domain,
81 global_replace=0))
82
83 input0 = helper.make_tensor_value_info(
84 'text', onnx_proto.TensorProto.STRING, [None, 1])
85 input1 = helper.make_tensor_value_info(
86 'pattern', onnx_proto.TensorProto.STRING, [1])
87 input2 = helper.make_tensor_value_info(
88 'rewrite', onnx_proto.TensorProto.STRING, [1])
89 output0 = helper.make_tensor_value_info(
90 'customout', onnx_proto.TensorProto.STRING, [None, 1])
91
92 graph = helper.make_graph(
93 nodes, 'test0', [input0, input1, input2], [output0])
94 model = helper.make_model(
95 graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
96 return model
97
98
99def _create_test_model_string_to_hash(
100 prefix, domain='ai.onnx.contrib', kind=None):
101 if kind == 'crc32':
102 op_type = 'StringToCRC32'
103 out_type = onnx_proto.TensorProto.UINT32
104 in_type = out_type
105 elif kind == 'hash_bucket':
106 op_type = 'StringToHashBucket'
107 out_type = onnx_proto.TensorProto.INT64
108 in_type = out_type
109 elif kind == 'hash_bucket_fast':
110 op_type = 'StringToHashBucketFast'
111 out_type = onnx_proto.TensorProto.INT64
112 in_type = out_type
113 else:
114 raise ValueError('Unknown value %r.' % kind)
115 nodes = []
116 nodes.append(
117 helper.make_node('Identity', ['text'], ['id1']))
118 nodes.append(
119 helper.make_node('Identity', ['num_buckets'], ['id2']))
120 nodes.append(
121 helper.make_node(
122 '%s%s' % (prefix, op_type), ['id1', 'id2'],
123 ['customout'], domain=domain))
124
125 input0 = helper.make_tensor_value_info(
126 'text', onnx_proto.TensorProto.STRING, [None, None])
127 input1 = helper.make_tensor_value_info(
128 'num_buckets', in_type, [1])
129 output0 = helper.make_tensor_value_info(
130 'customout', out_type, [None, None])
131
132 graph = helper.make_graph(
133 nodes, 'test0', [input0, input1], [output0])
134 model = helper.make_model(
135 graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
136 return model
137
138
139def _create_test_model_string_equal(prefix, domain='ai.onnx.contrib'):
140 nodes = []
141 nodes.append(helper.make_node('Identity', ['x'], ['id1']))
142 nodes.append(helper.make_node('Identity', ['y'], ['id2']))
143 nodes.append(
144 helper.make_node(
145 '%sStringEqual' % prefix, ['id1', 'id2'], ['z'], domain=domain))
146
147 input0 = helper.make_tensor_value_info(
148 'x', onnx_proto.TensorProto.STRING, [])
149 input1 = helper.make_tensor_value_info(
150 'y', onnx_proto.TensorProto.STRING, [])
151 output0 = helper.make_tensor_value_info(
152 'z', onnx_proto.TensorProto.BOOL, [])
153
154 graph = helper.make_graph(nodes, 'test0', [input0, input1], [output0])
155 model = helper.make_model(
156 graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
157 return model
158
159
160def _create_test_model_string_split(prefix, domain='ai.onnx.contrib'):
161 nodes = []
162 nodes.append(helper.make_node('Identity', ['input'], ['id1']))
163 nodes.append(helper.make_node('Identity', ['delimiter'], ['id2']))
164 nodes.append(helper.make_node('Identity', ['skip_empty'], ['id3']))
165 nodes.append(
166 helper.make_node(
167 '%sStringSplit' % prefix, ['id1', 'id2', 'id3'],
168 ['indices', 'values', 'shape'], domain=domain))
169
170 input0 = helper.make_tensor_value_info(
171 'input', onnx_proto.TensorProto.STRING, [])
172 input1 = helper.make_tensor_value_info(
173 'delimiter', onnx_proto.TensorProto.STRING, [])
174 input2 = helper.make_tensor_value_info(
175 'skip_empty', onnx_proto.TensorProto.BOOL, [])
176 output0 = helper.make_tensor_value_info(
177 'indices', onnx_proto.TensorProto.INT64, [])
178 output1 = helper.make_tensor_value_info(
179 'values', onnx_proto.TensorProto.STRING, [])
180 output2 = helper.make_tensor_value_info(
181 'shape', onnx_proto.TensorProto.INT64, [])
182
183 graph = helper.make_graph(nodes, 'test0', [input0, input1, input2],
184 [output0, output1, output2])
185 model = helper.make_model(
186 graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
187 return model
188
189
190class TestPythonOpString(unittest.TestCase):
191
192 _string_join = None
193 _string_to_crc32 = None
194
195 @classmethod
196 def setUpClass(cls):
197
198 @onnx_op(op_type="PyStringUpper",
199 inputs=[PyCustomOpDef.dt_string],
200 outputs=[PyCustomOpDef.dt_string])
201 def string_upper(x):
202 # The user custom op implementation here.
203 return np.array([s.upper() for s in x.ravel()]).reshape(x.shape)
204
205 @onnx_op(op_type="PyStringJoin",
206 inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_string,
207 PyCustomOpDef.dt_int64],
208 outputs=[PyCustomOpDef.dt_string])
209 def string_join(x, sep, axis):
210 # The user custom op implementation here.
211 if sep.shape != (1, ):
212 raise RuntimeError(
213 "Unexpected shape {} for 'sep'.".format(sep.shape))
214 if axis.shape != (1, ):
215 raise RuntimeError(
216 "Unexpected shape {} for 'axis'.".format(axis.shape))
217 sp = sep[0]
218 ax = axis[0]
219 if ax < 0 or ax >= len(x.shape):
220 raise RuntimeError(
221 "axis must be in [%r,%r] but is %r" % (
222 0, len(x.shape), ax))
223 if len(x.shape) == 1:
224 return np.array([sp.join(x)])
225 dims = np.arange(len(x.shape))
226 dims[ax], dims[-1] = dims[-1], dims[ax]
227 x2 = np.transpose(x, dims)
228 res_shape = x2.shape[:-1]
229 x2 = x2.reshape((-1, x2.shape[-1]))
230 res = np.empty(x2.shape[0], dtype=x.dtype)
231 for i in range(x2.shape[0]):
232 res[i] = sp.join(x2[i, :])
233 return res.reshape(res_shape)
234
235 @onnx_op(op_type="PyStringRegexReplace",
236 inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_string,
237 PyCustomOpDef.dt_string],
238 outputs=[PyCustomOpDef.dt_string])
239 def string_replace(x, pattern, rewrite):
240 # The user custom op implementation here.
241 if pattern.shape != (1, ):
242 raise RuntimeError(
243 "Unexpected shape {} for 'pattern'.".format(pattern.shape))
244 if rewrite.shape != (1, ):
245 raise RuntimeError(
246 "Unexpected shape {} for 'rewrite'.".format(rewrite.shape))
247 reg = re.compile(pattern[0])
248 res = np.array(
249 list(map(lambda t: reg.sub(rewrite[0], t), x.ravel())))
250 return res.reshape(x.shape)
251
252 @onnx_op(op_type="PyStringToCRC32",
253 inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_uint32],
254 outputs=[PyCustomOpDef.dt_uint32])
255 def string_to_crc32(x, num_buckets):
256 if num_buckets.shape != (1, ):
257 raise RuntimeError(
258 "Unexpected shape {} for 'num_buckets'.".format(
259 num_buckets.shape))
260 nb = num_buckets[0]
261 res = np.array(
262 list(map(
263 lambda x: crc32(x.encode('iso-8859-15')) % nb,
264 x.ravel())))
265 return res.reshape(x.shape)
266
267 @onnx_op(op_type="PyStringToHashBucket",
268 inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_int64],
269 outputs=[PyCustomOpDef.dt_int64])
270 def string_to_hash_bucket(x, num_buckets):
271 if num_buckets.shape != (1, ):
272 raise RuntimeError(
273 "Unexpected shape {} for 'num_buckets'.".format(
274 num_buckets.shape))
275 nb = num_buckets[0]
276 res = np.array(
277 list(map(lambda x: hash_64(x, nb, True), x.ravel())))
278 return res.reshape(x.shape).astype(np.int64)
279
280 @onnx_op(op_type="PyStringEqual",
281 inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_string],
282 outputs=[PyCustomOpDef.dt_bool])
283 def string_equal(x, y):
284 return x == y
285
286 @onnx_op(op_type="PyStringSplit",
287 inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_string,
288 PyCustomOpDef.dt_bool],
289 outputs=[PyCustomOpDef.dt_int64, PyCustomOpDef.dt_string,
290 PyCustomOpDef.dt_int64])
291 def string_split(input, delimiter, skip_empty):
292 if delimiter.shape != (1, ):
293 raise RuntimeError("demiliter must a single element tensor.")
294 if skip_empty.shape != (1, ):
295 raise RuntimeError("skip_empty must a single element tensor.")
296 if len(input.shape) != 1:
297 raise RuntimeError("input must a one dimension tensor.")
298 delimiter = delimiter[0]
299 skip_empty = skip_empty[0]
300 texts = []
301 indices = []
302 max_split = 0
303 for row, text in enumerate(input):
304 if not text:
305 continue
306 res = text.split(delimiter)
307 if skip_empty:
308 res = [t for t in res if t]
309 texts.extend(res)
310 max_split = max(max_split, len(res))
311 indices.extend((row, i) for i in range(len(res)))
312 return (np.array(indices, dtype=np.int64),
313 np.array(texts),
314 np.array([len(input), max_split], dtype=np.int64))
315
316 cls._string_join = string_join
317 cls._string_to_crc32 = string_to_crc32
318
319 def test_check_types(self):
320 def_list = set(dir(PyCustomOpDef))
321 type_list = [
322 # 'dt_bfloat16',
323 'dt_bool',
324 'dt_complex128',
325 'dt_complex64',
326 'dt_double',
327 'dt_float',
328 'dt_float16',
329 'dt_int16',
330 'dt_int32',
331 'dt_int64',
332 'dt_int8',
333 'dt_string',
334 'dt_uint16',
335 'dt_uint32',
336 'dt_uint64',
337 'dt_uint8']
338 for t in type_list:
339 self.assertIn(t, def_list)
340
341 def test_string_upper_cc(self):
342 so = _ort.SessionOptions()
343 so.register_custom_ops_library(_get_library_path())
344 onnx_model = _create_test_model_string_upper('')
345 self.assertIn('op_type: "StringUpper"', str(onnx_model))
346 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
347 input_1 = np.array([["Abc"]])
348 txout = sess.run(None, {'input_1': input_1})
349 self.assertEqual(txout[0].tolist(), np.array([["ABC"]]).tolist())
350
351 def test_string_upper_cc_accent(self):
352 so = _ort.SessionOptions()
353 so.register_custom_ops_library(_get_library_path())
354 onnx_model = _create_test_model_string_upper('')
355 self.assertIn('op_type: "StringUpper"', str(onnx_model))
356 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
357 input_1 = np.array([["R"], ["Abcé"], ["ABC"], ["A"]])
358 txout = sess.run(None, {'input_1': input_1})
359 self.assertEqual(
360 txout[0].tolist(),
361 np.array([["R"], ["ABCé"], ["ABC"], ["A"]]).tolist())
362
363 def test_string_upper_python(self):
364 so = _ort.SessionOptions()
365 so.register_custom_ops_library(_get_library_path())
366 onnx_model = _create_test_model_string_upper('Py')
367 self.assertIn('op_type: "PyStringUpper"', str(onnx_model))
368 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
369 input_1 = np.array([["Abc"]])
370 txout = sess.run(None, {'input_1': input_1})
371 self.assertEqual(txout[0].tolist(), np.array([["ABC"]]).tolist())
372
373 def test_string_upper_python_accent(self):
374 so = _ort.SessionOptions()
375 so.register_custom_ops_library(_get_library_path())
376 onnx_model = _create_test_model_string_upper('Py')
377 self.assertIn('op_type: "PyStringUpper"', str(onnx_model))
378 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
379 input_1 = np.array([["Abcé"]])
380 txout = sess.run(None, {'input_1': input_1})
381 self.assertEqual(txout[0].tolist(),
382 np.array([["ABCé".upper()]]).tolist())
383
384 def test_string_join_python(self):
385 so = _ort.SessionOptions()
386 so.register_custom_ops_library(_get_library_path())
387 onnx_model = _create_test_model_string_join('Py')
388 self.assertIn('op_type: "PyStringJoin"', str(onnx_model))
389 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
390 text = np.vstack([np.array([["a", "b", "c"]]),
391 np.array([["aa", "bb", ""]])])
392 self.assertEqual(text.shape, (2, 3))
393 sep = np.array([";"])
394 axis = np.array([1], dtype=np.int64)
395 TestPythonOpString._string_join(text, sep, axis)
396 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
397 self.assertEqual(
398 txout[0].tolist(), np.array(["a;b;c", "aa;bb;"]).tolist())
399 axis = np.array([0], dtype=np.int64)
400 TestPythonOpString._string_join(text, sep, axis)
401 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
402 self.assertEqual(
403 txout[0].tolist(), np.array(['a;aa', 'b;bb', 'c;']).tolist())
404
405 def test_string_join_python_3d(self):
406 so = _ort.SessionOptions()
407 so.register_custom_ops_library(_get_library_path())
408 onnx_model = _create_test_model_string_join('Py')
409 self.assertIn('op_type: "PyStringJoin"', str(onnx_model))
410 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
411 text = np.vstack([np.array([["a", "b", "c"]]),
412 np.array([["aa", "bb", ""]])]).reshape((2, 3, 1))
413 sep = np.array([";"])
414 axis = np.array([1], dtype=np.int64)
415 TestPythonOpString._string_join(text, sep, axis)
416 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
417 self.assertEqual(
418 txout[0].tolist(), np.array([['a;b;c'], ['aa;bb;']]).tolist())
419
420 def test_string_join_python_1d(self):
421 so = _ort.SessionOptions()
422 so.register_custom_ops_library(_get_library_path())
423 onnx_model = _create_test_model_string_join('Py')
424 self.assertIn('op_type: "PyStringJoin"', str(onnx_model))
425 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
426 text = np.array(["a", "b", "cc"])
427 sep = np.array([";"])
428 axis = np.array([0], dtype=np.int64)
429 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
430 self.assertEqual(txout[0].shape, (1, ))
431 self.assertEqual(
432 txout[0].tolist(), np.array(["a;b;cc"]).tolist())
433
434 def test_string_join_cc(self):
435 so = _ort.SessionOptions()
436 so.register_custom_ops_library(_get_library_path())
437 onnx_model = _create_test_model_string_join('')
438 self.assertIn('op_type: "StringJoin"', str(onnx_model))
439 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
440 text = np.vstack([np.array([["a", "b", "c"]]),
441 np.array([["aa", "bb", ""]])])
442 sep = np.array([";"])
443 axis = np.array([1], dtype=np.int64)
444 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
445 self.assertEqual(
446 txout[0].tolist(), np.array(["a;b;c", "aa;bb;"]).tolist())
447 axis = np.array([0], dtype=np.int64)
448 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
449 self.assertEqual(
450 txout[0].tolist(), np.array(['a;aa', 'b;bb', 'c;']).tolist())
451
452 def test_string_join_cc_1d(self):
453 so = _ort.SessionOptions()
454 so.register_custom_ops_library(_get_library_path())
455 onnx_model = _create_test_model_string_join('')
456 self.assertIn('op_type: "StringJoin"', str(onnx_model))
457 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
458 text = np.array(["a", "b", "cc"])
459 sep = np.array([";"])
460 axis = np.array([0], dtype=np.int64)
461 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
462 self.assertEqual(
463 txout[0].tolist(), np.array(["a;b;cc"]).tolist())
464
465 def test_string_join_cc_3d(self):
466 so = _ort.SessionOptions()
467 so.register_custom_ops_library(_get_library_path())
468 onnx_model = _create_test_model_string_join('')
469 self.assertIn('op_type: "StringJoin"', str(onnx_model))
470 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
471 text = np.array(["a", "b", "c", "d", "e", "f", "g", "h"]).reshape((
472 2, 2, 2))
473 sep = np.array([";"])
474 axis = np.array([2], dtype=np.int64)
475 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
476 self.assertEqual(
477 txout[0].tolist(),
478 np.array([['a;b', 'c;d'], ['e;f', 'g;h']]).tolist())
479 axis = np.array([1], dtype=np.int64)
480 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
481 self.assertEqual(
482 txout[0].tolist(),
483 np.array([['a;c', 'b;d'], ['e;g', 'f;h']]).tolist())
484 axis = np.array([0], dtype=np.int64)
485 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
486 self.assertEqual(
487 txout[0].tolist(),
488 np.array([['a;e', 'b;f'], ['c;g', 'd;h']]).tolist())
489
490 def test_string_replace_cc(self):
491 so = _ort.SessionOptions()
492 so.register_custom_ops_library(_get_library_path())
493 onnx_model = _create_test_model_string_replace('')
494 self.assertIn('op_type: "StringRegexReplace"', str(onnx_model))
495 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
496 pattern = np.array([r'def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):'])
497 rewrite = np.array([r'static PyObject* py_\1(void) {'])
498 text = np.array([['def myfunc():'], ['def dummy():']])
499 txout = sess.run(
500 None, {'text': text, 'pattern': pattern, 'rewrite': rewrite})
501 exp = [['static PyObject* py_myfunc(void) {'],
502 ['static PyObject* py_dummy(void) {']]
503 self.assertEqual(exp, txout[0].tolist())
504
505 def test_string_replace_cc_first(self):
506 so = _ort.SessionOptions()
507 so.register_custom_ops_library(_get_library_path())
508 onnx_model = _create_test_model_string_replace('', global_replace=False)
509 self.assertIn('op_type: "StringRegexReplace"', str(onnx_model))
510 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
511 pattern = np.array([r'def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):'])
512 rewrite = np.array([r'static PyObject* py_\1(void) {'])
513 text = np.array([['def myfunc():def myfunc():'], ['def dummy():def dummy():']])
514 txout = sess.run(
515 None, {'text': text, 'pattern': pattern, 'rewrite': rewrite})
516 exp = [['static PyObject* py_myfunc(void) {def myfunc():'],
517 ['static PyObject* py_dummy(void) {def dummy():']]
518 self.assertEqual(exp, txout[0].tolist())
519
520 def test_string_replace_cc_x2(self):
521 so = _ort.SessionOptions()
522 so.register_custom_ops_library(_get_library_path())
523 onnx_model = _create_test_model_string_replace('')
524 self.assertIn('op_type: "StringRegexReplace"', str(onnx_model))
525 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
526 pattern = np.array([r'def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):'])
527 rewrite = np.array([r'static PyObject* py_\1(void) {'])
528 text = np.array([['def myfunc():'], ['def dummy():' * 2]])
529 txout = sess.run(
530 None, {'text': text, 'pattern': pattern, 'rewrite': rewrite})
531 exp = [['static PyObject* py_myfunc(void) {'],
532 ['static PyObject* py_dummy(void) {' * 2]]
533 self.assertEqual(exp, txout[0].tolist())
534
535 def test_string_replace_python(self):
536 so = _ort.SessionOptions()
537 so.register_custom_ops_library(_get_library_path())
538 onnx_model = _create_test_model_string_replace('Py')
539 self.assertIn('op_type: "PyStringRegexReplace"', str(onnx_model))
540 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
541 pattern = np.array([r'def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):'])
542 rewrite = np.array([r'static PyObject*\npy_\1(void)\n{'])
543 text = np.array([['def myfunc():'], ['def dummy():']])
544 txout = sess.run(
545 None, {'text': text, 'pattern': pattern, 'rewrite': rewrite})
546 exp = [['static PyObject*\npy_myfunc(void)\n{'],
547 ['static PyObject*\npy_dummy(void)\n{']]
548 self.assertEqual(exp, txout[0].tolist())
549
550 def test_string_replace_python_x2(self):
551 so = _ort.SessionOptions()
552 so.register_custom_ops_library(_get_library_path())
553 onnx_model = _create_test_model_string_replace('Py')
554 self.assertIn('op_type: "PyStringRegexReplace"', str(onnx_model))
555 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
556 pattern = np.array([r'def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):'])
557 rewrite = np.array([r'static PyObject*\npy_\1(void)\n{'])
558 text = np.array([['def myfunc():'], ['def dummy():' * 2]])
559 txout = sess.run(
560 None, {'text': text, 'pattern': pattern, 'rewrite': rewrite})
561 exp = [['static PyObject*\npy_myfunc(void)\n{'],
562 ['static PyObject*\npy_dummy(void)\n{' * 2]]
563 self.assertEqual(exp, txout[0].tolist())
564
565 def test_string_to_crc32_python(self):
566 so = _ort.SessionOptions()
567 so.register_custom_ops_library(_get_library_path())
568 onnx_model = _create_test_model_string_to_hash('Py', kind='crc32')
569 self.assertIn('op_type: "PyStringToCRC32"', str(onnx_model))
570 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
571 text = np.array([["abc", "abcdé"], ["$$^l!%*ù", ""]])
572 num_buckets = np.array([44], dtype=np.uint32)
573 res = self._string_to_crc32(text, num_buckets)
574 self.assertEqual(res.shape, text.shape)
575 exp = np.array([[10, 38], [29, 0]], dtype=np.uint32)
576 self.assertEqual(exp.tolist(), res.tolist())
577 txout = sess.run(
578 None, {'text': text, 'num_buckets': num_buckets})
579 self.assertEqual(exp.tolist(), txout[0].tolist())
580
581 def test_string_to_hash_bucket_cc(self):
582 so = _ort.SessionOptions()
583 so.register_custom_ops_library(_get_library_path())
584 onnx_model = _create_test_model_string_to_hash(
585 '', kind='hash_bucket')
586 self.assertIn('op_type: "StringToHashBucket"', str(onnx_model))
587 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
588 raw = ["abc", "abcdé", "$$^l!%*ù", "", "a", "A"]
589 text = np.array(raw).reshape((3, 2))
590 num_buckets = np.array([NUM_BUCKETS], dtype=np.int64)
591 txout = sess.run(
592 None, {'text': text, 'num_buckets': num_buckets})
593 try:
594 from tensorflow.raw_ops import StringToHashBucket
595 dotf = True
596 except ImportError:
597 dotf = False
598 if dotf:
599 tfres = StringToHashBucket(
600 string_tensor=text, num_buckets=num_buckets[0])
601 self.assertEqual(tfres.shape, txout[0].shape)
602 self.assertEqual(tfres.numpy().tolist(), txout[0].tolist())
603 exp = np.array([[15, 11], [10, 21], [20, 21]], dtype=np.int64)
604 self.assertEqual(exp.shape, txout[0].shape)
605 self.assertEqual(exp.tolist(), txout[0].tolist())
606
607 def test_string_to_hash_bucket_fast_cc(self):
608 so = _ort.SessionOptions()
609 so.register_custom_ops_library(_get_library_path())
610 onnx_model = _create_test_model_string_to_hash(
611 '', kind='hash_bucket_fast')
612 self.assertIn('op_type: "StringToHashBucketFast"', str(onnx_model))
613 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
614 raw = ["abc", "abcdé", "$$^l!%*ù", "", "a", "A"]
615 text = np.array(raw).reshape((3, 2))
616 num_buckets = np.array([NUM_BUCKETS], dtype=np.int64)
617 txout = sess.run(
618 None, {'text': text, 'num_buckets': num_buckets})
619 try:
620 from tensorflow.raw_ops import StringToHashBucketFast
621 dotf = True
622 except ImportError:
623 dotf = False
624 if dotf:
625 tfres = StringToHashBucketFast(
626 input=text, num_buckets=num_buckets[0])
627 self.assertEqual(tfres.shape, txout[0].shape)
628 self.assertEqual(tfres.numpy().tolist(), txout[0].tolist())
629 exp = np.array([[9, 17], [4, 21], [14, 12]], dtype=np.int64)
630 self.assertEqual(exp.shape, txout[0].shape)
631 self.assertEqual(exp.tolist(), txout[0].tolist())
632
633 def test_string_to_hash_bucket_python(self):
634 so = _ort.SessionOptions()
635 so.register_custom_ops_library(_get_library_path())
636 onnx_model = _create_test_model_string_to_hash(
637 'Py', kind='hash_bucket')
638 self.assertIn('op_type: "PyStringToHashBucket"', str(onnx_model))
639 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
640 raw = ["abc", "abcdé", "$$^l!%*ù", "", "a", "A"]
641 text = np.array(raw).reshape((3, 2))
642 num_buckets = np.array([NUM_BUCKETS], dtype=np.int64)
643 exp = np.array([[9, 17], [4, 21], [14, 12]], dtype=np.int64)
644 txout = sess.run(
645 None, {'text': text, 'num_buckets': num_buckets})
646 self.assertEqual(exp.shape, txout[0].shape)
647 self.assertEqual(exp.tolist(), txout[0].tolist())
648
649 def enumerate_matrix_couples(self):
650 for i in range(1, 5):
651 shape = (3,) * i
652 a = (np.random.rand(*shape) * 10).astype(np.int32).astype(np.str)
653 yield a, a
654 for j in range(i):
655 shape2 = list(shape)
656 shape2[j] = 1
657 b = (np.random.rand(*shape2) * 10).astype(
658 np.int32).astype(np.str)
659 yield a, b
660 for k in range(j+1, i):
661 shape3 = list(shape2)
662 shape3[k] = 1
663 b = (np.random.rand(*shape3) * 10).astype(
664 np.int32).astype(np.str)
665 yield a, b
666
667 def test_string_equal_python(self):
668 so = _ort.SessionOptions()
669 so.register_custom_ops_library(_get_library_path())
670 onnx_model = _create_test_model_string_equal('Py')
671 self.assertIn('op_type: "PyStringEqual"', str(onnx_model))
672 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
673
674 for x, y in self.enumerate_matrix_couples():
675 txout = sess.run(None, {'x': x, 'y': y})
676 self.assertEqual(txout[0].tolist(), (x == y).tolist())
677 txout = sess.run(None, {'x': y, 'y': x})
678 self.assertEqual(txout[0].tolist(), (y == x).tolist())
679
680 def test_string_equal_cc(self):
681 so = _ort.SessionOptions()
682 so.register_custom_ops_library(_get_library_path())
683 onnx_model = _create_test_model_string_equal('')
684 self.assertIn('op_type: "StringEqual"', str(onnx_model))
685 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
686
687 for x, y in self.enumerate_matrix_couples():
688 txout = sess.run(None, {'x': x, 'y': y})
689 self.assertEqual(txout[0].tolist(), (x == y).tolist())
690 txout = sess.run(None, {'x': y, 'y': x})
691 self.assertEqual(txout[0].tolist(), (y == x).tolist())
692
693 def test_string_split_python(self):
694 so = _ort.SessionOptions()
695 so.register_custom_ops_library(_get_library_path())
696 onnx_model = _create_test_model_string_split('Py')
697 self.assertIn('op_type: "PyStringSplit"', str(onnx_model))
698 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
699 input = np.array(["a,,b", "", "aa,b,c", "dddddd"])
700 delimiter = np.array([","])
701
702 for skip in [True, False]:
703 with self.subTest(skip=skip):
704 skip_empty = np.array([skip])
705
706 txout = sess.run(
707 None, {'input': input, 'delimiter': delimiter,
708 'skip_empty': skip_empty})
709
710 if skip_empty:
711 exp_indices = np.array(
712 [[0, 0], [0, 1], [2, 0], [2, 1], [2, 2], [3, 0]])
713 exp_text = np.array(['a', 'b', 'aa', 'b', 'c', 'dddddd'])
714 else:
715 exp_indices = np.array(
716 [[0, 0], [0, 1], [0, 2], [2, 0], [2, 1],
717 [2, 2], [3, 0]])
718 exp_text = np.array(
719 ['a', '', 'b', 'aa', 'b', 'c', 'dddddd'])
720 exp_shape = np.array([4, 3])
721 self.assertEqual(exp_indices.tolist(), txout[0].tolist())
722 self.assertEqual(exp_text.tolist(), txout[1].tolist())
723 self.assertEqual(exp_shape.tolist(), txout[2].tolist())
724
725 def test_string_split_cc(self):
726 so = _ort.SessionOptions()
727 so.register_custom_ops_library(_get_library_path())
728 onnx_model = _create_test_model_string_split('')
729 self.assertIn('op_type: "StringSplit"', str(onnx_model))
730 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
731 input = np.array(["a,,b", "", "aa,b,c", "dddddd"])
732 delimiter = np.array([","])
733
734 for skip in [True, False]:
735 with self.subTest(skip=skip):
736 skip_empty = np.array([skip])
737
738 txout = sess.run(
739 None, {'input': input, 'delimiter': delimiter,
740 'skip_empty': skip_empty})
741
742 try:
743 from tensorflow.raw_ops import StringSplit
744 dotf = True
745 except ImportError:
746 dotf = False
747 if dotf:
748 tfres = StringSplit(
749 input=input, delimiter=",,", skip_empty=skip)
750 self.assertEqual(
751 [_.decode() for _ in tfres[1].numpy().tolist()],
752 txout[1].tolist())
753 self.assertEqual(
754 tfres[0].numpy().tolist(), txout[0].tolist())
755 self.assertEqual(
756 tfres[2].numpy().tolist(), txout[2].tolist())
757
758 if skip_empty:
759 exp_indices = np.array(
760 [[0, 0], [0, 1], [2, 0], [2, 1], [2, 2], [3, 0]])
761 exp_text = np.array(['a', 'b', 'aa', 'b', 'c', 'dddddd'])
762 else:
763 exp_indices = np.array(
764 [[0, 0], [0, 1], [0, 2], [2, 0], [2, 1],
765 [2, 2], [3, 0]])
766 exp_text = np.array(
767 ['a', '', 'b', 'aa', 'b', 'c', 'dddddd'])
768 exp_shape = np.array([4, 3])
769 self.assertEqual(exp_indices.tolist(), txout[0].tolist())
770 self.assertEqual(exp_text.tolist(), txout[1].tolist())
771 self.assertEqual(exp_shape.tolist(), txout[2].tolist())
772
773 def test_string_split_cc_sep2(self):
774 so = _ort.SessionOptions()
775 so.register_custom_ops_library(_get_library_path())
776 onnx_model = _create_test_model_string_split('')
777 self.assertIn('op_type: "StringSplit"', str(onnx_model))
778 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
779 input = np.array(["a*b", "a,*b", "aa,b,,c", 'z', "dddddd,", "**"])
780 delimiter = np.array([",*"])
781
782 for skip in [True, False]:
783 with self.subTest(skip=skip):
784 skip_empty = np.array([skip])
785
786 txout = sess.run(
787 None, {'input': input, 'delimiter': delimiter,
788 'skip_empty': skip_empty})
789
790 try:
791 from tensorflow.raw_ops import StringSplit
792 dotf = True
793 except ImportError:
794 dotf = False
795 if dotf:
796 tfres = StringSplit(
797 input=input, delimiter=",*", skip_empty=skip)
798 self.assertEqual(
799 [_.decode() for _ in tfres[1].numpy().tolist()],
800 txout[1].tolist())
801 self.assertEqual(
802 tfres[0].numpy().tolist(), txout[0].tolist())
803 self.assertEqual(
804 tfres[2].numpy().tolist(), txout[2].tolist())
805
806 if skip_empty:
807 exp_indices = np.array(
808 [[0, 0], [0, 1], [1, 0], [1, 1], [2, 0], [2, 1],
809 [2, 2], [3, 0], [4, 0]])
810 exp_text = np.array(
811 ['a', 'b', 'a', 'b', 'aa', 'b', 'c', 'z', 'dddddd'])
812 exp_shape = np.array([6, 3])
813 else:
814 exp_indices = np.array(
815 [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [2, 0],
816 [2, 1], [2, 2], [2, 3], [3, 0], [4, 0], [4, 1],
817 [5, 0], [5, 1], [5, 2]])
818 exp_text = np.array(
819 ['a', 'b', 'a', '', 'b', 'aa', 'b', '', 'c',
820 'z', 'dddddd', '', '', '', ''])
821 exp_shape = np.array([6, 4])
822 self.assertEqual(exp_text.tolist(), txout[1].tolist())
823 self.assertEqual(exp_indices.tolist(), txout[0].tolist())
824 self.assertEqual(exp_shape.tolist(), txout[2].tolist())
825
826 def test_string_split_cc_sep0(self):
827 so = _ort.SessionOptions()
828 so.register_custom_ops_library(_get_library_path())
829 onnx_model = _create_test_model_string_split('')
830 self.assertIn('op_type: "StringSplit"', str(onnx_model))
831 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
832 input = np.array(["a*b", "a,*b"])
833 delimiter = np.array([""])
834
835 for skip in [True, False]:
836 with self.subTest(skip=skip):
837 skip_empty = np.array([skip])
838
839 txout = sess.run(
840 None, {'input': input, 'delimiter': delimiter,
841 'skip_empty': skip_empty})
842
843 try:
844 from tensorflow.raw_ops import StringSplit
845 dotf = True
846 except ImportError:
847 dotf = False
848 if dotf:
849 tfres = StringSplit(
850 input=input, delimiter="", skip_empty=skip)
851 self.assertEqual(
852 [_.decode() for _ in tfres[1].numpy().tolist()],
853 txout[1].tolist())
854 self.assertEqual(
855 tfres[0].numpy().tolist(), txout[0].tolist())
856 self.assertEqual(
857 tfres[2].numpy().tolist(), txout[2].tolist())
858
859 exp_indices = np.array(
860 [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [1, 3]])
861 exp_text = np.array(['a', '*', 'b', 'a', ',', '*', 'b'])
862 exp_shape = np.array([2, 4])
863 self.assertEqual(exp_text.tolist(), txout[1].tolist())
864 self.assertEqual(exp_indices.tolist(), txout[0].tolist())
865 self.assertEqual(exp_shape.tolist(), txout[2].tolist())
866
867
868if __name__ == "__main__":
869 unittest.main()
870