microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
4a0f8929494fa301baa6c59f617cce7872a7c4c8

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

test/test_string_ops.py

845lines · modecode

1# coding: utf-8
2import unittest
3import re
4from binascii import crc32
5import numpy as np
6from onnx import helper, onnx_pb as onnx_proto
7import onnxruntime as _ort
8from onnxruntime_customops import (
9 onnx_op, PyCustomOpDef,
10 get_library_path as _get_library_path,
11 hash_64)
12
13NUM_BUCKETS = 23
14
15
16def _create_test_model_string_upper(prefix, domain='ai.onnx.contrib'):
17 nodes = []
18 nodes[0:] = [helper.make_node('Identity', ['input_1'], ['identity1'])]
19 nodes[1:] = [helper.make_node('%sStringUpper' % prefix,
20 ['identity1'], ['customout'],
21 domain=domain)]
22
23 input0 = helper.make_tensor_value_info(
24 'input_1', onnx_proto.TensorProto.STRING, [None, None])
25 output0 = helper.make_tensor_value_info(
26 'customout', onnx_proto.TensorProto.STRING, [None, None])
27
28 graph = helper.make_graph(nodes, 'test0', [input0], [output0])
29 model = helper.make_model(
30 graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
31 return model
32
33
34def _create_test_model_string_join(prefix, domain='ai.onnx.contrib'):
35 nodes = []
36 nodes.append(
37 helper.make_node('Identity', ['text'], ['identity1']))
38 nodes.append(
39 helper.make_node('Identity', ['sep'], ['identity2']))
40 nodes.append(
41 helper.make_node('Identity', ['axis'], ['identity3']))
42 nodes.append(
43 helper.make_node(
44 '%sStringJoin' % prefix, ['identity1', 'identity2', 'identity3'],
45 ['customout'], domain=domain))
46
47 input0 = helper.make_tensor_value_info(
48 'text', onnx_proto.TensorProto.STRING, None)
49 input1 = helper.make_tensor_value_info(
50 'sep', onnx_proto.TensorProto.STRING, [1])
51 input2 = helper.make_tensor_value_info(
52 'axis', onnx_proto.TensorProto.INT64, [1])
53 output0 = helper.make_tensor_value_info(
54 'customout', onnx_proto.TensorProto.STRING, None)
55
56 graph = helper.make_graph(
57 nodes, 'test0', [input0, input1, input2], [output0])
58 model = helper.make_model(
59 graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
60 return model
61
62
63def _create_test_model_string_replace(prefix, domain='ai.onnx.contrib'):
64 nodes = []
65 nodes.append(
66 helper.make_node('Identity', ['text'], ['id1']))
67 nodes.append(
68 helper.make_node('Identity', ['pattern'], ['id2']))
69 nodes.append(
70 helper.make_node('Identity', ['rewrite'], ['id3']))
71 nodes.append(
72 helper.make_node(
73 '%sStringRegexReplace' % prefix, ['id1', 'id2', 'id3'],
74 ['customout'], domain=domain))
75
76 input0 = helper.make_tensor_value_info(
77 'text', onnx_proto.TensorProto.STRING, [None, 1])
78 input1 = helper.make_tensor_value_info(
79 'pattern', onnx_proto.TensorProto.STRING, [1])
80 input2 = helper.make_tensor_value_info(
81 'rewrite', onnx_proto.TensorProto.STRING, [1])
82 output0 = helper.make_tensor_value_info(
83 'customout', onnx_proto.TensorProto.STRING, [None, 1])
84
85 graph = helper.make_graph(
86 nodes, 'test0', [input0, input1, input2], [output0])
87 model = helper.make_model(
88 graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
89 return model
90
91
92def _create_test_model_string_to_hash(
93 prefix, domain='ai.onnx.contrib', kind=None):
94 if kind == 'crc32':
95 op_type = 'StringToCRC32'
96 out_type = onnx_proto.TensorProto.UINT32
97 in_type = out_type
98 elif kind == 'hash_bucket':
99 op_type = 'StringToHashBucket'
100 out_type = onnx_proto.TensorProto.INT64
101 in_type = out_type
102 elif kind == 'hash_bucket_fast':
103 op_type = 'StringToHashBucketFast'
104 out_type = onnx_proto.TensorProto.INT64
105 in_type = out_type
106 else:
107 raise ValueError('Unknown value %r.' % kind)
108 nodes = []
109 nodes.append(
110 helper.make_node('Identity', ['text'], ['id1']))
111 nodes.append(
112 helper.make_node('Identity', ['num_buckets'], ['id2']))
113 nodes.append(
114 helper.make_node(
115 '%s%s' % (prefix, op_type), ['id1', 'id2'],
116 ['customout'], domain=domain))
117
118 input0 = helper.make_tensor_value_info(
119 'text', onnx_proto.TensorProto.STRING, [None, None])
120 input1 = helper.make_tensor_value_info(
121 'num_buckets', in_type, [1])
122 output0 = helper.make_tensor_value_info(
123 'customout', out_type, [None, None])
124
125 graph = helper.make_graph(
126 nodes, 'test0', [input0, input1], [output0])
127 model = helper.make_model(
128 graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
129 return model
130
131
132def _create_test_model_string_equal(prefix, domain='ai.onnx.contrib'):
133 nodes = []
134 nodes.append(helper.make_node('Identity', ['x'], ['id1']))
135 nodes.append(helper.make_node('Identity', ['y'], ['id2']))
136 nodes.append(
137 helper.make_node(
138 '%sStringEqual' % prefix, ['id1', 'id2'], ['z'], domain=domain))
139
140 input0 = helper.make_tensor_value_info(
141 'x', onnx_proto.TensorProto.STRING, [])
142 input1 = helper.make_tensor_value_info(
143 'y', onnx_proto.TensorProto.STRING, [])
144 output0 = helper.make_tensor_value_info(
145 'z', onnx_proto.TensorProto.BOOL, [])
146
147 graph = helper.make_graph(nodes, 'test0', [input0, input1], [output0])
148 model = helper.make_model(
149 graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
150 return model
151
152
153def _create_test_model_string_split(prefix, domain='ai.onnx.contrib'):
154 nodes = []
155 nodes.append(helper.make_node('Identity', ['input'], ['id1']))
156 nodes.append(helper.make_node('Identity', ['delimiter'], ['id2']))
157 nodes.append(helper.make_node('Identity', ['skip_empty'], ['id3']))
158 nodes.append(
159 helper.make_node(
160 '%sStringSplit' % prefix, ['id1', 'id2', 'id3'],
161 ['indices', 'values', 'shape'], domain=domain))
162
163 input0 = helper.make_tensor_value_info(
164 'input', onnx_proto.TensorProto.STRING, [])
165 input1 = helper.make_tensor_value_info(
166 'delimiter', onnx_proto.TensorProto.STRING, [])
167 input2 = helper.make_tensor_value_info(
168 'skip_empty', onnx_proto.TensorProto.BOOL, [])
169 output0 = helper.make_tensor_value_info(
170 'indices', onnx_proto.TensorProto.INT64, [])
171 output1 = helper.make_tensor_value_info(
172 'values', onnx_proto.TensorProto.STRING, [])
173 output2 = helper.make_tensor_value_info(
174 'shape', onnx_proto.TensorProto.INT64, [])
175
176 graph = helper.make_graph(nodes, 'test0', [input0, input1, input2],
177 [output0, output1, output2])
178 model = helper.make_model(
179 graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
180 return model
181
182
183class TestPythonOpString(unittest.TestCase):
184
185 _string_join = None
186 _string_to_crc32 = None
187
188 @classmethod
189 def setUpClass(cls):
190
191 @onnx_op(op_type="PyStringUpper",
192 inputs=[PyCustomOpDef.dt_string],
193 outputs=[PyCustomOpDef.dt_string])
194 def string_upper(x):
195 # The user custom op implementation here.
196 return np.array([s.upper() for s in x.ravel()]).reshape(x.shape)
197
198 @onnx_op(op_type="PyStringJoin",
199 inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_string,
200 PyCustomOpDef.dt_int64],
201 outputs=[PyCustomOpDef.dt_string])
202 def string_join(x, sep, axis):
203 # The user custom op implementation here.
204 if sep.shape != (1, ):
205 raise RuntimeError(
206 "Unexpected shape {} for 'sep'.".format(sep.shape))
207 if axis.shape != (1, ):
208 raise RuntimeError(
209 "Unexpected shape {} for 'axis'.".format(axis.shape))
210 sp = sep[0]
211 ax = axis[0]
212 if ax < 0 or ax >= len(x.shape):
213 raise RuntimeError(
214 "axis must be in [%r,%r] but is %r" % (
215 0, len(x.shape), ax))
216 if len(x.shape) == 1:
217 return np.array([sp.join(x)])
218 dims = np.arange(len(x.shape))
219 dims[ax], dims[-1] = dims[-1], dims[ax]
220 x2 = np.transpose(x, dims)
221 res_shape = x2.shape[:-1]
222 x2 = x2.reshape((-1, x2.shape[-1]))
223 res = np.empty(x2.shape[0], dtype=x.dtype)
224 for i in range(x2.shape[0]):
225 res[i] = sp.join(x2[i, :])
226 return res.reshape(res_shape)
227
228 @onnx_op(op_type="PyStringRegexReplace",
229 inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_string,
230 PyCustomOpDef.dt_string],
231 outputs=[PyCustomOpDef.dt_string])
232 def string_replace(x, pattern, rewrite):
233 # The user custom op implementation here.
234 if pattern.shape != (1, ):
235 raise RuntimeError(
236 "Unexpected shape {} for 'pattern'.".format(pattern.shape))
237 if rewrite.shape != (1, ):
238 raise RuntimeError(
239 "Unexpected shape {} for 'rewrite'.".format(rewrite.shape))
240 reg = re.compile(pattern[0])
241 res = np.array(
242 list(map(lambda t: reg.sub(rewrite[0], t), x.ravel())))
243 return res.reshape(x.shape)
244
245 @onnx_op(op_type="PyStringToCRC32",
246 inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_uint32],
247 outputs=[PyCustomOpDef.dt_uint32])
248 def string_to_crc32(x, num_buckets):
249 if num_buckets.shape != (1, ):
250 raise RuntimeError(
251 "Unexpected shape {} for 'num_buckets'.".format(
252 num_buckets.shape))
253 nb = num_buckets[0]
254 res = np.array(
255 list(map(
256 lambda x: crc32(x.encode('iso-8859-15')) % nb,
257 x.ravel())))
258 return res.reshape(x.shape)
259
260 @onnx_op(op_type="PyStringToHashBucket",
261 inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_int64],
262 outputs=[PyCustomOpDef.dt_int64])
263 def string_to_hash_bucket(x, num_buckets):
264 if num_buckets.shape != (1, ):
265 raise RuntimeError(
266 "Unexpected shape {} for 'num_buckets'.".format(
267 num_buckets.shape))
268 nb = num_buckets[0]
269 res = np.array(
270 list(map(lambda x: hash_64(x, nb, True), x.ravel())))
271 return res.reshape(x.shape).astype(np.int64)
272
273 @onnx_op(op_type="PyStringEqual",
274 inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_string],
275 outputs=[PyCustomOpDef.dt_bool])
276 def string_equal(x, y):
277 return x == y
278
279 @onnx_op(op_type="PyStringSplit",
280 inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_string,
281 PyCustomOpDef.dt_bool],
282 outputs=[PyCustomOpDef.dt_int64, PyCustomOpDef.dt_string,
283 PyCustomOpDef.dt_int64])
284 def string_split(input, delimiter, skip_empty):
285 if delimiter.shape != (1, ):
286 raise RuntimeError("demiliter must a single element tensor.")
287 if skip_empty.shape != (1, ):
288 raise RuntimeError("skip_empty must a single element tensor.")
289 if len(input.shape) != 1:
290 raise RuntimeError("input must a one dimension tensor.")
291 delimiter = delimiter[0]
292 skip_empty = skip_empty[0]
293 texts = []
294 indices = []
295 max_split = 0
296 for row, text in enumerate(input):
297 if not text:
298 continue
299 res = text.split(delimiter)
300 if skip_empty:
301 res = [t for t in res if t]
302 texts.extend(res)
303 max_split = max(max_split, len(res))
304 indices.extend((row, i) for i in range(len(res)))
305 return (np.array(indices, dtype=np.int64),
306 np.array(texts),
307 np.array([len(input), max_split], dtype=np.int64))
308
309 cls._string_join = string_join
310 cls._string_to_crc32 = string_to_crc32
311
312 def test_check_types(self):
313 def_list = set(dir(PyCustomOpDef))
314 type_list = [
315 # 'dt_bfloat16',
316 'dt_bool',
317 'dt_complex128',
318 'dt_complex64',
319 'dt_double',
320 'dt_float',
321 'dt_float16',
322 'dt_int16',
323 'dt_int32',
324 'dt_int64',
325 'dt_int8',
326 'dt_string',
327 'dt_uint16',
328 'dt_uint32',
329 'dt_uint64',
330 'dt_uint8']
331 for t in type_list:
332 self.assertIn(t, def_list)
333
334 def test_string_upper_cc(self):
335 so = _ort.SessionOptions()
336 so.register_custom_ops_library(_get_library_path())
337 onnx_model = _create_test_model_string_upper('')
338 self.assertIn('op_type: "StringUpper"', str(onnx_model))
339 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
340 input_1 = np.array([["Abc"]])
341 txout = sess.run(None, {'input_1': input_1})
342 self.assertEqual(txout[0].tolist(), np.array([["ABC"]]).tolist())
343
344 def test_string_upper_cc_accent(self):
345 so = _ort.SessionOptions()
346 so.register_custom_ops_library(_get_library_path())
347 onnx_model = _create_test_model_string_upper('')
348 self.assertIn('op_type: "StringUpper"', str(onnx_model))
349 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
350 input_1 = np.array([["Abcé"]])
351 txout = sess.run(None, {'input_1': input_1})
352 self.assertEqual(txout[0].tolist(), np.array([["ABCé"]]).tolist())
353
354 def test_string_upper_python(self):
355 so = _ort.SessionOptions()
356 so.register_custom_ops_library(_get_library_path())
357 onnx_model = _create_test_model_string_upper('Py')
358 self.assertIn('op_type: "PyStringUpper"', str(onnx_model))
359 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
360 input_1 = np.array([["Abc"]])
361 txout = sess.run(None, {'input_1': input_1})
362 self.assertEqual(txout[0].tolist(), np.array([["ABC"]]).tolist())
363
364 def test_string_upper_python_accent(self):
365 so = _ort.SessionOptions()
366 so.register_custom_ops_library(_get_library_path())
367 onnx_model = _create_test_model_string_upper('Py')
368 self.assertIn('op_type: "PyStringUpper"', str(onnx_model))
369 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
370 input_1 = np.array([["Abcé"]])
371 txout = sess.run(None, {'input_1': input_1})
372 self.assertEqual(txout[0].tolist(),
373 np.array([["ABCé".upper()]]).tolist())
374
375 def test_string_join_python(self):
376 so = _ort.SessionOptions()
377 so.register_custom_ops_library(_get_library_path())
378 onnx_model = _create_test_model_string_join('Py')
379 self.assertIn('op_type: "PyStringJoin"', str(onnx_model))
380 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
381 text = np.vstack([np.array([["a", "b", "c"]]),
382 np.array([["aa", "bb", ""]])])
383 self.assertEqual(text.shape, (2, 3))
384 sep = np.array([";"])
385 axis = np.array([1], dtype=np.int64)
386 TestPythonOpString._string_join(text, sep, axis)
387 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
388 self.assertEqual(
389 txout[0].tolist(), np.array(["a;b;c", "aa;bb;"]).tolist())
390 axis = np.array([0], dtype=np.int64)
391 TestPythonOpString._string_join(text, sep, axis)
392 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
393 self.assertEqual(
394 txout[0].tolist(), np.array(['a;aa', 'b;bb', 'c;']).tolist())
395
396 def test_string_join_python_3d(self):
397 so = _ort.SessionOptions()
398 so.register_custom_ops_library(_get_library_path())
399 onnx_model = _create_test_model_string_join('Py')
400 self.assertIn('op_type: "PyStringJoin"', str(onnx_model))
401 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
402 text = np.vstack([np.array([["a", "b", "c"]]),
403 np.array([["aa", "bb", ""]])]).reshape((2, 3, 1))
404 sep = np.array([";"])
405 axis = np.array([1], dtype=np.int64)
406 TestPythonOpString._string_join(text, sep, axis)
407 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
408 self.assertEqual(
409 txout[0].tolist(), np.array([['a;b;c'], ['aa;bb;']]).tolist())
410
411 def test_string_join_python_1d(self):
412 so = _ort.SessionOptions()
413 so.register_custom_ops_library(_get_library_path())
414 onnx_model = _create_test_model_string_join('Py')
415 self.assertIn('op_type: "PyStringJoin"', str(onnx_model))
416 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
417 text = np.array(["a", "b", "cc"])
418 sep = np.array([";"])
419 axis = np.array([0], dtype=np.int64)
420 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
421 self.assertEqual(txout[0].shape, (1, ))
422 self.assertEqual(
423 txout[0].tolist(), np.array(["a;b;cc"]).tolist())
424
425 def test_string_join_cc(self):
426 so = _ort.SessionOptions()
427 so.register_custom_ops_library(_get_library_path())
428 onnx_model = _create_test_model_string_join('')
429 self.assertIn('op_type: "StringJoin"', str(onnx_model))
430 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
431 text = np.vstack([np.array([["a", "b", "c"]]),
432 np.array([["aa", "bb", ""]])])
433 sep = np.array([";"])
434 axis = np.array([1], dtype=np.int64)
435 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
436 self.assertEqual(
437 txout[0].tolist(), np.array(["a;b;c", "aa;bb;"]).tolist())
438 axis = np.array([0], dtype=np.int64)
439 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
440 self.assertEqual(
441 txout[0].tolist(), np.array(['a;aa', 'b;bb', 'c;']).tolist())
442
443 def test_string_join_cc_1d(self):
444 so = _ort.SessionOptions()
445 so.register_custom_ops_library(_get_library_path())
446 onnx_model = _create_test_model_string_join('')
447 self.assertIn('op_type: "StringJoin"', str(onnx_model))
448 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
449 text = np.array(["a", "b", "cc"])
450 sep = np.array([";"])
451 axis = np.array([0], dtype=np.int64)
452 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
453 self.assertEqual(
454 txout[0].tolist(), np.array(["a;b;cc"]).tolist())
455
456 def test_string_join_cc_3d(self):
457 so = _ort.SessionOptions()
458 so.register_custom_ops_library(_get_library_path())
459 onnx_model = _create_test_model_string_join('')
460 self.assertIn('op_type: "StringJoin"', str(onnx_model))
461 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
462 text = np.array(["a", "b", "c", "d", "e", "f", "g", "h"]).reshape((
463 2, 2, 2))
464 sep = np.array([";"])
465 axis = np.array([2], dtype=np.int64)
466 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
467 self.assertEqual(
468 txout[0].tolist(),
469 np.array([['a;b', 'c;d'], ['e;f', 'g;h']]).tolist())
470 axis = np.array([1], dtype=np.int64)
471 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
472 self.assertEqual(
473 txout[0].tolist(),
474 np.array([['a;c', 'b;d'], ['e;g', 'f;h']]).tolist())
475 axis = np.array([0], dtype=np.int64)
476 txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
477 self.assertEqual(
478 txout[0].tolist(),
479 np.array([['a;e', 'b;f'], ['c;g', 'd;h']]).tolist())
480
481 def test_string_replace_cc(self):
482 so = _ort.SessionOptions()
483 so.register_custom_ops_library(_get_library_path())
484 onnx_model = _create_test_model_string_replace('')
485 self.assertIn('op_type: "StringRegexReplace"', str(onnx_model))
486 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
487 pattern = np.array([r'def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):'])
488 rewrite = np.array([r'static PyObject* py_\1(void) {'])
489 text = np.array([['def myfunc():'], ['def dummy():']])
490 txout = sess.run(
491 None, {'text': text, 'pattern': pattern, 'rewrite': rewrite})
492 exp = [['static PyObject* py_myfunc(void) {'],
493 ['static PyObject* py_dummy(void) {']]
494 self.assertEqual(exp, txout[0].tolist())
495
496 def test_string_replace_cc_x2(self):
497 so = _ort.SessionOptions()
498 so.register_custom_ops_library(_get_library_path())
499 onnx_model = _create_test_model_string_replace('')
500 self.assertIn('op_type: "StringRegexReplace"', str(onnx_model))
501 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
502 pattern = np.array([r'def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):'])
503 rewrite = np.array([r'static PyObject* py_\1(void) {'])
504 text = np.array([['def myfunc():'], ['def dummy():' * 2]])
505 txout = sess.run(
506 None, {'text': text, 'pattern': pattern, 'rewrite': rewrite})
507 exp = [['static PyObject* py_myfunc(void) {'],
508 ['static PyObject* py_dummy(void) {' * 2]]
509 self.assertEqual(exp, txout[0].tolist())
510
511 def test_string_replace_python(self):
512 so = _ort.SessionOptions()
513 so.register_custom_ops_library(_get_library_path())
514 onnx_model = _create_test_model_string_replace('Py')
515 self.assertIn('op_type: "PyStringRegexReplace"', str(onnx_model))
516 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
517 pattern = np.array([r'def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):'])
518 rewrite = np.array([r'static PyObject*\npy_\1(void)\n{'])
519 text = np.array([['def myfunc():'], ['def dummy():']])
520 txout = sess.run(
521 None, {'text': text, 'pattern': pattern, 'rewrite': rewrite})
522 exp = [['static PyObject*\npy_myfunc(void)\n{'],
523 ['static PyObject*\npy_dummy(void)\n{']]
524 self.assertEqual(exp, txout[0].tolist())
525
526 def test_string_replace_python_x2(self):
527 so = _ort.SessionOptions()
528 so.register_custom_ops_library(_get_library_path())
529 onnx_model = _create_test_model_string_replace('Py')
530 self.assertIn('op_type: "PyStringRegexReplace"', str(onnx_model))
531 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
532 pattern = np.array([r'def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):'])
533 rewrite = np.array([r'static PyObject*\npy_\1(void)\n{'])
534 text = np.array([['def myfunc():'], ['def dummy():' * 2]])
535 txout = sess.run(
536 None, {'text': text, 'pattern': pattern, 'rewrite': rewrite})
537 exp = [['static PyObject*\npy_myfunc(void)\n{'],
538 ['static PyObject*\npy_dummy(void)\n{' * 2]]
539 self.assertEqual(exp, txout[0].tolist())
540
541 def test_string_to_crc32_python(self):
542 so = _ort.SessionOptions()
543 so.register_custom_ops_library(_get_library_path())
544 onnx_model = _create_test_model_string_to_hash('Py', kind='crc32')
545 self.assertIn('op_type: "PyStringToCRC32"', str(onnx_model))
546 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
547 text = np.array([["abc", "abcdé"], ["$$^l!%*ù", ""]])
548 num_buckets = np.array([44], dtype=np.uint32)
549 res = self._string_to_crc32(text, num_buckets)
550 self.assertEqual(res.shape, text.shape)
551 exp = np.array([[10, 38], [29, 0]], dtype=np.uint32)
552 self.assertEqual(exp.tolist(), res.tolist())
553 txout = sess.run(
554 None, {'text': text, 'num_buckets': num_buckets})
555 self.assertEqual(exp.tolist(), txout[0].tolist())
556
557 def test_string_to_hash_bucket_cc(self):
558 so = _ort.SessionOptions()
559 so.register_custom_ops_library(_get_library_path())
560 onnx_model = _create_test_model_string_to_hash(
561 '', kind='hash_bucket')
562 self.assertIn('op_type: "StringToHashBucket"', str(onnx_model))
563 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
564 raw = ["abc", "abcdé", "$$^l!%*ù", "", "a", "A"]
565 text = np.array(raw).reshape((3, 2))
566 num_buckets = np.array([NUM_BUCKETS], dtype=np.int64)
567 txout = sess.run(
568 None, {'text': text, 'num_buckets': num_buckets})
569 try:
570 from tensorflow.raw_ops import StringToHashBucket
571 dotf = True
572 except ImportError:
573 dotf = False
574 if dotf:
575 tfres = StringToHashBucket(
576 string_tensor=text, num_buckets=num_buckets[0])
577 self.assertEqual(tfres.shape, txout[0].shape)
578 self.assertEqual(tfres.numpy().tolist(), txout[0].tolist())
579 exp = np.array([[15, 11], [10, 21], [20, 21]], dtype=np.int64)
580 self.assertEqual(exp.shape, txout[0].shape)
581 self.assertEqual(exp.tolist(), txout[0].tolist())
582
583 def test_string_to_hash_bucket_fast_cc(self):
584 so = _ort.SessionOptions()
585 so.register_custom_ops_library(_get_library_path())
586 onnx_model = _create_test_model_string_to_hash(
587 '', kind='hash_bucket_fast')
588 self.assertIn('op_type: "StringToHashBucketFast"', str(onnx_model))
589 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
590 raw = ["abc", "abcdé", "$$^l!%*ù", "", "a", "A"]
591 text = np.array(raw).reshape((3, 2))
592 num_buckets = np.array([NUM_BUCKETS], dtype=np.int64)
593 txout = sess.run(
594 None, {'text': text, 'num_buckets': num_buckets})
595 try:
596 from tensorflow.raw_ops import StringToHashBucketFast
597 dotf = True
598 except ImportError:
599 dotf = False
600 if dotf:
601 tfres = StringToHashBucketFast(
602 input=text, num_buckets=num_buckets[0])
603 self.assertEqual(tfres.shape, txout[0].shape)
604 self.assertEqual(tfres.numpy().tolist(), txout[0].tolist())
605 exp = np.array([[9, 17], [4, 21], [14, 12]], dtype=np.int64)
606 self.assertEqual(exp.shape, txout[0].shape)
607 self.assertEqual(exp.tolist(), txout[0].tolist())
608
609 def test_string_to_hash_bucket_python(self):
610 so = _ort.SessionOptions()
611 so.register_custom_ops_library(_get_library_path())
612 onnx_model = _create_test_model_string_to_hash(
613 'Py', kind='hash_bucket')
614 self.assertIn('op_type: "PyStringToHashBucket"', str(onnx_model))
615 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
616 raw = ["abc", "abcdé", "$$^l!%*ù", "", "a", "A"]
617 text = np.array(raw).reshape((3, 2))
618 num_buckets = np.array([NUM_BUCKETS], dtype=np.int64)
619 exp = np.array([[9, 17], [4, 21], [14, 12]], dtype=np.int64)
620 txout = sess.run(
621 None, {'text': text, 'num_buckets': num_buckets})
622 self.assertEqual(exp.shape, txout[0].shape)
623 self.assertEqual(exp.tolist(), txout[0].tolist())
624
625 def enumerate_matrix_couples(self):
626 for i in range(1, 5):
627 shape = (3,) * i
628 a = (np.random.rand(*shape) * 10).astype(np.int32).astype(np.str)
629 yield a, a
630 for j in range(i):
631 shape2 = list(shape)
632 shape2[j] = 1
633 b = (np.random.rand(*shape2) * 10).astype(
634 np.int32).astype(np.str)
635 yield a, b
636 for k in range(j+1, i):
637 shape3 = list(shape2)
638 shape3[k] = 1
639 b = (np.random.rand(*shape3) * 10).astype(
640 np.int32).astype(np.str)
641 yield a, b
642
643 def test_string_equal_python(self):
644 so = _ort.SessionOptions()
645 so.register_custom_ops_library(_get_library_path())
646 onnx_model = _create_test_model_string_equal('Py')
647 self.assertIn('op_type: "PyStringEqual"', str(onnx_model))
648 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
649
650 for x, y in self.enumerate_matrix_couples():
651 txout = sess.run(None, {'x': x, 'y': y})
652 self.assertEqual(txout[0].tolist(), (x == y).tolist())
653 txout = sess.run(None, {'x': y, 'y': x})
654 self.assertEqual(txout[0].tolist(), (y == x).tolist())
655
656 def test_string_equal_cc(self):
657 so = _ort.SessionOptions()
658 so.register_custom_ops_library(_get_library_path())
659 onnx_model = _create_test_model_string_equal('')
660 self.assertIn('op_type: "StringEqual"', str(onnx_model))
661 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
662
663 for x, y in self.enumerate_matrix_couples():
664 txout = sess.run(None, {'x': x, 'y': y})
665 self.assertEqual(txout[0].tolist(), (x == y).tolist())
666 txout = sess.run(None, {'x': y, 'y': x})
667 self.assertEqual(txout[0].tolist(), (y == x).tolist())
668
669 def test_string_split_python(self):
670 so = _ort.SessionOptions()
671 so.register_custom_ops_library(_get_library_path())
672 onnx_model = _create_test_model_string_split('Py')
673 self.assertIn('op_type: "PyStringSplit"', str(onnx_model))
674 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
675 input = np.array(["a,,b", "", "aa,b,c", "dddddd"])
676 delimiter = np.array([","])
677
678 for skip in [True, False]:
679 with self.subTest(skip=skip):
680 skip_empty = np.array([skip])
681
682 txout = sess.run(
683 None, {'input': input, 'delimiter': delimiter,
684 'skip_empty': skip_empty})
685
686 if skip_empty:
687 exp_indices = np.array(
688 [[0, 0], [0, 1], [2, 0], [2, 1], [2, 2], [3, 0]])
689 exp_text = np.array(['a', 'b', 'aa', 'b', 'c', 'dddddd'])
690 else:
691 exp_indices = np.array(
692 [[0, 0], [0, 1], [0, 2], [2, 0], [2, 1],
693 [2, 2], [3, 0]])
694 exp_text = np.array(
695 ['a', '', 'b', 'aa', 'b', 'c', 'dddddd'])
696 exp_shape = np.array([4, 3])
697 self.assertEqual(exp_indices.tolist(), txout[0].tolist())
698 self.assertEqual(exp_text.tolist(), txout[1].tolist())
699 self.assertEqual(exp_shape.tolist(), txout[2].tolist())
700
701 def test_string_split_cc(self):
702 so = _ort.SessionOptions()
703 so.register_custom_ops_library(_get_library_path())
704 onnx_model = _create_test_model_string_split('')
705 self.assertIn('op_type: "StringSplit"', str(onnx_model))
706 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
707 input = np.array(["a,,b", "", "aa,b,c", "dddddd"])
708 delimiter = np.array([","])
709
710 for skip in [True, False]:
711 with self.subTest(skip=skip):
712 skip_empty = np.array([skip])
713
714 txout = sess.run(
715 None, {'input': input, 'delimiter': delimiter,
716 'skip_empty': skip_empty})
717
718 try:
719 from tensorflow.raw_ops import StringSplit
720 dotf = True
721 except ImportError:
722 dotf = False
723 if dotf:
724 tfres = StringSplit(
725 input=input, delimiter=",,", skip_empty=skip)
726 self.assertEqual(
727 [_.decode() for _ in tfres[1].numpy().tolist()],
728 txout[1].tolist())
729 self.assertEqual(
730 tfres[0].numpy().tolist(), txout[0].tolist())
731 self.assertEqual(
732 tfres[2].numpy().tolist(), txout[2].tolist())
733
734 if skip_empty:
735 exp_indices = np.array(
736 [[0, 0], [0, 1], [2, 0], [2, 1], [2, 2], [3, 0]])
737 exp_text = np.array(['a', 'b', 'aa', 'b', 'c', 'dddddd'])
738 else:
739 exp_indices = np.array(
740 [[0, 0], [0, 1], [0, 2], [2, 0], [2, 1],
741 [2, 2], [3, 0]])
742 exp_text = np.array(
743 ['a', '', 'b', 'aa', 'b', 'c', 'dddddd'])
744 exp_shape = np.array([4, 3])
745 self.assertEqual(exp_indices.tolist(), txout[0].tolist())
746 self.assertEqual(exp_text.tolist(), txout[1].tolist())
747 self.assertEqual(exp_shape.tolist(), txout[2].tolist())
748
749 def test_string_split_cc_sep2(self):
750 so = _ort.SessionOptions()
751 so.register_custom_ops_library(_get_library_path())
752 onnx_model = _create_test_model_string_split('')
753 self.assertIn('op_type: "StringSplit"', str(onnx_model))
754 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
755 input = np.array(["a*b", "a,*b", "aa,b,,c", 'z', "dddddd,", "**"])
756 delimiter = np.array([",*"])
757
758 for skip in [True, False]:
759 with self.subTest(skip=skip):
760 skip_empty = np.array([skip])
761
762 txout = sess.run(
763 None, {'input': input, 'delimiter': delimiter,
764 'skip_empty': skip_empty})
765
766 try:
767 from tensorflow.raw_ops import StringSplit
768 dotf = True
769 except ImportError:
770 dotf = False
771 if dotf:
772 tfres = StringSplit(
773 input=input, delimiter=",*", skip_empty=skip)
774 self.assertEqual(
775 [_.decode() for _ in tfres[1].numpy().tolist()],
776 txout[1].tolist())
777 self.assertEqual(
778 tfres[0].numpy().tolist(), txout[0].tolist())
779 self.assertEqual(
780 tfres[2].numpy().tolist(), txout[2].tolist())
781
782 if skip_empty:
783 exp_indices = np.array(
784 [[0, 0], [0, 1], [1, 0], [1, 1], [2, 0], [2, 1],
785 [2, 2], [3, 0], [4, 0]])
786 exp_text = np.array(
787 ['a', 'b', 'a', 'b', 'aa', 'b', 'c', 'z', 'dddddd'])
788 exp_shape = np.array([6, 3])
789 else:
790 exp_indices = np.array(
791 [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [2, 0],
792 [2, 1], [2, 2], [2, 3], [3, 0], [4, 0], [4, 1],
793 [5, 0], [5, 1], [5, 2]])
794 exp_text = np.array(
795 ['a', 'b', 'a', '', 'b', 'aa', 'b', '', 'c',
796 'z', 'dddddd', '', '', '', ''])
797 exp_shape = np.array([6, 4])
798 self.assertEqual(exp_text.tolist(), txout[1].tolist())
799 self.assertEqual(exp_indices.tolist(), txout[0].tolist())
800 self.assertEqual(exp_shape.tolist(), txout[2].tolist())
801
802 def test_string_split_cc_sep0(self):
803 so = _ort.SessionOptions()
804 so.register_custom_ops_library(_get_library_path())
805 onnx_model = _create_test_model_string_split('')
806 self.assertIn('op_type: "StringSplit"', str(onnx_model))
807 sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
808 input = np.array(["a*b", "a,*b"])
809 delimiter = np.array([""])
810
811 for skip in [True, False]:
812 with self.subTest(skip=skip):
813 skip_empty = np.array([skip])
814
815 txout = sess.run(
816 None, {'input': input, 'delimiter': delimiter,
817 'skip_empty': skip_empty})
818
819 try:
820 from tensorflow.raw_ops import StringSplit
821 dotf = True
822 except ImportError:
823 dotf = False
824 if dotf:
825 tfres = StringSplit(
826 input=input, delimiter="", skip_empty=skip)
827 self.assertEqual(
828 [_.decode() for _ in tfres[1].numpy().tolist()],
829 txout[1].tolist())
830 self.assertEqual(
831 tfres[0].numpy().tolist(), txout[0].tolist())
832 self.assertEqual(
833 tfres[2].numpy().tolist(), txout[2].tolist())
834
835 exp_indices = np.array(
836 [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [1, 3]])
837 exp_text = np.array(['a', '*', 'b', 'a', ',', '*', 'b'])
838 exp_shape = np.array([2, 4])
839 self.assertEqual(exp_text.tolist(), txout[1].tolist())
840 self.assertEqual(exp_indices.tolist(), txout[0].tolist())
841 self.assertEqual(exp_shape.tolist(), txout[2].tolist())
842
843
844if __name__ == "__main__":
845 unittest.main()