microsoft/onnxruntime-extensions

Public

mirrored from https://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v0.7.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/math/segment_sum.cc

89lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "segment_sum.hpp"
5
6template <typename T>
7void KernelSegmentSum_Compute(OrtW::CustomOpApi& ort_, OrtKernelContext* context) {
8 // Setup inputs
9 const OrtValue* data = ort_.KernelContext_GetInput(context, 0);
10 const T* p_data = ort_.GetTensorData<T>(data);
11 const OrtValue* segment_ids = ort_.KernelContext_GetInput(context, 1);
12 const int64_t* p_segment_ids = ort_.GetTensorData<int64_t>(segment_ids);
13
14 // Setup output
15 OrtTensorDimensions dim_data(ort_, data);
16 OrtTensorDimensions dim_seg(ort_, segment_ids);
17 if (dim_data.size() == 0 || dim_seg.size() == 0)
18 ORTX_CXX_API_THROW("Both inputs cannot be empty.", ORT_INVALID_GRAPH);
19 if (dim_seg.size() != 1)
20 ORTX_CXX_API_THROW("segment_ids must a single tensor", ORT_INVALID_GRAPH);
21 if (dim_data[0] != dim_seg[0])
22 ORTX_CXX_API_THROW(MakeString(
23 "First dimensions of data and segment_ids should be the same, data shape: ", dim_data,
24 " segment_ids shape: ", dim_seg),
25 ORT_INVALID_GRAPH);
26
27 int64_t last_seg = p_segment_ids[dim_seg[0] - 1];
28 OrtTensorDimensions dim_out = dim_data;
29 dim_out[0] = last_seg + 1;
30
31 OrtValue* v = ort_.KernelContext_GetOutput(context, 0, dim_out.data(), dim_out.size());
32 T* p_output = ort_.GetTensorMutableData<T>(v);
33 int64_t out_size = dim_out.Size();
34 memset(p_output, 0, static_cast<size_t>(out_size * sizeof(T)));
35
36 // The implementation is naive. It could be parallelized and
37 // use SIMD instructions to be faster.
38 int64_t in_stride = dim_data.Size();
39 const T* begin = p_data;
40 const T* end = p_data + in_stride;
41 in_stride /= dim_data[0];
42 T *p_out, *p_out_end;
43 const int64_t* p_seg = p_segment_ids;
44 for (; begin != end; ++p_seg) {
45 if ((p_seg != p_segment_ids) && (*p_seg != *(p_seg - 1)) && (*p_seg != *(p_seg - 1) + 1))
46 ORTX_CXX_API_THROW(MakeString("segment_ids must be increasing but found ",
47 *(p_seg - 1), " and ", *p_seg, " at position ",
48 std::distance(p_segment_ids, p_seg), "."),
49 ORT_RUNTIME_EXCEPTION);
50 p_out = p_output + *p_seg * in_stride;
51 p_out_end = p_out + in_stride;
52 for (; p_out != p_out_end; ++p_out, ++begin)
53 *p_out += *begin;
54 }
55}
56
57KernelSegmentSum::KernelSegmentSum(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
58}
59
60void KernelSegmentSum::Compute(OrtKernelContext* context) {
61 KernelSegmentSum_Compute<float>(ort_, context);
62}
63
64size_t CustomOpSegmentSum::GetInputTypeCount() const {
65 return 2;
66};
67
68size_t CustomOpSegmentSum::GetOutputTypeCount() const {
69 return 1;
70};
71
72ONNXTensorElementDataType CustomOpSegmentSum::GetOutputType(size_t /*index*/) const {
73 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
74};
75
76const char* CustomOpSegmentSum::GetName() const {
77 return "SegmentSum";
78};
79
80ONNXTensorElementDataType CustomOpSegmentSum::GetInputType(size_t index) const {
81 switch (index) {
82 case 0:
83 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
84 case 1:
85 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
86 default:
87 ORTX_CXX_API_THROW("Operator SegmentSum has 2 inputs.", ORT_INVALID_ARGUMENT);
88 }
89};
90