microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
c7b2f864c65a0269e909738bc6b704f1d1db33a0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

test/test_string_ops.py

845lines · modepreview

# coding: utf-8
import unittest
import re
from binascii import crc32
import numpy as np
from onnx import helper, onnx_pb as onnx_proto
import onnxruntime as _ort
from ortcustomops import (
    onnx_op, PyCustomOpDef,
    get_library_path as _get_library_path,
    hash_64)

NUM_BUCKETS = 23


def _create_test_model_string_upper(prefix, domain='ai.onnx.contrib'):
    nodes = []
    nodes[0:] = [helper.make_node('Identity', ['input_1'], ['identity1'])]
    nodes[1:] = [helper.make_node('%sStringUpper' % prefix,
                                  ['identity1'], ['customout'],
                                  domain=domain)]

    input0 = helper.make_tensor_value_info(
        'input_1', onnx_proto.TensorProto.STRING, [None, None])
    output0 = helper.make_tensor_value_info(
        'customout', onnx_proto.TensorProto.STRING, [None, None])

    graph = helper.make_graph(nodes, 'test0', [input0], [output0])
    model = helper.make_model(
        graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
    return model


def _create_test_model_string_join(prefix, domain='ai.onnx.contrib'):
    nodes = []
    nodes.append(
        helper.make_node('Identity', ['text'], ['identity1']))
    nodes.append(
        helper.make_node('Identity', ['sep'], ['identity2']))
    nodes.append(
        helper.make_node('Identity', ['axis'], ['identity3']))
    nodes.append(
        helper.make_node(
            '%sStringJoin' % prefix, ['identity1', 'identity2', 'identity3'],
            ['customout'], domain=domain))

    input0 = helper.make_tensor_value_info(
        'text', onnx_proto.TensorProto.STRING, None)
    input1 = helper.make_tensor_value_info(
        'sep', onnx_proto.TensorProto.STRING, [1])
    input2 = helper.make_tensor_value_info(
        'axis', onnx_proto.TensorProto.INT64, [1])
    output0 = helper.make_tensor_value_info(
        'customout', onnx_proto.TensorProto.STRING, None)

    graph = helper.make_graph(
        nodes, 'test0', [input0, input1, input2], [output0])
    model = helper.make_model(
        graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
    return model


def _create_test_model_string_replace(prefix, domain='ai.onnx.contrib'):
    nodes = []
    nodes.append(
        helper.make_node('Identity', ['text'], ['id1']))
    nodes.append(
        helper.make_node('Identity', ['pattern'], ['id2']))
    nodes.append(
        helper.make_node('Identity', ['rewrite'], ['id3']))
    nodes.append(
        helper.make_node(
            '%sStringRegexReplace' % prefix, ['id1', 'id2', 'id3'],
            ['customout'], domain=domain))

    input0 = helper.make_tensor_value_info(
        'text', onnx_proto.TensorProto.STRING, [None, 1])
    input1 = helper.make_tensor_value_info(
        'pattern', onnx_proto.TensorProto.STRING, [1])
    input2 = helper.make_tensor_value_info(
        'rewrite', onnx_proto.TensorProto.STRING, [1])
    output0 = helper.make_tensor_value_info(
        'customout', onnx_proto.TensorProto.STRING, [None, 1])

    graph = helper.make_graph(
        nodes, 'test0', [input0, input1, input2], [output0])
    model = helper.make_model(
        graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
    return model


def _create_test_model_string_to_hash(
        prefix, domain='ai.onnx.contrib', kind=None):
    if kind == 'crc32':
        op_type = 'StringToCRC32'
        out_type = onnx_proto.TensorProto.UINT32
        in_type = out_type
    elif kind == 'hash_bucket':
        op_type = 'StringToHashBucket'
        out_type = onnx_proto.TensorProto.INT64
        in_type = out_type
    elif kind == 'hash_bucket_fast':
        op_type = 'StringToHashBucketFast'
        out_type = onnx_proto.TensorProto.INT64
        in_type = out_type
    else:
        raise ValueError('Unknown value %r.' % kind)
    nodes = []
    nodes.append(
        helper.make_node('Identity', ['text'], ['id1']))
    nodes.append(
        helper.make_node('Identity', ['num_buckets'], ['id2']))
    nodes.append(
        helper.make_node(
            '%s%s' % (prefix, op_type), ['id1', 'id2'],
            ['customout'], domain=domain))

    input0 = helper.make_tensor_value_info(
        'text', onnx_proto.TensorProto.STRING, [None, None])
    input1 = helper.make_tensor_value_info(
        'num_buckets', in_type, [1])
    output0 = helper.make_tensor_value_info(
        'customout', out_type, [None, None])

    graph = helper.make_graph(
        nodes, 'test0', [input0, input1], [output0])
    model = helper.make_model(
        graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
    return model


def _create_test_model_string_equal(prefix, domain='ai.onnx.contrib'):
    nodes = []
    nodes.append(helper.make_node('Identity', ['x'], ['id1']))
    nodes.append(helper.make_node('Identity', ['y'], ['id2']))
    nodes.append(
        helper.make_node(
            '%sStringEqual' % prefix, ['id1', 'id2'], ['z'], domain=domain))

    input0 = helper.make_tensor_value_info(
        'x', onnx_proto.TensorProto.STRING, [])
    input1 = helper.make_tensor_value_info(
        'y', onnx_proto.TensorProto.STRING, [])
    output0 = helper.make_tensor_value_info(
        'z', onnx_proto.TensorProto.BOOL, [])

    graph = helper.make_graph(nodes, 'test0', [input0, input1], [output0])
    model = helper.make_model(
        graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
    return model


def _create_test_model_string_split(prefix, domain='ai.onnx.contrib'):
    nodes = []
    nodes.append(helper.make_node('Identity', ['input'], ['id1']))
    nodes.append(helper.make_node('Identity', ['delimiter'], ['id2']))
    nodes.append(helper.make_node('Identity', ['skip_empty'], ['id3']))
    nodes.append(
        helper.make_node(
            '%sStringSplit' % prefix, ['id1', 'id2', 'id3'],
            ['indices', 'values', 'shape'], domain=domain))

    input0 = helper.make_tensor_value_info(
        'input', onnx_proto.TensorProto.STRING, [])
    input1 = helper.make_tensor_value_info(
        'delimiter', onnx_proto.TensorProto.STRING, [])
    input2 = helper.make_tensor_value_info(
        'skip_empty', onnx_proto.TensorProto.BOOL, [])
    output0 = helper.make_tensor_value_info(
        'indices', onnx_proto.TensorProto.INT64, [])
    output1 = helper.make_tensor_value_info(
        'values', onnx_proto.TensorProto.STRING, [])
    output2 = helper.make_tensor_value_info(
        'shape', onnx_proto.TensorProto.INT64, [])

    graph = helper.make_graph(nodes, 'test0', [input0, input1, input2],
                              [output0, output1, output2])
    model = helper.make_model(
        graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
    return model


class TestPythonOpString(unittest.TestCase):

    _string_join = None
    _string_to_crc32 = None

    @classmethod
    def setUpClass(cls):

        @onnx_op(op_type="PyStringUpper",
                 inputs=[PyCustomOpDef.dt_string],
                 outputs=[PyCustomOpDef.dt_string])
        def string_upper(x):
            # The user custom op implementation here.
            return np.array([s.upper() for s in x.ravel()]).reshape(x.shape)

        @onnx_op(op_type="PyStringJoin",
                 inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_string,
                         PyCustomOpDef.dt_int64],
                 outputs=[PyCustomOpDef.dt_string])
        def string_join(x, sep, axis):
            # The user custom op implementation here.
            if sep.shape != (1, ):
                raise RuntimeError(
                    "Unexpected shape {} for 'sep'.".format(sep.shape))
            if axis.shape != (1, ):
                raise RuntimeError(
                    "Unexpected shape {} for 'axis'.".format(axis.shape))
            sp = sep[0]
            ax = axis[0]
            if ax < 0 or ax >= len(x.shape):
                raise RuntimeError(
                    "axis must be in [%r,%r] but is %r" % (
                        0, len(x.shape), ax))
            if len(x.shape) == 1:
                return np.array([sp.join(x)])
            dims = np.arange(len(x.shape))
            dims[ax], dims[-1] = dims[-1], dims[ax]
            x2 = np.transpose(x, dims)
            res_shape = x2.shape[:-1]
            x2 = x2.reshape((-1, x2.shape[-1]))
            res = np.empty(x2.shape[0], dtype=x.dtype)
            for i in range(x2.shape[0]):
                res[i] = sp.join(x2[i, :])
            return res.reshape(res_shape)

        @onnx_op(op_type="PyStringRegexReplace",
                 inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_string,
                         PyCustomOpDef.dt_string],
                 outputs=[PyCustomOpDef.dt_string])
        def string_replace(x, pattern, rewrite):
            # The user custom op implementation here.
            if pattern.shape != (1, ):
                raise RuntimeError(
                    "Unexpected shape {} for 'pattern'.".format(pattern.shape))
            if rewrite.shape != (1, ):
                raise RuntimeError(
                    "Unexpected shape {} for 'rewrite'.".format(rewrite.shape))
            reg = re.compile(pattern[0])
            res = np.array(
                list(map(lambda t: reg.sub(rewrite[0], t), x.ravel())))
            return res.reshape(x.shape)

        @onnx_op(op_type="PyStringToCRC32",
                 inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_uint32],
                 outputs=[PyCustomOpDef.dt_uint32])
        def string_to_crc32(x, num_buckets):
            if num_buckets.shape != (1, ):
                raise RuntimeError(
                    "Unexpected shape {} for 'num_buckets'.".format(
                        num_buckets.shape))
            nb = num_buckets[0]
            res = np.array(
                list(map(
                    lambda x: crc32(x.encode('iso-8859-15')) % nb,
                    x.ravel())))
            return res.reshape(x.shape)

        @onnx_op(op_type="PyStringToHashBucket",
                 inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_int64],
                 outputs=[PyCustomOpDef.dt_int64])
        def string_to_hash_bucket(x, num_buckets):
            if num_buckets.shape != (1, ):
                raise RuntimeError(
                    "Unexpected shape {} for 'num_buckets'.".format(
                        num_buckets.shape))
            nb = num_buckets[0]
            res = np.array(
                list(map(lambda x: hash_64(x, nb, True), x.ravel())))
            return res.reshape(x.shape).astype(np.int64)

        @onnx_op(op_type="PyStringEqual",
                 inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_string],
                 outputs=[PyCustomOpDef.dt_bool])
        def string_equal(x, y):
            return x == y

        @onnx_op(op_type="PyStringSplit",
                 inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_string,
                         PyCustomOpDef.dt_bool],
                 outputs=[PyCustomOpDef.dt_int64, PyCustomOpDef.dt_string,
                          PyCustomOpDef.dt_int64])
        def string_split(input, delimiter, skip_empty):
            if delimiter.shape != (1, ):
                raise RuntimeError("demiliter must a single element tensor.")
            if skip_empty.shape != (1, ):
                raise RuntimeError("skip_empty must a single element tensor.")
            if len(input.shape) != 1:
                raise RuntimeError("input must a one dimension tensor.")
            delimiter = delimiter[0]
            skip_empty = skip_empty[0]
            texts = []
            indices = []
            max_split = 0
            for row, text in enumerate(input):
                if not text:
                    continue
                res = text.split(delimiter)
                if skip_empty:
                    res = [t for t in res if t]
                texts.extend(res)
                max_split = max(max_split, len(res))
                indices.extend((row, i) for i in range(len(res)))
            return (np.array(indices, dtype=np.int64),
                    np.array(texts),
                    np.array([len(input), max_split], dtype=np.int64))

        cls._string_join = string_join
        cls._string_to_crc32 = string_to_crc32

    def test_check_types(self):
        def_list = set(dir(PyCustomOpDef))
        type_list = [
            # 'dt_bfloat16',
            'dt_bool',
            'dt_complex128',
            'dt_complex64',
            'dt_double',
            'dt_float',
            'dt_float16',
            'dt_int16',
            'dt_int32',
            'dt_int64',
            'dt_int8',
            'dt_string',
            'dt_uint16',
            'dt_uint32',
            'dt_uint64',
            'dt_uint8']
        for t in type_list:
            self.assertIn(t, def_list)

    def test_string_upper_cc(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_string_upper('')
        self.assertIn('op_type: "StringUpper"', str(onnx_model))
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
        input_1 = np.array([["Abc"]])
        txout = sess.run(None, {'input_1': input_1})
        self.assertEqual(txout[0].tolist(), np.array([["ABC"]]).tolist())

    def test_string_upper_cc_accent(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_string_upper('')
        self.assertIn('op_type: "StringUpper"', str(onnx_model))
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
        input_1 = np.array([["Abcé"]])
        txout = sess.run(None, {'input_1': input_1})
        self.assertEqual(txout[0].tolist(), np.array([["ABCé"]]).tolist())

    def test_string_upper_python(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_string_upper('Py')
        self.assertIn('op_type: "PyStringUpper"', str(onnx_model))
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
        input_1 = np.array([["Abc"]])
        txout = sess.run(None, {'input_1': input_1})
        self.assertEqual(txout[0].tolist(), np.array([["ABC"]]).tolist())

    def test_string_upper_python_accent(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_string_upper('Py')
        self.assertIn('op_type: "PyStringUpper"', str(onnx_model))
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
        input_1 = np.array([["Abcé"]])
        txout = sess.run(None, {'input_1': input_1})
        self.assertEqual(txout[0].tolist(),
                         np.array([["ABCé".upper()]]).tolist())

    def test_string_join_python(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_string_join('Py')
        self.assertIn('op_type: "PyStringJoin"', str(onnx_model))
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
        text = np.vstack([np.array([["a", "b", "c"]]),
                          np.array([["aa", "bb", ""]])])
        self.assertEqual(text.shape, (2, 3))
        sep = np.array([";"])
        axis = np.array([1], dtype=np.int64)
        TestPythonOpString._string_join(text, sep, axis)
        txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
        self.assertEqual(
            txout[0].tolist(), np.array(["a;b;c", "aa;bb;"]).tolist())
        axis = np.array([0], dtype=np.int64)
        TestPythonOpString._string_join(text, sep, axis)
        txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
        self.assertEqual(
            txout[0].tolist(), np.array(['a;aa', 'b;bb', 'c;']).tolist())

    def test_string_join_python_3d(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_string_join('Py')
        self.assertIn('op_type: "PyStringJoin"', str(onnx_model))
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
        text = np.vstack([np.array([["a", "b", "c"]]),
                          np.array([["aa", "bb", ""]])]).reshape((2, 3, 1))
        sep = np.array([";"])
        axis = np.array([1], dtype=np.int64)
        TestPythonOpString._string_join(text, sep, axis)
        txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
        self.assertEqual(
            txout[0].tolist(), np.array([['a;b;c'], ['aa;bb;']]).tolist())

    def test_string_join_python_1d(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_string_join('Py')
        self.assertIn('op_type: "PyStringJoin"', str(onnx_model))
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
        text = np.array(["a", "b", "cc"])
        sep = np.array([";"])
        axis = np.array([0], dtype=np.int64)
        txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
        self.assertEqual(txout[0].shape, (1, ))
        self.assertEqual(
            txout[0].tolist(), np.array(["a;b;cc"]).tolist())

    def test_string_join_cc(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_string_join('')
        self.assertIn('op_type: "StringJoin"', str(onnx_model))
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
        text = np.vstack([np.array([["a", "b", "c"]]),
                          np.array([["aa", "bb", ""]])])
        sep = np.array([";"])
        axis = np.array([1], dtype=np.int64)
        txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
        self.assertEqual(
            txout[0].tolist(), np.array(["a;b;c", "aa;bb;"]).tolist())
        axis = np.array([0], dtype=np.int64)
        txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
        self.assertEqual(
            txout[0].tolist(), np.array(['a;aa', 'b;bb', 'c;']).tolist())

    def test_string_join_cc_1d(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_string_join('')
        self.assertIn('op_type: "StringJoin"', str(onnx_model))
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
        text = np.array(["a", "b", "cc"])
        sep = np.array([";"])
        axis = np.array([0], dtype=np.int64)
        txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
        self.assertEqual(
            txout[0].tolist(), np.array(["a;b;cc"]).tolist())

    def test_string_join_cc_3d(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_string_join('')
        self.assertIn('op_type: "StringJoin"', str(onnx_model))
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
        text = np.array(["a", "b", "c", "d", "e", "f", "g", "h"]).reshape((
            2, 2, 2))
        sep = np.array([";"])
        axis = np.array([2], dtype=np.int64)
        txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
        self.assertEqual(
            txout[0].tolist(),
            np.array([['a;b', 'c;d'], ['e;f', 'g;h']]).tolist())
        axis = np.array([1], dtype=np.int64)
        txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
        self.assertEqual(
            txout[0].tolist(),
            np.array([['a;c', 'b;d'], ['e;g', 'f;h']]).tolist())
        axis = np.array([0], dtype=np.int64)
        txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
        self.assertEqual(
            txout[0].tolist(),
            np.array([['a;e', 'b;f'], ['c;g', 'd;h']]).tolist())

    def test_string_replace_cc(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_string_replace('')
        self.assertIn('op_type: "StringRegexReplace"', str(onnx_model))
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
        pattern = np.array([r'def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):'])
        rewrite = np.array([r'static PyObject* py_\1(void) {'])
        text = np.array([['def myfunc():'], ['def dummy():']])
        txout = sess.run(
            None, {'text': text, 'pattern': pattern, 'rewrite': rewrite})
        exp = [['static PyObject* py_myfunc(void) {'],
               ['static PyObject* py_dummy(void) {']]
        self.assertEqual(exp, txout[0].tolist())

    def test_string_replace_cc_x2(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_string_replace('')
        self.assertIn('op_type: "StringRegexReplace"', str(onnx_model))
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
        pattern = np.array([r'def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):'])
        rewrite = np.array([r'static PyObject* py_\1(void) {'])
        text = np.array([['def myfunc():'], ['def dummy():' * 2]])
        txout = sess.run(
            None, {'text': text, 'pattern': pattern, 'rewrite': rewrite})
        exp = [['static PyObject* py_myfunc(void) {'],
               ['static PyObject* py_dummy(void) {' * 2]]
        self.assertEqual(exp, txout[0].tolist())

    def test_string_replace_python(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_string_replace('Py')
        self.assertIn('op_type: "PyStringRegexReplace"', str(onnx_model))
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
        pattern = np.array([r'def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):'])
        rewrite = np.array([r'static PyObject*\npy_\1(void)\n{'])
        text = np.array([['def myfunc():'], ['def dummy():']])
        txout = sess.run(
            None, {'text': text, 'pattern': pattern, 'rewrite': rewrite})
        exp = [['static PyObject*\npy_myfunc(void)\n{'],
               ['static PyObject*\npy_dummy(void)\n{']]
        self.assertEqual(exp, txout[0].tolist())

    def test_string_replace_python_x2(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_string_replace('Py')
        self.assertIn('op_type: "PyStringRegexReplace"', str(onnx_model))
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
        pattern = np.array([r'def\s+([a-zA-Z_][a-zA-Z_0-9]*)\s*\(\s*\):'])
        rewrite = np.array([r'static PyObject*\npy_\1(void)\n{'])
        text = np.array([['def myfunc():'], ['def dummy():' * 2]])
        txout = sess.run(
            None, {'text': text, 'pattern': pattern, 'rewrite': rewrite})
        exp = [['static PyObject*\npy_myfunc(void)\n{'],
               ['static PyObject*\npy_dummy(void)\n{' * 2]]
        self.assertEqual(exp, txout[0].tolist())

    def test_string_to_crc32_python(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_string_to_hash('Py', kind='crc32')
        self.assertIn('op_type: "PyStringToCRC32"', str(onnx_model))
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
        text = np.array([["abc", "abcdé"], ["$$^l!%*ù", ""]])
        num_buckets = np.array([44], dtype=np.uint32)
        res = self._string_to_crc32(text, num_buckets)
        self.assertEqual(res.shape, text.shape)
        exp = np.array([[10, 38], [29, 0]], dtype=np.uint32)
        self.assertEqual(exp.tolist(), res.tolist())
        txout = sess.run(
            None, {'text': text, 'num_buckets': num_buckets})
        self.assertEqual(exp.tolist(), txout[0].tolist())

    def test_string_to_hash_bucket_cc(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_string_to_hash(
            '', kind='hash_bucket')
        self.assertIn('op_type: "StringToHashBucket"', str(onnx_model))
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
        raw = ["abc", "abcdé", "$$^l!%*ù", "", "a", "A"]
        text = np.array(raw).reshape((3, 2))
        num_buckets = np.array([NUM_BUCKETS], dtype=np.int64)
        txout = sess.run(
            None, {'text': text, 'num_buckets': num_buckets})
        try:
            from tensorflow.raw_ops import StringToHashBucket
            dotf = True
        except ImportError:
            dotf = False
        if dotf:
            tfres = StringToHashBucket(
                string_tensor=text, num_buckets=num_buckets[0])
            self.assertEqual(tfres.shape, txout[0].shape)
            self.assertEqual(tfres.numpy().tolist(), txout[0].tolist())
        exp = np.array([[15, 11], [10, 21], [20, 21]], dtype=np.int64)
        self.assertEqual(exp.shape, txout[0].shape)
        self.assertEqual(exp.tolist(), txout[0].tolist())

    def test_string_to_hash_bucket_fast_cc(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_string_to_hash(
            '', kind='hash_bucket_fast')
        self.assertIn('op_type: "StringToHashBucketFast"', str(onnx_model))
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
        raw = ["abc", "abcdé", "$$^l!%*ù", "", "a", "A"]
        text = np.array(raw).reshape((3, 2))
        num_buckets = np.array([NUM_BUCKETS], dtype=np.int64)
        txout = sess.run(
            None, {'text': text, 'num_buckets': num_buckets})
        try:
            from tensorflow.raw_ops import StringToHashBucketFast
            dotf = True
        except ImportError:
            dotf = False
        if dotf:
            tfres = StringToHashBucketFast(
                input=text, num_buckets=num_buckets[0])
            self.assertEqual(tfres.shape, txout[0].shape)
            self.assertEqual(tfres.numpy().tolist(), txout[0].tolist())
        exp = np.array([[9, 17], [4, 21], [14, 12]], dtype=np.int64)
        self.assertEqual(exp.shape, txout[0].shape)
        self.assertEqual(exp.tolist(), txout[0].tolist())

    def test_string_to_hash_bucket_python(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_string_to_hash(
            'Py', kind='hash_bucket')
        self.assertIn('op_type: "PyStringToHashBucket"', str(onnx_model))
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
        raw = ["abc", "abcdé", "$$^l!%*ù", "", "a", "A"]
        text = np.array(raw).reshape((3, 2))
        num_buckets = np.array([NUM_BUCKETS], dtype=np.int64)
        exp = np.array([[9, 17], [4, 21], [14, 12]], dtype=np.int64)
        txout = sess.run(
            None, {'text': text, 'num_buckets': num_buckets})
        self.assertEqual(exp.shape, txout[0].shape)
        self.assertEqual(exp.tolist(), txout[0].tolist())

    def enumerate_matrix_couples(self):
        for i in range(1, 5):
            shape = (3,) * i
            a = (np.random.rand(*shape) * 10).astype(np.int32).astype(np.str)
            yield a, a
            for j in range(i):
                shape2 = list(shape)
                shape2[j] = 1
                b = (np.random.rand(*shape2) * 10).astype(
                    np.int32).astype(np.str)
                yield a, b
                for k in range(j+1, i):
                    shape3 = list(shape2)
                    shape3[k] = 1
                    b = (np.random.rand(*shape3) * 10).astype(
                        np.int32).astype(np.str)
                    yield a, b

    def test_string_equal_python(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_string_equal('Py')
        self.assertIn('op_type: "PyStringEqual"', str(onnx_model))
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)

        for x, y in self.enumerate_matrix_couples():
            txout = sess.run(None, {'x': x, 'y': y})
            self.assertEqual(txout[0].tolist(), (x == y).tolist())
            txout = sess.run(None, {'x': y, 'y': x})
            self.assertEqual(txout[0].tolist(), (y == x).tolist())

    def test_string_equal_cc(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_string_equal('')
        self.assertIn('op_type: "StringEqual"', str(onnx_model))
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)

        for x, y in self.enumerate_matrix_couples():
            txout = sess.run(None, {'x': x, 'y': y})
            self.assertEqual(txout[0].tolist(), (x == y).tolist())
            txout = sess.run(None, {'x': y, 'y': x})
            self.assertEqual(txout[0].tolist(), (y == x).tolist())

    def test_string_split_python(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_string_split('Py')
        self.assertIn('op_type: "PyStringSplit"', str(onnx_model))
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
        input = np.array(["a,,b", "", "aa,b,c", "dddddd"])
        delimiter = np.array([","])

        for skip in [True, False]:
            with self.subTest(skip=skip):
                skip_empty = np.array([skip])

                txout = sess.run(
                    None, {'input': input, 'delimiter': delimiter,
                           'skip_empty': skip_empty})

                if skip_empty:
                    exp_indices = np.array(
                        [[0, 0], [0, 1], [2, 0], [2, 1], [2, 2], [3, 0]])
                    exp_text = np.array(['a', 'b', 'aa', 'b', 'c', 'dddddd'])
                else:
                    exp_indices = np.array(
                        [[0, 0], [0, 1], [0, 2], [2, 0], [2, 1],
                         [2, 2], [3, 0]])
                    exp_text = np.array(
                        ['a', '', 'b', 'aa', 'b', 'c', 'dddddd'])
                exp_shape = np.array([4, 3])
                self.assertEqual(exp_indices.tolist(), txout[0].tolist())
                self.assertEqual(exp_text.tolist(), txout[1].tolist())
                self.assertEqual(exp_shape.tolist(), txout[2].tolist())

    def test_string_split_cc(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_string_split('')
        self.assertIn('op_type: "StringSplit"', str(onnx_model))
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
        input = np.array(["a,,b", "", "aa,b,c", "dddddd"])
        delimiter = np.array([","])

        for skip in [True, False]:
            with self.subTest(skip=skip):
                skip_empty = np.array([skip])

                txout = sess.run(
                    None, {'input': input, 'delimiter': delimiter,
                           'skip_empty': skip_empty})

                try:
                    from tensorflow.raw_ops import StringSplit
                    dotf = True
                except ImportError:
                    dotf = False
                if dotf:
                    tfres = StringSplit(
                        input=input, delimiter=",,", skip_empty=skip)
                    self.assertEqual(
                        [_.decode() for _ in tfres[1].numpy().tolist()],
                        txout[1].tolist())
                    self.assertEqual(
                        tfres[0].numpy().tolist(), txout[0].tolist())
                    self.assertEqual(
                        tfres[2].numpy().tolist(), txout[2].tolist())

                if skip_empty:
                    exp_indices = np.array(
                        [[0, 0], [0, 1], [2, 0], [2, 1], [2, 2], [3, 0]])
                    exp_text = np.array(['a', 'b', 'aa', 'b', 'c', 'dddddd'])
                else:
                    exp_indices = np.array(
                        [[0, 0], [0, 1], [0, 2], [2, 0], [2, 1],
                         [2, 2], [3, 0]])
                    exp_text = np.array(
                        ['a', '', 'b', 'aa', 'b', 'c', 'dddddd'])
                exp_shape = np.array([4, 3])
                self.assertEqual(exp_indices.tolist(), txout[0].tolist())
                self.assertEqual(exp_text.tolist(), txout[1].tolist())
                self.assertEqual(exp_shape.tolist(), txout[2].tolist())

    def test_string_split_cc_sep2(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_string_split('')
        self.assertIn('op_type: "StringSplit"', str(onnx_model))
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
        input = np.array(["a*b", "a,*b", "aa,b,,c", 'z', "dddddd,", "**"])
        delimiter = np.array([",*"])

        for skip in [True, False]:
            with self.subTest(skip=skip):
                skip_empty = np.array([skip])

                txout = sess.run(
                    None, {'input': input, 'delimiter': delimiter,
                           'skip_empty': skip_empty})

                try:
                    from tensorflow.raw_ops import StringSplit
                    dotf = True
                except ImportError:
                    dotf = False
                if dotf:
                    tfres = StringSplit(
                        input=input, delimiter=",*", skip_empty=skip)
                    self.assertEqual(
                        [_.decode() for _ in tfres[1].numpy().tolist()],
                        txout[1].tolist())
                    self.assertEqual(
                        tfres[0].numpy().tolist(), txout[0].tolist())
                    self.assertEqual(
                        tfres[2].numpy().tolist(), txout[2].tolist())

                if skip_empty:
                    exp_indices = np.array(
                        [[0, 0], [0, 1], [1, 0], [1, 1], [2, 0], [2, 1],
                         [2, 2], [3, 0], [4, 0]])
                    exp_text = np.array(
                        ['a', 'b', 'a', 'b', 'aa', 'b', 'c', 'z', 'dddddd'])
                    exp_shape = np.array([6, 3])
                else:
                    exp_indices = np.array(
                        [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [2, 0],
                         [2, 1], [2, 2], [2, 3], [3, 0], [4, 0], [4, 1],
                         [5, 0], [5, 1], [5, 2]])
                    exp_text = np.array(
                        ['a', 'b', 'a', '', 'b', 'aa', 'b', '', 'c',
                         'z', 'dddddd', '', '', '', ''])
                    exp_shape = np.array([6, 4])
                self.assertEqual(exp_text.tolist(), txout[1].tolist())
                self.assertEqual(exp_indices.tolist(), txout[0].tolist())
                self.assertEqual(exp_shape.tolist(), txout[2].tolist())

    def test_string_split_cc_sep0(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_string_split('')
        self.assertIn('op_type: "StringSplit"', str(onnx_model))
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
        input = np.array(["a*b", "a,*b"])
        delimiter = np.array([""])

        for skip in [True, False]:
            with self.subTest(skip=skip):
                skip_empty = np.array([skip])

                txout = sess.run(
                    None, {'input': input, 'delimiter': delimiter,
                           'skip_empty': skip_empty})

                try:
                    from tensorflow.raw_ops import StringSplit
                    dotf = True
                except ImportError:
                    dotf = False
                if dotf:
                    tfres = StringSplit(
                        input=input, delimiter="", skip_empty=skip)
                    self.assertEqual(
                        [_.decode() for _ in tfres[1].numpy().tolist()],
                        txout[1].tolist())
                    self.assertEqual(
                        tfres[0].numpy().tolist(), txout[0].tolist())
                    self.assertEqual(
                        tfres[2].numpy().tolist(), txout[2].tolist())

                exp_indices = np.array(
                    [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [1, 3]])
                exp_text = np.array(['a', '*', 'b', 'a', ',', '*', 'b'])
                exp_shape = np.array([2, 4])
                self.assertEqual(exp_text.tolist(), txout[1].tolist())
                self.assertEqual(exp_indices.tolist(), txout[0].tolist())
                self.assertEqual(exp_shape.tolist(), txout[2].tolist())


if __name__ == "__main__":
    unittest.main()