microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
27132ced71e5e35e3ee706398a316010a5ada1d9

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

base/ocos.cc

76lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3#include <sstream>
4#include "ocos.h"
5#include "narrow.h"
6
7OrtErrorCode BaseKernel::GetErrorCodeAndRelease(OrtStatusPtr status) const noexcept {
8 if (status == nullptr) {
9 return ORT_OK;
10 }
11 auto error_code = api_.GetErrorCode(status);
12 api_.ReleaseStatus(status);
13 return error_code;
14}
15
16void BaseKernel::SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim,
17 const std::vector<int64_t>& data) {
18 OrtValue* output = ort_.KernelContext_GetOutput(ctx, output_idx, dim.data(), dim.size());
19 int64_t* data_ptr = ort_.GetTensorMutableData<int64_t>(output);
20 for (size_t i = 0; i < data.size(); i++) {
21 data_ptr[i] = data[i];
22 }
23}
24
25template <>
26bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) const noexcept {
27 size_t size = 0;
28 OrtStatus* status = api_.KernelInfoGetAttribute_string(&info_, name, nullptr, &size);
29
30 // The status should be a nullptr when querying for the size.
31 if (status != nullptr) {
32 api_.ReleaseStatus(status);
33 return false;
34 }
35
36 value.resize(size);
37 status = api_.KernelInfoGetAttribute_string(&info_, name, &value[0], &size);
38 if (GetErrorCodeAndRelease(status) != ORT_OK) {
39 return false;
40 }
41 value.resize(size - 1);
42
43 return true;
44}
45
46template <>
47bool BaseKernel::TryToGetAttribute(const char* name, int64_t& value) const noexcept {
48 return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(&info_, name, &value)) == ORT_OK;
49}
50
51template <>
52bool BaseKernel::TryToGetAttribute(const char* name, float& value) const noexcept {
53 return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_float(&info_, name, &value)) == ORT_OK;
54}
55
56template <>
57bool BaseKernel::TryToGetAttribute(const char* name, int& value) const noexcept {
58 int64_t origin_value = 0;
59 if (GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(&info_, name, &origin_value)) != ORT_OK) {
60 return false;
61 }
62
63 value = ort_extensions::narrow<int>(origin_value);
64 return true;
65}
66
67template <>
68bool BaseKernel::TryToGetAttribute(const char* name, bool& value) const noexcept {
69 int64_t origin_value = 0;
70 if (GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(&info_, name, &origin_value)) != ORT_OK) {
71 return false;
72 }
73
74 value = origin_value == 1;
75 return true;
76}
77