microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
operators/math/segment_sum.cc
91lines · modecode
| 1 | // Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | // Licensed under the MIT License. |
| 3 | |
| 4 | #include "segment_sum.hpp" |
| 5 | |
| 6 | template <typename T> |
| 7 | void KernelSegmentSum_Compute(Ort::CustomOpApi& ort_, OrtKernelContext* context) { |
| 8 | // Setup inputs |
| 9 | const OrtValue* data = ort_.KernelContext_GetInput(context, 0); |
| 10 | const T* p_data = ort_.GetTensorData<T>(data); |
| 11 | const OrtValue* segment_ids = ort_.KernelContext_GetInput(context, 1); |
| 12 | const int64_t* p_segment_ids = ort_.GetTensorData<int64_t>(segment_ids); |
| 13 | |
| 14 | // Setup output |
| 15 | OrtTensorDimensions dim_data(ort_, data); |
| 16 | OrtTensorDimensions dim_seg(ort_, segment_ids); |
| 17 | if (dim_data.size() == 0 || dim_seg.size() == 0) |
| 18 | ORT_CXX_API_THROW("Both inputs cannot be empty.", ORT_INVALID_GRAPH); |
| 19 | if (dim_seg.size() != 1) |
| 20 | ORT_CXX_API_THROW("segment_ids must a single tensor", ORT_INVALID_GRAPH); |
| 21 | if (dim_data[0] != dim_seg[0]) |
| 22 | ORT_CXX_API_THROW(MakeString( |
| 23 | "First dimensions of data and segment_ids should be the same, data shape: ", dim_data.GetDims(), |
| 24 | " segment_ids shape: ", dim_seg.GetDims()), ORT_INVALID_GRAPH); |
| 25 | |
| 26 | int64_t last_seg = p_segment_ids[dim_seg[0] - 1]; |
| 27 | OrtTensorDimensions dim_out = dim_data; |
| 28 | dim_out[0] = last_seg + 1; |
| 29 | |
| 30 | OrtValue* v = ort_.KernelContext_GetOutput(context, 0, dim_out.data(), dim_out.size()); |
| 31 | T* p_output = ort_.GetTensorMutableData<T>(v); |
| 32 | int64_t out_size = dim_out.Size(); |
| 33 | memset(p_output, 0, out_size * sizeof(T)); |
| 34 | |
| 35 | // The implementation is naive. It could be parallelized and |
| 36 | // use SIMD instructions to be faster. |
| 37 | int64_t in_stride = dim_data.Size(); |
| 38 | const T* begin = p_data; |
| 39 | const T* end = p_data + in_stride; |
| 40 | in_stride /= dim_data[0]; |
| 41 | T *p_out, *p_out_end; |
| 42 | const int64_t* p_seg = p_segment_ids; |
| 43 | for (; begin != end; ++p_seg) { |
| 44 | if ((p_seg != p_segment_ids) && (*p_seg != *(p_seg - 1)) && (*p_seg != *(p_seg - 1) + 1)) |
| 45 | ORT_CXX_API_THROW(MakeString("segment_ids must be increasing but found ", |
| 46 | *(p_seg - 1), " and ", *p_seg, " at position ", |
| 47 | std::distance(p_segment_ids, p_seg), "."), ORT_RUNTIME_EXCEPTION); |
| 48 | p_out = p_output + *p_seg * in_stride; |
| 49 | p_out_end = p_out + in_stride; |
| 50 | for (; p_out != p_out_end; ++p_out, ++begin) |
| 51 | *p_out += *begin; |
| 52 | } |
| 53 | } |
| 54 | |
| 55 | KernelSegmentSum::KernelSegmentSum(OrtApi api) : BaseKernel(api) { |
| 56 | } |
| 57 | |
| 58 | void KernelSegmentSum::Compute(OrtKernelContext* context) { |
| 59 | KernelSegmentSum_Compute<float>(ort_, context); |
| 60 | } |
| 61 | |
| 62 | size_t CustomOpSegmentSum::GetInputTypeCount() const { |
| 63 | return 2; |
| 64 | }; |
| 65 | |
| 66 | size_t CustomOpSegmentSum::GetOutputTypeCount() const { |
| 67 | return 1; |
| 68 | }; |
| 69 | |
| 70 | ONNXTensorElementDataType CustomOpSegmentSum::GetOutputType(size_t /*index*/) const { |
| 71 | return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; |
| 72 | }; |
| 73 | |
| 74 | void* CustomOpSegmentSum::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const { |
| 75 | return new KernelSegmentSum(api); |
| 76 | }; |
| 77 | |
| 78 | const char* CustomOpSegmentSum::GetName() const { |
| 79 | return "SegmentSum"; |
| 80 | }; |
| 81 | |
| 82 | ONNXTensorElementDataType CustomOpSegmentSum::GetInputType(size_t index) const { |
| 83 | switch (index) { |
| 84 | case 0: |
| 85 | return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; |
| 86 | case 1: |
| 87 | return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; |
| 88 | default: |
| 89 | ORT_CXX_API_THROW("Operator SegmentSum has 2 inputs.", ORT_INVALID_ARGUMENT); |
| 90 | } |
| 91 | }; |
| 92 | |