// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "segment_sum.hpp"

OrtStatusPtr segment_sum(const ortc::Tensor<float>& data,
                 const ortc::Tensor<int64_t>& segment_ids,
                 ortc::Tensor<float>& output) {
  auto& dim_data = data.Shape();
  auto& dim_seg = segment_ids.Shape();
  if (dim_data.size() == 0 || dim_seg.size() == 0)
    return OrtW::CreateStatus("Both inputs cannot be empty.", ORT_INVALID_GRAPH);
  if (dim_seg.size() != 1)
    return OrtW::CreateStatus("segment_ids must a single tensor", ORT_INVALID_GRAPH);
  if (dim_data[0] != dim_seg[0])
    return OrtW::CreateStatus(MakeString(
                           "First dimensions of data and segment_ids should be the same, data shape: ", dim_data,
                           " segment_ids shape: ", dim_seg).c_str(),
                       ORT_INVALID_GRAPH);

  const int64_t* p_segment_ids = segment_ids.Data();
  const float* p_data = data.Data();

  if (dim_seg[0] == 0) {
    return OrtW::CreateStatus("segment_ids must not be empty.", ORT_INVALID_ARGUMENT);
  }

  int64_t last_seg = p_segment_ids[dim_seg[0] - 1];
  if (p_segment_ids[0] < 0 || last_seg < 0)
    return OrtW::CreateStatus("segment_ids must not contain negative values.", ORT_INVALID_ARGUMENT);

  std::vector<int64_t> dim_out = dim_data;
  dim_out[0] = last_seg + 1;

  float* p_output = output.Allocate(dim_out);
  int64_t out_size = output.NumberOfElement();
  memset(p_output, 0, static_cast<size_t>(out_size * sizeof(float)));

  // The implementation is naive. It could be parallelized and
  // use SIMD instructions to be faster.
  int64_t in_stride = data.NumberOfElement();
  const float* begin = p_data;
  const float* end = p_data + in_stride;
  in_stride /= dim_data[0];
  float *p_out, *p_out_end;
  const int64_t* p_seg = p_segment_ids;
  for (; begin != end; ++p_seg) {
    if ((p_seg != p_segment_ids) && (*p_seg != *(p_seg - 1)) && (*p_seg != *(p_seg - 1) + 1)) {
      return OrtW::CreateStatus(MakeString("segment_ids must be increasing but found ",
                                        *(p_seg - 1), " and ", *p_seg, " at position ",
                                        std::distance(p_segment_ids, p_seg), ".").c_str(),
                             ORT_RUNTIME_EXCEPTION);
    }
    if (*p_seg < 0 || *p_seg > last_seg) {
      return OrtW::CreateStatus(MakeString("segment_ids value ", *p_seg, " at position ",
                                        std::distance(p_segment_ids, p_seg),
                                        " is out of range [0, ", last_seg, "].").c_str(),
                             ORT_INVALID_ARGUMENT);
    }
    p_out = p_output + *p_seg * in_stride;
    p_out_end = p_out + in_stride;
    for (; p_out != p_out_end; ++p_out, ++begin) {
      *p_out += *begin;
    }
  }

  return nullptr;
}
