# coding: utf-8
import unittest
import numpy as np
from onnx import helper, onnx_pb as onnx_proto
import onnxruntime as _ort
from onnxruntime_extensions import (
    OrtPyFunction,
    PyCustomOpDef,
    onnx_op,
    make_onnx_model,
    get_library_path as _get_library_path,
)


def _create_test_model_segment_sum(prefix, domain="ai.onnx.contrib"):
    nodes = [
        helper.make_node("Identity", ["data"], ["id1"]),
        helper.make_node("Identity", ["segment_ids"], ["id2"]),
        helper.make_node("%sSegmentSum" % prefix, ["id1", "id2"], ["z"], domain=domain),
    ]

    input0 = helper.make_tensor_value_info("data", onnx_proto.TensorProto.FLOAT, [])
    input1 = helper.make_tensor_value_info(
        "segment_ids", onnx_proto.TensorProto.INT64, []
    )
    output0 = helper.make_tensor_value_info("z", onnx_proto.TensorProto.FLOAT, [])

    graph = helper.make_graph(nodes, "test0", [input0, input1], [output0])
    model = make_onnx_model(graph)
    return model


class TestMathOpString(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        @onnx_op(
            op_type="PySegmentSum",
            inputs=[PyCustomOpDef.dt_float, PyCustomOpDef.dt_int64],
            outputs=[PyCustomOpDef.dt_float],
        )
        def segment_sum(data, segment_ids):
            # segment_ids is sorted
            nb_seg = segment_ids[-1] + 1
            sh = (nb_seg,) + data.shape[1:]
            res = np.zeros(sh, dtype=data.dtype)
            for seg, row in zip(segment_ids, data):
                res[seg] += row
            return res

    def test_segment_sum_cc(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_segment_sum("")
        self.assertIn('op_type: "SegmentSum"', str(onnx_model))
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
        data = np.array([[1, 2, 3, 4], [4, 3, 2, 1], [5, 6, 7, 8]], dtype=np.float32)
        segment_ids = np.array([0, 0, 1], dtype=np.int64)
        exp = np.array([[5, 5, 5, 5], [5, 6, 7, 8]], dtype=np.float32)
        txout = sess.run(None, {"data": data, "segment_ids": segment_ids})
        self.assertEqual(exp.shape, txout[0].shape)
        self.assertEqual(exp.tolist(), txout[0].tolist())

    def test_segment_sum_cc_negative_segment_ids(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_segment_sum("")
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
        data = np.array([[1.0], [2.0]], dtype=np.float32)
        segment_ids = np.array([-1, 0], dtype=np.int64)
        with self.assertRaises(Exception) as ctx:
            sess.run(None, {"data": data, "segment_ids": segment_ids})
        self.assertIn("negative", str(ctx.exception).lower())

    def test_segment_sum_python(self):
        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        onnx_model = _create_test_model_segment_sum("Py")
        self.assertIn('op_type: "PySegmentSum"', str(onnx_model))
        sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
        data = np.array([[1, 2, 3, 4], [4, 3, 2, 1], [5, 6, 7, 8]], dtype=np.float32)
        segment_ids = np.array([0, 0, 1], dtype=np.int64)
        exp = np.array([[5, 5, 5, 5], [5, 6, 7, 8]], dtype=np.float32)
        txout = sess.run(None, {"data": data, "segment_ids": segment_ids})
        self.assertEqual(exp.shape, txout[0].shape)
        self.assertEqual(exp.tolist(), txout[0].tolist())

        try:
            from tensorflow.raw_ops import SegmentSum

            dotf = True
        except ImportError:
            dotf = False
        if dotf:
            tfres = SegmentSum(data=data, segment_ids=segment_ids)
            self.assertEqual(tfres.shape, txout[0].shape)
            self.assertEqual(tfres.numpy().tolist(), txout[0].tolist())

    def test_inverse(self):
        mat = np.random.rand(5, 5).astype(np.float32)
        inv_mat = np.linalg.inv(mat)
        ort_inv = OrtPyFunction.from_customop('Inverse')
        act_mat = ort_inv(mat)
        np.testing.assert_allclose(inv_mat, act_mat, rtol=1.e-2, atol=1.e-3)


class TestRaggedTensorToDenseValidation(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        nodes = [
            helper.make_node(
                "RaggedTensorToDense",
                inputs=["unused", "values", "default_value", "offsets"],
                outputs=["out"],
                domain="ai.onnx.contrib",
            ),
        ]
        inputs = [
            helper.make_tensor_value_info("unused", onnx_proto.TensorProto.INT64, [0]),
            helper.make_tensor_value_info("values", onnx_proto.TensorProto.INT64, [None]),
            helper.make_tensor_value_info("default_value", onnx_proto.TensorProto.INT64, [1]),
            helper.make_tensor_value_info("offsets", onnx_proto.TensorProto.INT64, [None]),
        ]
        output = helper.make_tensor_value_info("out", onnx_proto.TensorProto.INT64, [None])
        graph = helper.make_graph(nodes, "test_ragged", inputs, [output])
        cls.model = make_onnx_model(graph)

        so = _ort.SessionOptions()
        so.register_custom_ops_library(_get_library_path())
        cls.so = so

    def _run(self, values, offsets):
        sess = _ort.InferenceSession(self.model.SerializeToString(), self.so, providers=["CPUExecutionProvider"])
        return sess.run(None, {
            "unused": np.array([], dtype=np.int64),
            "values": np.array(values, dtype=np.int64),
            "default_value": np.array([-1], dtype=np.int64),
            "offsets": np.array(offsets, dtype=np.int64),
        })

    def test_empty_offsets(self):
        with self.assertRaises(Exception) as ctx:
            self._run([1, 2, 3], [])
        self.assertIn("must not be empty", str(ctx.exception))

    def test_out_of_range_offset(self):
        with self.assertRaises(Exception) as ctx:
            self._run([1, 2, 3], [0, 10])
        self.assertIn("out of valid range", str(ctx.exception))

    def test_negative_offset(self):
        with self.assertRaises(Exception) as ctx:
            self._run([1, 2, 3], [0, -1, 3])
        self.assertIn("out of valid range", str(ctx.exception))

    def test_non_monotonic_offsets(self):
        with self.assertRaises(Exception) as ctx:
            self._run([1, 2, 3], [0, 3, 1])
        self.assertIn("monotonic non-decreasing", str(ctx.exception))

    def test_valid_offsets(self):
        result = self._run([10, 20, 30], [0, 2, 3])
        self.assertEqual(result[0].shape[0], 2)


if __name__ == "__main__":
    unittest.main()
