microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
edgchen1/fix_ci

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/math/segement_extraction.cc

60lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "segment_extraction.hpp"
5
6KernelSegmentExtraction::KernelSegmentExtraction(const OrtApi& api) : BaseKernel(api) {
7}
8
9void KernelSegmentExtraction::Compute(OrtKernelContext* context) {
10 const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
11 const int64_t* p_data = ort_.GetTensorData<int64_t>(input);
12 OrtTensorDimensions input_dim(ort_, input);
13 if (!((input_dim.size() == 1) || (input_dim.size() == 2 && input_dim[0] == 1))) {
14 ORTX_CXX_API_THROW("[SegmentExtraction]: Expect input dimension [n] or [1,n]." , ORT_INVALID_GRAPH);
15 }
16
17 std::vector<std::int64_t> segment_value;
18 std::vector<std::int64_t> segment_position;
19 for (std::int64_t i = 0; i < input_dim.Size(); i++) {
20 if (!p_data[i]) {
21 continue;
22 }
23
24 // push start position and value
25 if (i == 0 || p_data[i - 1] != p_data[i]) {
26 segment_value.push_back(p_data[i]);
27 segment_position.push_back(i);
28 }
29
30 // push end position
31 if (i == (input_dim.Size() - 1) || p_data[i + 1] != p_data[i]) {
32 segment_position.push_back(i + 1);
33 }
34 }
35
36 std::vector<int64_t> segment_value_dim({static_cast<int64_t>(segment_value.size())});
37 std::vector<int64_t> segment_position_dim({static_cast<int64_t>(segment_value.size()), 2});
38 SetOutput(context, 0, segment_position_dim, segment_position);
39 SetOutput(context, 1, segment_value_dim, segment_value);
40}
41
42size_t CustomOpSegmentExtraction::GetInputTypeCount() const {
43 return 1;
44};
45
46size_t CustomOpSegmentExtraction::GetOutputTypeCount() const {
47 return 2;
48};
49
50ONNXTensorElementDataType CustomOpSegmentExtraction::GetOutputType(size_t /*index*/) const {
51 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
52};
53
54const char* CustomOpSegmentExtraction::GetName() const {
55 return "SegmentExtraction";
56};
57
58ONNXTensorElementDataType CustomOpSegmentExtraction::GetInputType(size_t /*index*/) const {
59 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
60};
61