// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <vector>
#include "onnxruntime_c_api.h"
//
// DEPRECATED: All new custom OPs should not use any class/struct/functions from this file.
// TODO: Remove this file once all custom OPs are migrated to the new API
//
namespace OrtW {
struct CustomOpApi {
CustomOpApi(const OrtApi& api) : api_(api) {}
template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const;
OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value) const;
size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const;
size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values,
size_t dim_values_length) const;
void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const;
template <typename T>
T* GetTensorMutableData(_Inout_ OrtValue* value) const;
template <typename T>
const T* GetTensorData(_Inout_ const OrtValue* value) const;
std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const;
void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const;
size_t KernelContext_GetInputCount(const OrtKernelContext* context) const;
const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const;
size_t KernelContext_GetOutputCount(const OrtKernelContext* context) const;
OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values,
size_t dim_count) const;
void ThrowOnError(OrtStatus* status) const {
OrtW::ThrowOnError(api_, status);
}
const OrtApi& GetOrtApi() const { return api_; }
private:
const OrtApi& api_;
};
//
// Custom OP API Inlines
//
template <>
inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
float out;
ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out));
return out;
}
template <>
inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
int64_t out;
ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out));
return out;
}
template <>
inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
size_t size = 0;
std::string out;
// Feed nullptr for the data buffer to query the true size of the string attribute
OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
if (status == nullptr) {
out.resize(size);
ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
out.resize(size - 1); // remove the terminating character '\0'
} else {
ThrowOnError(status);
}
return out;
}
template <>
inline std::vector<float> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
size_t size = 0;
std::vector<float> out;
// Feed nullptr for the data buffer to query the true size of the attribute
OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size);
if (status == nullptr) {
out.resize(size);
ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size));
} else {
ThrowOnError(status);
}
return out;
}
template <>
inline std::vector<int64_t> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
size_t size = 0;
std::vector<int64_t> out;
// Feed nullptr for the data buffer to query the true size of the attribute
OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size);
if (status == nullptr) {
out.resize(size);
ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size));
} else {
ThrowOnError(status);
}
return out;
}
inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) const {
OrtTensorTypeAndShapeInfo* out;
ThrowOnError(api_.GetTensorTypeAndShape(value, &out));
return out;
}
inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
size_t out;
ThrowOnError(api_.GetTensorShapeElementCount(info, &out));
return out;
}
inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const {
ONNXTensorElementDataType out;
ThrowOnError(api_.GetTensorElementType(info, &out));
return out;
}
inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
size_t out;
ThrowOnError(api_.GetDimensionsCount(info, &out));
return out;
}
inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) const {
ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length));
}
inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const {
ThrowOnError(api_.SetDimensions(info, dim_values, dim_count));
}
template <typename T>
inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) const {
T* data = nullptr;
ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
return data;
}
template <typename T>
inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) const {
return GetTensorMutableData<T>(const_cast<OrtValue*>(value));
}
inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const {
std::vector<int64_t> output(GetDimensionsCount(info));
GetDimensions(info, output.data(), output.size());
return output;
}
inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const {
api_.ReleaseTensorTypeAndShapeInfo(input);
}
inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) const {
size_t out;
ThrowOnError(api_.KernelContext_GetInputCount(context, &out));
return out;
}
inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const {
const OrtValue* out;
ThrowOnError(api_.KernelContext_GetInput(context, index, &out));
return out;
}
inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) const {
size_t out;
ThrowOnError(api_.KernelContext_GetOutputCount(context, &out));
return out;
}
inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index,
_In_ const int64_t* dim_values, size_t dim_count) const {
OrtValue* out;
ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
return out;
}
} // namespace of OrtW
struct BaseKernel {
BaseKernel(const OrtApi& api, const OrtKernelInfo& info) noexcept
: api_(api), info_(info), ort_(api_) {
}
template <class T>
bool TryToGetAttribute(const char* name, T& value) const noexcept;
template <class T>
T TryToGetAttributeWithDefault(const char* name, const T& default_value) const noexcept {
T result = default_value;
TryToGetAttribute(name, result);
return result;
}
void SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim,
const std::vector<int64_t>& data);
protected:
OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status) const noexcept;
const OrtApi& api_;
OrtW::CustomOpApi ort_;
const OrtKernelInfo& info_;
};
struct OrtTensorDimensions : std::vector<int64_t> {
OrtTensorDimensions() = default;
OrtTensorDimensions(const OrtW::CustomOpApi& ort, const OrtValue* value) {
OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
std::vector<int64_t>::operator=(ort.GetTensorShape(info));
ort.ReleaseTensorTypeAndShapeInfo(info);
}
int64_t Size() const {
int64_t s = 1;
for (auto it = begin(); it != end(); ++it)
s *= *it;
return s;
}
bool IsScalar() const {
return empty();
}
bool IsVector() const {
return size() == 1;
}
};microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
includes/onnxruntime_cpp_api_legacy.hpp
251lines · modepreview