microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
d1c657486d908aaf4c494d9a02871cae7131401e

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

test/test_string_ops.py

847lines · modecode

1# coding: utf-8
2import unittest
3import re
4from binascii import crc32
5import numpy as np
6from onnx import helper, onnx_pb as onnx_proto
7import onnxruntime as _ort
8from onnxruntime_customops import (
9 onnx_op, PyCustomOpDef,
10 get_library_path as _get_library_path,
11 hash_64)
12
13NUM_BUCKETS = 23
14
15
16def _create_test_model_string_upper(prefix, domain='ai.onnx.contrib'):
17 nodes = []
18 nodes[0:] = [helper.make_node('Identity', ['input_1'], ['identity1'])]
19 nodes[1:] = [helper.make_node('%sStringUpper' % prefix,
20 ['identity1'], ['customout'],
21 domain=domain)]
22
23 input0 = helper.make_tensor_value_info(
24 'input_1', onnx_proto.TensorProto.STRING, [None, None])
25 output0 = helper.make_tensor_value_info(
26 'customout', onnx_proto.TensorProto.STRING, [None, None])
27
28 graph = helper.make_graph(nodes, 'test0', [input0], [output0])
29 model = helper.make_model(
30 graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
31 return model
32
33
34def _create_test_model_string_join(prefix, domain='ai.onnx.contrib'):
35 nodes = []
36 nodes.append(
37 helper.make_node('Identity', ['text'], ['identity1']))
38 nodes.append(
39 helper.make_node('Identity', ['sep'], ['identity2']))
40 nodes.append(
41 helper.make_node('Identity', ['axis'], ['identity3']))
42 nodes.append(
43 helper.make_node(
44 '%sStringJoin' % prefix, ['identity1', 'identity2', 'identity3'],
45 ['customout'], domain=domain))
46
47 input0 = helper.make_tensor_value_info(
48 'text', onnx_proto.TensorProto.STRING, None)
49 input1 = helper.make_tensor_value_info(
50 'sep', onnx_proto.TensorProto.STRING, [1])
51 input2 = helper.make_tensor_value_info(
52 'axis', onnx_proto.TensorProto.INT64, [1])
53 output0 = helper.make_tensor_value_info(
54 'customout', onnx_proto.TensorProto.STRING, None)
55
56 graph = helper.make_graph(
57 nodes, 'test0', [input0, input1, input2], [output0])
58 model = helper.make_model(
59 graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
60 return model
61
62
63def _create_test_model_string_replace(prefix, domain='ai.onnx.contrib'):
64 nodes = []
65 nodes.append(
66 helper.make_node('Identity', ['text'], ['id1']))
67 nodes.append(
68 helper.make_node('Identity', ['pattern'], ['id2']))
69 nodes.append(
70 helper.make_node('Identity', ['rewrite'], ['id3']))
71 nodes.append(
72 helper.make_node(
73 '%sStringRegexReplace' % prefix, ['id1', 'id2', 'id3'],
74 ['customout'], domain=domain))
75
76 input0 = helper.make_tensor_value_info(
77 'text', onnx_proto.TensorProto.STRING, [None, 1])
78 input1 = helper.make_tensor_value_info(
79 'pattern', onnx_proto.TensorProto.STRING, [1])
80 input2 = helper.make_tensor_value_info(
81 'rewrite', onnx_proto.TensorProto.STRING, [1])
82 output0 = helper.make_tensor_value_info(
83 'customout', onnx_proto.TensorProto.STRING, [None, 1])
84
85 graph = helper.make_graph(
86 nodes, 'test0', [input0, input1, input2], [output0])
87 model = helper.make_model(
88 graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
89 return model
90
91
92def _create_test_model_string_to_hash(
93 prefix, domain='ai.onnx.contrib', kind=None):
94 if kind == 'crc32':
95 op_type = 'StringToCRC32'
96 out_type = onnx_proto.TensorProto.UINT32
97 in_type = out_type
98 elif kind == 'hash_bucket':
99 op_type = 'StringToHashBucket'
100 out_type = onnx_proto.TensorProto.INT64
101 in_type = out_type
102 elif kind == 'hash_bucket_fast':
103 op_type = 'StringToHashBucketFast'
104 out_type = onnx_proto.TensorProto.INT64
105 in_type = out_type
106 else:
107 raise ValueError('Unknown value %r.' % kind)
108 nodes = []
109 nodes.append(
110 helper.make_node('Identity', ['text'], ['id1']))
111 nodes.append(
112 helper.make_node('Identity', ['num_buckets'], ['id2']))
113 nodes.append(
114 helper.make_node(
115 '%s%s' % (prefix, op_type), ['id1', 'id2'],
116 ['customout'], domain=domain))
117
118 input0 = helper.make_tensor_value_info(
119 'text', onnx_proto.TensorProto.STRING, [None, None])
120 input1 = helper.make_tensor_value_info(
121 'num_buckets', in_type, [1])
122 output0 = helper.make_tensor_value_info(
123 'customout', out_type, [None, None])
124
125 graph = helper.make_graph(
126 nodes, 'test0', [input0, input1], [output0])
127 model = helper.make_model(
128 graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
129 return model
130
131
132def _create_test_model_string_equal(prefix, domain='ai.onnx.contrib'):
133 nodes = []
134 nodes.append(helper.make_node('Identity', ['x'], ['id1']))
135 nodes.append(helper.make_node('Identity', ['y'], ['id2']))
136 nodes.append(
137 helper.make_node(
138 '%sStringEqual' % prefix, ['id1', 'id2'], ['z'], domain=domain))
139
140 input0 = helper.make_tensor_value_info(
141 'x', onnx_proto.TensorProto.STRING, [])
142 input1 = helper.make_tensor_value_info(
143 'y', onnx_proto.TensorProto.STRING, [])
144 output0 = helper.make_tensor_value_info(
145 'z', onnx_proto.TensorProto.BOOL, [])
146
147 graph = helper.make_graph(nodes, 'test0', [input0, input1], [output0])
148 model = helper.make_model(
149 graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
150 return model
151
152
153def _create_test_model_string_split(prefix, domain='ai.onnx.contrib'):
154 nodes = []
155 nodes.append(helper.make_node('Identity', ['input'], ['id1']))
156 nodes.append(helper.make_node('Identity', ['delimiter'], ['id2']))
157 nodes.append(helper.make_node('Identity', ['skip_empty'], ['id3']))
158 nodes.append(
159 helper.make_node(
160 '%sStringSplit' % prefix, ['id1', 'id2', 'id3'],
161 ['indices', 'values', 'shape'], domain=domain))
162
163 input0 = helper.make_tensor_value_info(
164 'input', onnx_proto.TensorProto.STRING, [])
165 input1 = helper.make_tensor_value_info(
166 'delimiter', onnx_proto.TensorProto.STRING, [])
167 input2 = helper.make_tensor_value_info(
168 'skip_empty', onnx_proto.TensorProto.BOOL, [])
169 output0 = helper.make_tensor_value_info(
170 'indices', onnx_proto.TensorProto.INT64, [])
171 output1 = helper.make_tensor_value_info(
172 'values', onnx_proto.TensorProto.STRING, [])
173 output2 = helper.make_tensor_value_info(
174 'shape', onnx_proto.TensorProto.INT64, [])
175
176 graph = helper.make_graph(nodes, 'test0', [input0, input1, input2],
177 [output0, output1, output2])
178 model = helper.make_model(
179 graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
180 return model
181
182
183class TestPythonOpString(unittest.TestCase):
184
185 _string_join = None
186 _string_to_crc32 = None
187
188 @classmethod
189 def setUpClass(cls):
190
191 @onnx_op(op_type="PyStringUpper",
192 inputs=[PyCustomOpDef.dt_string],
193 outputs=[PyCustomOpDef.dt_string])
194 def string_upper(x):
195 # The user custom op implementation here.
196 return np.array([s.upper() for s in x.ravel()]).reshape(x.shape)
197
198 @onnx_op(op_type="PyStringJoin",
199 inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_string,
200 PyCustomOpDef.dt_int64],
201 outputs=[PyCustomOpDef.dt_string])
202 def string_join(x, sep, axis):
203 # The user custom op implementation here.
204 if sep.shape != (1, ):
205 raise RuntimeError(
206 "Unexpected shape {} for 'sep'.".format(sep.shape))
207 if axis.shape != (1, ):
208 raise RuntimeError(
209 "Unexpected shape {} for 'axis'.".format(axis.shape))
210 sp = sep[0]
211 ax = axis[0]
212 if ax < 0 or ax >= len(x.shape):
213 raise RuntimeError(
214 "axis must be in [%r,%r] but is %r" % (
215 0, len(x.shape), ax))
216 if len(x.shape) == 1:
217 return np.array([sp.join(x)])
218 dims = np.arange(len(x.shape))
219 dims[ax], dims[-1] = dims[-1], dims[ax]
220 x2 = np.transpose(x, dims)
221 res_shape = x2.shape[:-1]
222 x2 = x2.reshape((-1, x2.shape[-1]))
223 res = np.empty(x2.shape[0], dtype=x.dtype)
224 for i in range(x2.shape[0]):
225 res[i] = sp.join(x2[i, :])
226 return res.reshape(res_shape)
227
228 @onnx_op(op_type="PyStringRegexReplace",
229 inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_string,
230 PyCustomOpDef.dt_string],
231 outputs=[PyCustomOpDef.dt_string])
232 def string_replace(x, pattern, rewrite):
233 # The user custom op implementation here.
234 if pattern.shape != (1, ):
235 raise RuntimeError(
236 "Unexpected shape {} for 'pattern'.".format(pattern.shape))
237 if rewrite.shape != (1, ):
238 raise RuntimeError(
239 "Unexpected shape {} for 'rewrite'.".format(rewrite.shape))
240 reg = re.compile(pattern[0])
241 res = np.array(
242 list(map(lambda t: reg.sub(rewrite[0], t), x.ravel())))
243 return res.reshape(x.shape)
244
245 @onnx_op(op_type="PyStringToCRC32",
246 inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_uint32],
247 outputs=[PyCustomOpDef.dt_uint32])
248 def string_to_crc32(x, num_buckets):
249 if num_buckets.shape != (1, ):
250 raise RuntimeError(
251 "Unexpected shape {} for 'num_buckets'.".format(
252 num_buckets.shape))
253 nb = num_buckets[0]
254 res = np.array(
255 list(map(
256 lambda x: crc32(x.encode('iso-8859-15')) % nb,
257 x.ravel())))
258 return res.reshape(x.shape)
259
260 @onnx_op(op_type="PyStringToHashBucket",
261 inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_int64],
262 outputs=[PyCustomOpDef.dt_int64])
263 def string_to_hash_bucket(x, num_buckets):
264 if num_buckets.shape != (1, ):
265 raise RuntimeError(
266 "Unexpected shape {} for 'num_buckets'.".format(
267 num_buckets.shape))
268 nb = num_buckets[0]
269 res = np.array(
270 list(map(lambda x: hash_64(x, nb, True), x.ravel())))
271 return res.reshape(x.shape).astype(np.int64)
272
273 @onnx_op(op_type="PyStringEqual",
274 inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_string],
275 outputs=[PyCustomOpDef.dt_bool])
276 def string_equal(x, y):
277 return x == y
278
279 @onnx_op(op_type="PyStringSplit",
280 inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_string,
281 PyCustomOpDef.dt_bool],
282 outputs=[PyCustomOpDef.dt_int64, PyCustomOpDef.dt_string,
283 PyCustomOpDef.dt_int64])
284 def string_split(input, delimiter, skip_empty):
285 if delimiter.shape != (1, ):
286 raise RuntimeError("demiliter must a single element tensor.")
287 if skip_empty.shape != (1, ):
288 raise RuntimeError("skip_empty must a single element tensor.")
289 if len(input.shape) != 1:
290 raise RuntimeError("input must a one dimension tensor.")
291 delimiter = delimiter[0]
292 skip_empty = skip_empty[0]
293 texts = []
294 indices = []
295 max_split = 0
296 for row, text in enumerate(input):
297 if not text:
298 continue
299 res = text.split(delimiter)
300 if skip_empty:
301 res = [t for t in res if t]
302 texts.extend(res)
303 max_split = max(max_split, len(res))
304 indices.extend((row, i) for i in range(len(res)))
305 return (np.array(indices, dtype=np.int64),
306 np.array(texts),
307 np.array([len(input), max_split], dtype=np.int64))
308
309 cls._string_join = string_join
310 cls._string_to_crc32 = string_to_crc32
311
312 def test_check_types(self):
313 def_list = set(dir(PyCustomOpDef))
314 type_list = [
315 # 'dt_bfloat16',
316 'dt_bool',
317 'dt_complex128',
318 'dt_complex64',
319 'dt_double',
320 'dt_float',
321 'dt_float16',
322 'dt_int16',
323 'dt_int32',
324 'dt_int64',
325 'dt_int8',
326 'dt_string',
327 'dt_uint16',
328 'dt_uint32',
329 'dt_uint64',
330 'dt_uint8']
331 for t in type_list:
332 self.assertIn(t, def_list)
333
334 def test_string_upper_cc(self):
335 so = _ort.SessionOptions()
336 so.register_custom_ops_library(_get_library_path())
337 onnx_model = _create_test_model_string_upper('')
338 self.assertIn('op_type: "StringUpper"', str(onnx_model))
339 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
340 input_1 = np.array([["Abc"]])
341 txout = sess.run(None, {'input_1': input_1})
342 self.assertEqual(txout[0].tolist(), np.array([["ABC"]]).tolist())
343
344 def test_string_upper_cc_accent(self):
345 so = _ort.SessionOptions()
346 so.register_custom_ops_library(_get_library_path())
347 onnx_model = _create_test_model_string_upper('')
348 self.assertIn('op_type: "StringUpper"', str(onnx_model))
349 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
350 input_1 = np.array([["R"], ["Abcé"], ["ABC"], ["A"]])
351 txout = sess.run(None, {'input_1': input_1})
352 self.assertEqual(
353 txout[0].tolist(),
354 np.array([["R"], ["ABCé"], ["ABC"], ["A"]]).tolist())
355
356 def test_string_upper_python(self):
357 so = _ort.SessionOptions()
358 so.register_custom_ops_library(_get_library_path())
359 onnx_model = _create_test_model_string_upper('Py')
360 self.assertIn('op_type: "PyStringUpper"', str(onnx_model))
361 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
362 input_1 = np.array([["Abc"]])
363 txout = sess.run(None, {'input_1': input_1})
364 self.assertEqual(txout[0].tolist(), np.array([["ABC"]]).tolist())
365
366 def test_string_upper_python_accent(self):
367 so = _ort.SessionOptions()
368 so.register_custom_ops_library(_get_library_path())
369 onnx_model = _create_test_model_string_upper('Py')
370 self.assertIn('op_type: "PyStringUpper"', str(onnx_model))
371 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
372 input_1 = np.array([["Abcé"]])
373 txout = sess.run(None, {'input_1': input_1})
374 self.assertEqual(txout[0].tolist(),
375 np.array([["ABCé".upper()]]).tolist())
376
377 def test_string_join_python(self):
378 so = _ort.SessionOptions()
379 so.register_custom_ops_library(_get_library_path())
380 onnx_model = _create_test_model_string_join('Py')
381 self.assertIn('op_type: "PyStringJoin"', str(onnx_model))
382 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
383 text = np.vstack([np.array([["a", "b", "c"]]),
384 np.array([["aa", "bb", ""]])])
385 self.assertEqual(text.shape, (2, 3))
386 sep = np.array([";"])
387 axis = np.array([1], dtype=np.int64)
388 TestPythonOpString._string_join(text, sep, axis)
389 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
390 self.assertEqual(
391 txout[0].tolist(), np.array(["a;b;c", "aa;bb;"]).tolist())
392 axis = np.array([0], dtype=np.int64)
393 TestPythonOpString._string_join(text, sep, axis)
394 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
395 self.assertEqual(
396 txout[0].tolist(), np.array(['a;aa', 'b;bb', 'c;']).tolist())
397
398 def test_string_join_python_3d(self):
399 so = _ort.SessionOptions()
400 so.register_custom_ops_library(_get_library_path())
401 onnx_model = _create_test_model_string_join('Py')
402 self.assertIn('op_type: "PyStringJoin"', str(onnx_model))
403 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
404 text = np.vstack([np.array([["a", "b", "c"]]),
405 np.array([["aa", "bb", ""]])]).reshape((2, 3, 1))
406 sep = np.array([";"])
407 axis = np.array([1], dtype=np.int64)
408 TestPythonOpString._string_join(text, sep, axis)
409 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
410 self.assertEqual(
411 txout[0].tolist(), np.array([['a;b;c'], ['aa;bb;']]).tolist())
412
413 def test_string_join_python_1d(self):
414 so = _ort.SessionOptions()
415 so.register_custom_ops_library(_get_library_path())
416 onnx_model = _create_test_model_string_join('Py')
417 self.assertIn('op_type: "PyStringJoin"', str(onnx_model))
418 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
419 text = np.array(["a", "b", "cc"])
420 sep = np.array([";"])
421 axis = np.array([0], dtype=np.int64)
422 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
423 self.assertEqual(txout[0].shape, (1, ))
424 self.assertEqual(
425 txout[0].tolist(), np.array(["a;b;cc"]).tolist())
426
427 def test_string_join_cc(self):
428 so = _ort.SessionOptions()
429 so.register_custom_ops_library(_get_library_path())
430 onnx_model = _create_test_model_string_join('')
431 self.assertIn('op_type: "StringJoin"', str(onnx_model))
432 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
433 text = np.vstack([np.array([["a", "b", "c"]]),
434 np.array([["aa", "bb", ""]])])
435 sep = np.array([";"])
436 axis = np.array([1], dtype=np.int64)
437 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
438 self.assertEqual(
439 txout[0].tolist(), np.array(["a;b;c", "aa;bb;"]).tolist())
440 axis = np.array([0], dtype=np.int64)
441 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
442 self.assertEqual(
443 txout[0].tolist(), np.array(['a;aa', 'b;bb', 'c;']).tolist())
444
445 def test_string_join_cc_1d(self):
446 so = _ort.SessionOptions()
447 so.register_custom_ops_library(_get_library_path())
448 onnx_model = _create_test_model_string_join('')
449 self.assertIn('op_type: "StringJoin"', str(onnx_model))
450 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
451 text = np.array(["a", "b", "cc"])
452 sep = np.array([";"])
453 axis = np.array([0], dtype=np.int64)
454 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
455 self.assertEqual(
456 txout[0].tolist(), np.array(["a;b;cc"]).tolist())
457
458 def test_string_join_cc_3d(self):
459 so = _ort.SessionOptions()
460 so.register_custom_ops_library(_get_library_path())
461 onnx_model = _create_test_model_string_join('')
462 self.assertIn('op_type: "StringJoin"', str(onnx_model))
463 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
464 text = np.array(["a", "b", "c", "d", "e", "f", "g", "h"]).reshape((
465 2, 2, 2))
466 sep = np.array([";"])
467 axis = np.array([2], dtype=np.int64)
468 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
469 self.assertEqual(
470 txout[0].tolist(),
471 np.array([['a;b', 'c;d'], ['e;f', 'g;h']]).tolist())
472 axis = np.array([1], dtype=np.int64)
473 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
474 self.assertEqual(
475 txout[0].tolist(),
476 np.array([['a;c', 'b;d'], ['e;g', 'f;h']]).tolist())
477 axis = np.array([0], dtype=np.int64)
478 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
479 self.assertEqual(
480 txout[0].tolist(),
481 np.array([['a;e', 'b;f'], ['c;g', 'd;h']]).tolist())
482
483 def test_string_replace_cc(self):
484 so = _ort.SessionOptions()
485 so.register_custom_ops_library(_get_library_path())
486 onnx_model = _create_test_model_string_replace('')
487 self.assertIn('op_type: "StringRegexReplace"', str(onnx_model))
488 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
489 pattern = np.array([r'def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):'])
490 rewrite = np.array([r'static PyObject* py_\1(void) {'])
491 text = np.array([['def myfunc():'], ['def dummy():']])
492 txout = sess.run(
493 None, {'text': text, 'pattern': pattern, 'rewrite': rewrite})
494 exp = [['static PyObject* py_myfunc(void) {'],
495 ['static PyObject* py_dummy(void) {']]
496 self.assertEqual(exp, txout[0].tolist())
497
498 def test_string_replace_cc_x2(self):
499 so = _ort.SessionOptions()
500 so.register_custom_ops_library(_get_library_path())
501 onnx_model = _create_test_model_string_replace('')
502 self.assertIn('op_type: "StringRegexReplace"', str(onnx_model))
503 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
504 pattern = np.array([r'def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):'])
505 rewrite = np.array([r'static PyObject* py_\1(void) {'])
506 text = np.array([['def myfunc():'], ['def dummy():' * 2]])
507 txout = sess.run(
508 None, {'text': text, 'pattern': pattern, 'rewrite': rewrite})
509 exp = [['static PyObject* py_myfunc(void) {'],
510 ['static PyObject* py_dummy(void) {' * 2]]
511 self.assertEqual(exp, txout[0].tolist())
512
513 def test_string_replace_python(self):
514 so = _ort.SessionOptions()
515 so.register_custom_ops_library(_get_library_path())
516 onnx_model = _create_test_model_string_replace('Py')
517 self.assertIn('op_type: "PyStringRegexReplace"', str(onnx_model))
518 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
519 pattern = np.array([r'def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):'])
520 rewrite = np.array([r'static PyObject*\npy_\1(void)\n{'])
521 text = np.array([['def myfunc():'], ['def dummy():']])
522 txout = sess.run(
523 None, {'text': text, 'pattern': pattern, 'rewrite': rewrite})
524 exp = [['static PyObject*\npy_myfunc(void)\n{'],
525 ['static PyObject*\npy_dummy(void)\n{']]
526 self.assertEqual(exp, txout[0].tolist())
527
528 def test_string_replace_python_x2(self):
529 so = _ort.SessionOptions()
530 so.register_custom_ops_library(_get_library_path())
531 onnx_model = _create_test_model_string_replace('Py')
532 self.assertIn('op_type: "PyStringRegexReplace"', str(onnx_model))
533 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
534 pattern = np.array([r'def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):'])
535 rewrite = np.array([r'static PyObject*\npy_\1(void)\n{'])
536 text = np.array([['def myfunc():'], ['def dummy():' * 2]])
537 txout = sess.run(
538 None, {'text': text, 'pattern': pattern, 'rewrite': rewrite})
539 exp = [['static PyObject*\npy_myfunc(void)\n{'],
540 ['static PyObject*\npy_dummy(void)\n{' * 2]]
541 self.assertEqual(exp, txout[0].tolist())
542
543 def test_string_to_crc32_python(self):
544 so = _ort.SessionOptions()
545 so.register_custom_ops_library(_get_library_path())
546 onnx_model = _create_test_model_string_to_hash('Py', kind='crc32')
547 self.assertIn('op_type: "PyStringToCRC32"', str(onnx_model))
548 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
549 text = np.array([["abc", "abcdé"], ["$$^l!%*ù", ""]])
550 num_buckets = np.array([44], dtype=np.uint32)
551 res = self._string_to_crc32(text, num_buckets)
552 self.assertEqual(res.shape, text.shape)
553 exp = np.array([[10, 38], [29, 0]], dtype=np.uint32)
554 self.assertEqual(exp.tolist(), res.tolist())
555 txout = sess.run(
556 None, {'text': text, 'num_buckets': num_buckets})
557 self.assertEqual(exp.tolist(), txout[0].tolist())
558
559 def test_string_to_hash_bucket_cc(self):
560 so = _ort.SessionOptions()
561 so.register_custom_ops_library(_get_library_path())
562 onnx_model = _create_test_model_string_to_hash(
563 '', kind='hash_bucket')
564 self.assertIn('op_type: "StringToHashBucket"', str(onnx_model))
565 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
566 raw = ["abc", "abcdé", "$$^l!%*ù", "", "a", "A"]
567 text = np.array(raw).reshape((3, 2))
568 num_buckets = np.array([NUM_BUCKETS], dtype=np.int64)
569 txout = sess.run(
570 None, {'text': text, 'num_buckets': num_buckets})
571 try:
572 from tensorflow.raw_ops import StringToHashBucket
573 dotf = True
574 except ImportError:
575 dotf = False
576 if dotf:
577 tfres = StringToHashBucket(
578 string_tensor=text, num_buckets=num_buckets[0])
579 self.assertEqual(tfres.shape, txout[0].shape)
580 self.assertEqual(tfres.numpy().tolist(), txout[0].tolist())
581 exp = np.array([[15, 11], [10, 21], [20, 21]], dtype=np.int64)
582 self.assertEqual(exp.shape, txout[0].shape)
583 self.assertEqual(exp.tolist(), txout[0].tolist())
584
585 def test_string_to_hash_bucket_fast_cc(self):
586 so = _ort.SessionOptions()
587 so.register_custom_ops_library(_get_library_path())
588 onnx_model = _create_test_model_string_to_hash(
589 '', kind='hash_bucket_fast')
590 self.assertIn('op_type: "StringToHashBucketFast"', str(onnx_model))
591 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
592 raw = ["abc", "abcdé", "$$^l!%*ù", "", "a", "A"]
593 text = np.array(raw).reshape((3, 2))
594 num_buckets = np.array([NUM_BUCKETS], dtype=np.int64)
595 txout = sess.run(
596 None, {'text': text, 'num_buckets': num_buckets})
597 try:
598 from tensorflow.raw_ops import StringToHashBucketFast
599 dotf = True
600 except ImportError:
601 dotf = False
602 if dotf:
603 tfres = StringToHashBucketFast(
604 input=text, num_buckets=num_buckets[0])
605 self.assertEqual(tfres.shape, txout[0].shape)
606 self.assertEqual(tfres.numpy().tolist(), txout[0].tolist())
607 exp = np.array([[9, 17], [4, 21], [14, 12]], dtype=np.int64)
608 self.assertEqual(exp.shape, txout[0].shape)
609 self.assertEqual(exp.tolist(), txout[0].tolist())
610
611 def test_string_to_hash_bucket_python(self):
612 so = _ort.SessionOptions()
613 so.register_custom_ops_library(_get_library_path())
614 onnx_model = _create_test_model_string_to_hash(
615 'Py', kind='hash_bucket')
616 self.assertIn('op_type: "PyStringToHashBucket"', str(onnx_model))
617 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
618 raw = ["abc", "abcdé", "$$^l!%*ù", "", "a", "A"]
619 text = np.array(raw).reshape((3, 2))
620 num_buckets = np.array([NUM_BUCKETS], dtype=np.int64)
621 exp = np.array([[9, 17], [4, 21], [14, 12]], dtype=np.int64)
622 txout = sess.run(
623 None, {'text': text, 'num_buckets': num_buckets})
624 self.assertEqual(exp.shape, txout[0].shape)
625 self.assertEqual(exp.tolist(), txout[0].tolist())
626
627 def enumerate_matrix_couples(self):
628 for i in range(1, 5):
629 shape = (3,) * i
630 a = (np.random.rand(*shape) * 10).astype(np.int32).astype(np.str)
631 yield a, a
632 for j in range(i):
633 shape2 = list(shape)
634 shape2[j] = 1
635 b = (np.random.rand(*shape2) * 10).astype(
636 np.int32).astype(np.str)
637 yield a, b
638 for k in range(j+1, i):
639 shape3 = list(shape2)
640 shape3[k] = 1
641 b = (np.random.rand(*shape3) * 10).astype(
642 np.int32).astype(np.str)
643 yield a, b
644
645 def test_string_equal_python(self):
646 so = _ort.SessionOptions()
647 so.register_custom_ops_library(_get_library_path())
648 onnx_model = _create_test_model_string_equal('Py')
649 self.assertIn('op_type: "PyStringEqual"', str(onnx_model))
650 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
651
652 for x, y in self.enumerate_matrix_couples():
653 txout = sess.run(None, {'x': x, 'y': y})
654 self.assertEqual(txout[0].tolist(), (x == y).tolist())
655 txout = sess.run(None, {'x': y, 'y': x})
656 self.assertEqual(txout[0].tolist(), (y == x).tolist())
657
658 def test_string_equal_cc(self):
659 so = _ort.SessionOptions()
660 so.register_custom_ops_library(_get_library_path())
661 onnx_model = _create_test_model_string_equal('')
662 self.assertIn('op_type: "StringEqual"', str(onnx_model))
663 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
664
665 for x, y in self.enumerate_matrix_couples():
666 txout = sess.run(None, {'x': x, 'y': y})
667 self.assertEqual(txout[0].tolist(), (x == y).tolist())
668 txout = sess.run(None, {'x': y, 'y': x})
669 self.assertEqual(txout[0].tolist(), (y == x).tolist())
670
671 def test_string_split_python(self):
672 so = _ort.SessionOptions()
673 so.register_custom_ops_library(_get_library_path())
674 onnx_model = _create_test_model_string_split('Py')
675 self.assertIn('op_type: "PyStringSplit"', str(onnx_model))
676 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
677 input = np.array(["a,,b", "", "aa,b,c", "dddddd"])
678 delimiter = np.array([","])
679
680 for skip in [True, False]:
681 with self.subTest(skip=skip):
682 skip_empty = np.array([skip])
683
684 txout = sess.run(
685 None, {'input': input, 'delimiter': delimiter,
686 'skip_empty': skip_empty})
687
688 if skip_empty:
689 exp_indices = np.array(
690 [[0, 0], [0, 1], [2, 0], [2, 1], [2, 2], [3, 0]])
691 exp_text = np.array(['a', 'b', 'aa', 'b', 'c', 'dddddd'])
692 else:
693 exp_indices = np.array(
694 [[0, 0], [0, 1], [0, 2], [2, 0], [2, 1],
695 [2, 2], [3, 0]])
696 exp_text = np.array(
697 ['a', '', 'b', 'aa', 'b', 'c', 'dddddd'])
698 exp_shape = np.array([4, 3])
699 self.assertEqual(exp_indices.tolist(), txout[0].tolist())
700 self.assertEqual(exp_text.tolist(), txout[1].tolist())
701 self.assertEqual(exp_shape.tolist(), txout[2].tolist())
702
703 def test_string_split_cc(self):
704 so = _ort.SessionOptions()
705 so.register_custom_ops_library(_get_library_path())
706 onnx_model = _create_test_model_string_split('')
707 self.assertIn('op_type: "StringSplit"', str(onnx_model))
708 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
709 input = np.array(["a,,b", "", "aa,b,c", "dddddd"])
710 delimiter = np.array([","])
711
712 for skip in [True, False]:
713 with self.subTest(skip=skip):
714 skip_empty = np.array([skip])
715
716 txout = sess.run(
717 None, {'input': input, 'delimiter': delimiter,
718 'skip_empty': skip_empty})
719
720 try:
721 from tensorflow.raw_ops import StringSplit
722 dotf = True
723 except ImportError:
724 dotf = False
725 if dotf:
726 tfres = StringSplit(
727 input=input, delimiter=",,", skip_empty=skip)
728 self.assertEqual(
729 [_.decode() for _ in tfres[1].numpy().tolist()],
730 txout[1].tolist())
731 self.assertEqual(
732 tfres[0].numpy().tolist(), txout[0].tolist())
733 self.assertEqual(
734 tfres[2].numpy().tolist(), txout[2].tolist())
735
736 if skip_empty:
737 exp_indices = np.array(
738 [[0, 0], [0, 1], [2, 0], [2, 1], [2, 2], [3, 0]])
739 exp_text = np.array(['a', 'b', 'aa', 'b', 'c', 'dddddd'])
740 else:
741 exp_indices = np.array(
742 [[0, 0], [0, 1], [0, 2], [2, 0], [2, 1],
743 [2, 2], [3, 0]])
744 exp_text = np.array(
745 ['a', '', 'b', 'aa', 'b', 'c', 'dddddd'])
746 exp_shape = np.array([4, 3])
747 self.assertEqual(exp_indices.tolist(), txout[0].tolist())
748 self.assertEqual(exp_text.tolist(), txout[1].tolist())
749 self.assertEqual(exp_shape.tolist(), txout[2].tolist())
750
751 def test_string_split_cc_sep2(self):
752 so = _ort.SessionOptions()
753 so.register_custom_ops_library(_get_library_path())
754 onnx_model = _create_test_model_string_split('')
755 self.assertIn('op_type: "StringSplit"', str(onnx_model))
756 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
757 input = np.array(["a*b", "a,*b", "aa,b,,c", 'z', "dddddd,", "**"])
758 delimiter = np.array([",*"])
759
760 for skip in [True, False]:
761 with self.subTest(skip=skip):
762 skip_empty = np.array([skip])
763
764 txout = sess.run(
765 None, {'input': input, 'delimiter': delimiter,
766 'skip_empty': skip_empty})
767
768 try:
769 from tensorflow.raw_ops import StringSplit
770 dotf = True
771 except ImportError:
772 dotf = False
773 if dotf:
774 tfres = StringSplit(
775 input=input, delimiter=",*", skip_empty=skip)
776 self.assertEqual(
777 [_.decode() for _ in tfres[1].numpy().tolist()],
778 txout[1].tolist())
779 self.assertEqual(
780 tfres[0].numpy().tolist(), txout[0].tolist())
781 self.assertEqual(
782 tfres[2].numpy().tolist(), txout[2].tolist())
783
784 if skip_empty:
785 exp_indices = np.array(
786 [[0, 0], [0, 1], [1, 0], [1, 1], [2, 0], [2, 1],
787 [2, 2], [3, 0], [4, 0]])
788 exp_text = np.array(
789 ['a', 'b', 'a', 'b', 'aa', 'b', 'c', 'z', 'dddddd'])
790 exp_shape = np.array([6, 3])
791 else:
792 exp_indices = np.array(
793 [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [2, 0],
794 [2, 1], [2, 2], [2, 3], [3, 0], [4, 0], [4, 1],
795 [5, 0], [5, 1], [5, 2]])
796 exp_text = np.array(
797 ['a', 'b', 'a', '', 'b', 'aa', 'b', '', 'c',
798 'z', 'dddddd', '', '', '', ''])
799 exp_shape = np.array([6, 4])
800 self.assertEqual(exp_text.tolist(), txout[1].tolist())
801 self.assertEqual(exp_indices.tolist(), txout[0].tolist())
802 self.assertEqual(exp_shape.tolist(), txout[2].tolist())
803
804 def test_string_split_cc_sep0(self):
805 so = _ort.SessionOptions()
806 so.register_custom_ops_library(_get_library_path())
807 onnx_model = _create_test_model_string_split('')
808 self.assertIn('op_type: "StringSplit"', str(onnx_model))
809 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
810 input = np.array(["a*b", "a,*b"])
811 delimiter = np.array([""])
812
813 for skip in [True, False]:
814 with self.subTest(skip=skip):
815 skip_empty = np.array([skip])
816
817 txout = sess.run(
818 None, {'input': input, 'delimiter': delimiter,
819 'skip_empty': skip_empty})
820
821 try:
822 from tensorflow.raw_ops import StringSplit
823 dotf = True
824 except ImportError:
825 dotf = False
826 if dotf:
827 tfres = StringSplit(
828 input=input, delimiter="", skip_empty=skip)
829 self.assertEqual(
830 [_.decode() for _ in tfres[1].numpy().tolist()],
831 txout[1].tolist())
832 self.assertEqual(
833 tfres[0].numpy().tolist(), txout[0].tolist())
834 self.assertEqual(
835 tfres[2].numpy().tolist(), txout[2].tolist())
836
837 exp_indices = np.array(
838 [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [1, 3]])
839 exp_text = np.array(['a', '*', 'b', 'a', ',', '*', 'b'])
840 exp_shape = np.array([2, 4])
841 self.assertEqual(exp_text.tolist(), txout[1].tolist())
842 self.assertEqual(exp_indices.tolist(), txout[0].tolist())
843 self.assertEqual(exp_shape.tolist(), txout[2].tolist())
844
845
846if __name__ == "__main__":
847 unittest.main()
848