microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
rel-0.7

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

base/ocos.cc

84lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3#include <sstream>
4#include "ocos.h"
5
6bool BaseKernel::HasAttribute(const char* name) const noexcept {
7 size_t size;
8 std::string out;
9 // Crashes here.
10 OrtStatus* status = api_.KernelInfoGetAttribute_string(&info_, name, nullptr, &size);
11 auto r = api_.GetErrorCode(status);
12 bool has = (r == ORT_INVALID_ARGUMENT) || (r == ORT_OK);
13 if (has) {
14 api_.ReleaseStatus(status);
15 return has;
16 }
17 const char* error = api_.GetErrorMessage(status);
18 if (strstr(error, "No attribute") == error) {
19 api_.ReleaseStatus(status);
20 return false;
21 }
22 api_.ReleaseStatus(status);
23 return true;
24}
25
26OrtErrorCode BaseKernel::GetErrorCodeAndRelease(OrtStatusPtr status) const noexcept {
27 if (status == nullptr) {
28 return ORT_OK;
29 }
30 auto error_code = api_.GetErrorCode(status);
31 api_.ReleaseStatus(status);
32 return error_code;
33}
34
35void BaseKernel::SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim,
36 const std::vector<int64_t>& data) {
37 OrtValue* output = ort_.KernelContext_GetOutput(ctx, output_idx, dim.data(), dim.size());
38 int64_t* data_ptr = ort_.GetTensorMutableData<int64_t>(output);
39 for (size_t i = 0; i < data.size(); i++) {
40 data_ptr[i] = data[i];
41 }
42}
43
44template <>
45bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) const noexcept {
46 size_t size = 0;
47 OrtStatus* status = api_.KernelInfoGetAttribute_string(&info_, name, nullptr, &size);
48
49 // The status should be a nullptr when querying for the size.
50 if (status != nullptr) {
51 api_.ReleaseStatus(status);
52 return false;
53 }
54
55 value.resize(size);
56 status = api_.KernelInfoGetAttribute_string(&info_, name, &value[0], &size);
57 if (GetErrorCodeAndRelease(status) != ORT_OK) {
58 return false;
59 }
60 value.resize(size - 1);
61
62 return true;
63}
64
65template <>
66bool BaseKernel::TryToGetAttribute(const char* name, int64_t& value) const noexcept {
67 return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(&info_, name, &value)) == ORT_OK;
68}
69
70template <>
71bool BaseKernel::TryToGetAttribute(const char* name, float& value) const noexcept {
72 return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_float(&info_, name, &value)) == ORT_OK;
73}
74
75template <>
76bool BaseKernel::TryToGetAttribute(const char* name, bool& value) const noexcept {
77 int64_t origin_value = 0;
78 if (GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(&info_, name, &origin_value)) != ORT_OK) {
79 return false;
80 }
81
82 value = origin_value == 1;
83 return true;
84}
85