microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
9eef22cb81d762f6c093a4740c992582267a783f

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/math/segment_sum.cc

91lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "segment_sum.hpp"
5
6template <typename T>
7void KernelSegmentSum_Compute(Ort::CustomOpApi& ort_, OrtKernelContext* context) {
8 // Setup inputs
9 const OrtValue* data = ort_.KernelContext_GetInput(context, 0);
10 const T* p_data = ort_.GetTensorData<T>(data);
11 const OrtValue* segment_ids = ort_.KernelContext_GetInput(context, 1);
12 const int64_t* p_segment_ids = ort_.GetTensorData<int64_t>(segment_ids);
13
14 // Setup output
15 OrtTensorDimensions dim_data(ort_, data);
16 OrtTensorDimensions dim_seg(ort_, segment_ids);
17 if (dim_data.size() == 0 || dim_seg.size() == 0)
18 ORT_CXX_API_THROW("Both inputs cannot be empty.", ORT_INVALID_GRAPH);
19 if (dim_seg.size() != 1)
20 ORT_CXX_API_THROW("segment_ids must a single tensor", ORT_INVALID_GRAPH);
21 if (dim_data[0] != dim_seg[0])
22 ORT_CXX_API_THROW(MakeString(
23 "First dimensions of data and segment_ids should be the same, data shape: ", dim_data.GetDims(),
24 " segment_ids shape: ", dim_seg.GetDims()), ORT_INVALID_GRAPH);
25
26 int64_t last_seg = p_segment_ids[dim_seg[0] - 1];
27 OrtTensorDimensions dim_out = dim_data;
28 dim_out[0] = last_seg + 1;
29
30 OrtValue* v = ort_.KernelContext_GetOutput(context, 0, dim_out.data(), dim_out.size());
31 T* p_output = ort_.GetTensorMutableData<T>(v);
32 int64_t out_size = dim_out.Size();
33 memset(p_output, 0, out_size * sizeof(T));
34
35 // The implementation is naive. It could be parallelized and
36 // use SIMD instructions to be faster.
37 int64_t in_stride = dim_data.Size();
38 const T* begin = p_data;
39 const T* end = p_data + in_stride;
40 in_stride /= dim_data[0];
41 T *p_out, *p_out_end;
42 const int64_t* p_seg = p_segment_ids;
43 for (; begin != end; ++p_seg) {
44 if ((p_seg != p_segment_ids) && (*p_seg != *(p_seg - 1)) && (*p_seg != *(p_seg - 1) + 1))
45 ORT_CXX_API_THROW(MakeString("segment_ids must be increasing but found ",
46 *(p_seg - 1), " and ", *p_seg, " at position ",
47 std::distance(p_segment_ids, p_seg), "."), ORT_RUNTIME_EXCEPTION);
48 p_out = p_output + *p_seg * in_stride;
49 p_out_end = p_out + in_stride;
50 for (; p_out != p_out_end; ++p_out, ++begin)
51 *p_out += *begin;
52 }
53}
54
55KernelSegmentSum::KernelSegmentSum(OrtApi api) : BaseKernel(api) {
56}
57
58void KernelSegmentSum::Compute(OrtKernelContext* context) {
59 KernelSegmentSum_Compute<float>(ort_, context);
60}
61
62size_t CustomOpSegmentSum::GetInputTypeCount() const {
63 return 2;
64};
65
66size_t CustomOpSegmentSum::GetOutputTypeCount() const {
67 return 1;
68};
69
70ONNXTensorElementDataType CustomOpSegmentSum::GetOutputType(size_t /*index*/) const {
71 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
72};
73
74void* CustomOpSegmentSum::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
75 return new KernelSegmentSum(api);
76};
77
78const char* CustomOpSegmentSum::GetName() const {
79 return "SegmentSum";
80};
81
82ONNXTensorElementDataType CustomOpSegmentSum::GetInputType(size_t index) const {
83 switch (index) {
84 case 0:
85 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
86 case 1:
87 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
88 default:
89 ORT_CXX_API_THROW("Operator SegmentSum has 2 inputs.", ORT_INVALID_ARGUMENT);
90 }
91};
92