microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
rel-0.7

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/math/segement_extraction.cc

61lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "segment_extraction.hpp"
5
6KernelSegmentExtraction::KernelSegmentExtraction(const OrtApi& api, const OrtKernelInfo& info)
7 : BaseKernel(api, info) {
8}
9
10void KernelSegmentExtraction::Compute(OrtKernelContext* context) {
11 const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
12 const int64_t* p_data = ort_.GetTensorData<int64_t>(input);
13 OrtTensorDimensions input_dim(ort_, input);
14 if (!((input_dim.size() == 1) || (input_dim.size() == 2 && input_dim[0] == 1))) {
15 ORTX_CXX_API_THROW("[SegmentExtraction]: Expect input dimension [n] or [1,n].", ORT_INVALID_GRAPH);
16 }
17
18 std::vector<std::int64_t> segment_value;
19 std::vector<std::int64_t> segment_position;
20 for (std::int64_t i = 0; i < input_dim.Size(); i++) {
21 if (!p_data[i]) {
22 continue;
23 }
24
25 // push start position and value
26 if (i == 0 || p_data[i - 1] != p_data[i]) {
27 segment_value.push_back(p_data[i]);
28 segment_position.push_back(i);
29 }
30
31 // push end position
32 if (i == (input_dim.Size() - 1) || p_data[i + 1] != p_data[i]) {
33 segment_position.push_back(i + 1);
34 }
35 }
36
37 std::vector<int64_t> segment_value_dim({static_cast<int64_t>(segment_value.size())});
38 std::vector<int64_t> segment_position_dim({static_cast<int64_t>(segment_value.size()), 2});
39 SetOutput(context, 0, segment_position_dim, segment_position);
40 SetOutput(context, 1, segment_value_dim, segment_value);
41}
42
43size_t CustomOpSegmentExtraction::GetInputTypeCount() const {
44 return 1;
45};
46
47size_t CustomOpSegmentExtraction::GetOutputTypeCount() const {
48 return 2;
49};
50
51ONNXTensorElementDataType CustomOpSegmentExtraction::GetOutputType(size_t /*index*/) const {
52 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
53};
54
55const char* CustomOpSegmentExtraction::GetName() const {
56 return "SegmentExtraction";
57};
58
59ONNXTensorElementDataType CustomOpSegmentExtraction::GetInputType(size_t /*index*/) const {
60 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
61};
62