microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
9edf572de9b3e5eb261ca06060e6b2e4ab4012df

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

base/ocos.cc

85lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3#include <sstream>
4#include "ocos.h"
5#include "narrow.h"
6
7OrtxStatus::operator OrtStatus*() const noexcept {
8 if (IsOk()) {
9 return nullptr;
10 }
11
12 OrtStatus* status = OrtW::CreateStatus(Message(), OrtErrorCode::ORT_RUNTIME_EXCEPTION);
13 return status;
14}
15
16OrtErrorCode BaseKernel::GetErrorCodeAndRelease(OrtStatusPtr status) const noexcept {
17 if (status == nullptr) {
18 return ORT_OK;
19 }
20 auto error_code = api_.GetErrorCode(status);
21 api_.ReleaseStatus(status);
22 return error_code;
23}
24
25void BaseKernel::SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim,
26 const std::vector<int64_t>& data) {
27 OrtValue* output = ort_.KernelContext_GetOutput(ctx, output_idx, dim.data(), dim.size());
28 int64_t* data_ptr = ort_.GetTensorMutableData<int64_t>(output);
29 for (size_t i = 0; i < data.size(); i++) {
30 data_ptr[i] = data[i];
31 }
32}
33
34template <>
35bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) const noexcept {
36 size_t size = 0;
37 OrtStatus* status = api_.KernelInfoGetAttribute_string(&info_, name, nullptr, &size);
38
39 // The status should be a nullptr when querying for the size.
40 if (status != nullptr) {
41 api_.ReleaseStatus(status);
42 return false;
43 }
44
45 value.resize(size);
46 status = api_.KernelInfoGetAttribute_string(&info_, name, &value[0], &size);
47 if (GetErrorCodeAndRelease(status) != ORT_OK) {
48 return false;
49 }
50 value.resize(size - 1);
51
52 return true;
53}
54
55template <>
56bool BaseKernel::TryToGetAttribute(const char* name, int64_t& value) const noexcept {
57 return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(&info_, name, &value)) == ORT_OK;
58}
59
60template <>
61bool BaseKernel::TryToGetAttribute(const char* name, float& value) const noexcept {
62 return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_float(&info_, name, &value)) == ORT_OK;
63}
64
65template <>
66bool BaseKernel::TryToGetAttribute(const char* name, int& value) const noexcept {
67 int64_t origin_value = 0;
68 if (GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(&info_, name, &origin_value)) != ORT_OK) {
69 return false;
70 }
71
72 value = ort_extensions::narrow<int>(origin_value);
73 return true;
74}
75
76template <>
77bool BaseKernel::TryToGetAttribute(const char* name, bool& value) const noexcept {
78 int64_t origin_value = 0;
79 if (GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(&info_, name, &origin_value)) != ORT_OK) {
80 return false;
81 }
82
83 value = origin_value == 1;
84 return true;
85}
86