microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
78d8dd5705d87040aac4c4479982292862ffee05

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/ocos.cc

102lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3#include <sstream>
4#include "ocos.h"
5
6bool BaseKernel::HasAttribute(const char* name) const {
7 if (info_ == nullptr) {
8 ORT_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
9 }
10 size_t size;
11 std::string out;
12 // Crashes here.
13 OrtStatus* status = api_.KernelInfoGetAttribute_string(info_, name, nullptr, &size);
14 auto r = api_.GetErrorCode(status);
15 bool has = (r == ORT_INVALID_ARGUMENT) || (r == ORT_OK);
16 if (has) {
17 api_.ReleaseStatus(status);
18 return has;
19 }
20 const char* error = api_.GetErrorMessage(status);
21 if (strstr(error, "No attribute") == error) {
22 api_.ReleaseStatus(status);
23 return false;
24 }
25 api_.ReleaseStatus(status);
26 return true;
27}
28
29OrtErrorCode BaseKernel::GetErrorCodeAndRelease(OrtStatusPtr status) {
30 if (status == nullptr) {
31 return ORT_OK;
32 }
33 auto error_code = api_.GetErrorCode(status);
34 api_.ReleaseStatus(status);
35 return error_code;
36}
37
38void BaseKernel::SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim, const std::vector<int64_t>& data) {
39 OrtValue* output = ort_.KernelContext_GetOutput(ctx, output_idx, dim.data(), dim.size());
40 int64_t * data_ptr = ort_.GetTensorMutableData<int64_t>(output);
41 for (int i = 0; i < data.size(); i++) {
42 data_ptr[i] = data[i];
43 }
44}
45
46template <>
47bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) {
48 if (info_ == nullptr) {
49 ORT_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
50 }
51
52 size_t size = 0;
53 OrtStatus* status = api_.KernelInfoGetAttribute_string(info_, name, nullptr, &size);
54
55 // The status should be a nullptr when querying for the size.
56 if (status != nullptr) {
57 api_.ReleaseStatus(status);
58 return false;
59 }
60
61 value.resize(size);
62 status = api_.KernelInfoGetAttribute_string(info_, name, &value[0], &size);
63 if (GetErrorCodeAndRelease(status) != ORT_OK) {
64 return false;
65 }
66 value.resize(size - 1);
67
68 return true;
69}
70
71template <>
72bool BaseKernel::TryToGetAttribute(const char* name, int64_t& value) {
73 if (info_ == nullptr) {
74 ORT_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
75 }
76
77 return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(info_, name, &value)) == ORT_OK;
78}
79
80template <>
81bool BaseKernel::TryToGetAttribute(const char* name, float& value) {
82 if (info_ == nullptr) {
83 ORT_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
84 }
85
86 return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_float(info_, name, &value)) == ORT_OK;
87}
88
89template <>
90bool BaseKernel::TryToGetAttribute(const char* name, bool& value) {
91 if (info_ == nullptr) {
92 ORT_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
93 }
94
95 int64_t origin_value = 0;
96 if (GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(info_, name, &origin_value)) != ORT_OK) {
97 return false;
98 }
99
100 value = origin_value == 1;
101 return true;
102}
103