microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
9eef22cb81d762f6c093a4740c992582267a783f

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/ocos.cc

101lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3#include <sstream>
4#include "ocos.h"
5
6bool BaseKernel::HasAttribute(const char* name) const {
7 if (info_ == nullptr) {
8 ORT_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
9 }
10 size_t size;
11 std::string out;
12 // Crashes here.
13 OrtStatus* status = api_.KernelInfoGetAttribute_string(info_, name, nullptr, &size);
14 auto r = api_.GetErrorCode(status);
15 bool has = (r == ORT_INVALID_ARGUMENT) || (r == ORT_OK);
16 if (has) {
17 api_.ReleaseStatus(status);
18 return has;
19 }
20 const char* error = api_.GetErrorMessage(status);
21 if (strstr(error, "No attribute") == error) {
22 api_.ReleaseStatus(status);
23 return false;
24 }
25 api_.ReleaseStatus(status);
26 return true;
27}
28
29OrtErrorCode BaseKernel::GetErrorCodeAndRelease(OrtStatusPtr status) {
30 if (status == nullptr) {
31 return ORT_OK;
32 }
33 auto error_code = api_.GetErrorCode(status);
34 api_.ReleaseStatus(status);
35 return error_code;
36}
37
38void BaseKernel::SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim, const std::vector<int64_t>& data) {
39 OrtValue* output = ort_.KernelContext_GetOutput(ctx, output_idx, dim.data(), dim.size());
40 int64_t * data_ptr = ort_.GetTensorMutableData<int64_t>(output);
41 for (int i = 0; i < data.size(); i++) {
42 data_ptr[i] = data[i];
43 }
44}
45
46template <>
47bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) {
48 if (info_ == nullptr) {
49 ORT_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
50 }
51
52 size_t size = 0;
53 OrtStatus* status = api_.KernelInfoGetAttribute_string(info_, name, nullptr, &size);
54
55 // The status should be ORT_INVALID_ARGUMENT because the size is insufficient to hold the string
56 if (GetErrorCodeAndRelease(status) != ORT_INVALID_ARGUMENT) {
57 return false;
58 }
59
60 value.resize(size);
61 status = api_.KernelInfoGetAttribute_string(info_, name, &value[0], &size);
62 if (GetErrorCodeAndRelease(status) != ORT_OK) {
63 return false;
64 }
65 value.resize(size - 1);
66
67 return true;
68}
69
70template <>
71bool BaseKernel::TryToGetAttribute(const char* name, int64_t& value) {
72 if (info_ == nullptr) {
73 ORT_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
74 }
75
76 return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(info_, name, &value)) == ORT_OK;
77}
78
79template <>
80bool BaseKernel::TryToGetAttribute(const char* name, float& value) {
81 if (info_ == nullptr) {
82 ORT_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
83 }
84
85 return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_float(info_, name, &value)) == ORT_OK;
86}
87
88template <>
89bool BaseKernel::TryToGetAttribute(const char* name, bool& value) {
90 if (info_ == nullptr) {
91 ORT_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
92 }
93
94 int64_t origin_value = 0;
95 if (GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(info_, name, &origin_value)) != ORT_OK) {
96 return false;
97 }
98
99 value = origin_value == 1;
100 return true;
101}
102