microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
ci_build/onnxruntime_integration/onnxruntime_v1.8.patch
243lines · modecode
| 1 | diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt |
| 2 | index ac9c62fb6..9893f703e 100644 |
| 3 | --- a/cmake/CMakeLists.txt |
| 4 | +++ b/cmake/CMakeLists.txt |
| 5 | @@ -966,6 +966,14 @@ if (onnxruntime_USE_TVM) |
| 6 | list(APPEND onnxruntime_EXTERNAL_DEPENDENCIES tvm nnvm_compiler) |
| 7 | endif() |
| 8 | |
| 9 | +# ONNXRuntime-CustomOps |
| 10 | +set(OCOS_ENABLE_CTEST OFF CACHE INTERNAL "") |
| 11 | +set(OCOS_ENABLE_STATIC_LIB ON CACHE INTERNAL "") |
| 12 | +set(OCOS_ENABLE_SPM_TOKENIZER OFF CACHE INTERNAL "") |
| 13 | +add_subdirectory(external/onnxruntime-extensions EXCLUDE_FROM_ALL) |
| 14 | +target_include_directories(ortcustomops_static PRIVATE ${RE2_INCLUDE_DIR} external/json/include) |
| 15 | +target_include_directories(ortcustomops PUBLIC external/onnxruntime-extensions/shared) |
| 16 | + |
| 17 | if (APPLE OR CMAKE_SYSTEM_NAME STREQUAL "Android") |
| 18 | #onnx/onnx/proto_utils.h:34:16: error: 'SetTotalBytesLimit' is deprecated: Please use the single |
| 19 | #parameter version of SetTotalBytesLimit(). The second parameter is ignored. |
| 20 | diff --git a/cmake/onnxruntime_session.cmake b/cmake/onnxruntime_session.cmake |
| 21 | index df7eebf5a..2c511005a 100644 |
| 22 | --- a/cmake/onnxruntime_session.cmake |
| 23 | +++ b/cmake/onnxruntime_session.cmake |
| 24 | @@ -16,7 +16,7 @@ if(onnxruntime_ENABLE_INSTRUMENT) |
| 25 | target_compile_definitions(onnxruntime_session PUBLIC ONNXRUNTIME_ENABLE_INSTRUMENT) |
| 26 | endif() |
| 27 | target_include_directories(onnxruntime_session PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS}) |
| 28 | -target_link_libraries(onnxruntime_session PRIVATE nlohmann_json::nlohmann_json) |
| 29 | +target_link_libraries(onnxruntime_session PRIVATE nlohmann_json::nlohmann_json ortcustomops) |
| 30 | add_dependencies(onnxruntime_session ${onnxruntime_EXTERNAL_DEPENDENCIES}) |
| 31 | set_target_properties(onnxruntime_session PROPERTIES FOLDER "ONNXRuntime") |
| 32 | if (onnxruntime_USE_CUDA) |
| 33 | diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h |
| 34 | index b28c44613..b932d55ce 100644 |
| 35 | --- a/include/onnxruntime/core/session/onnxruntime_c_api.h |
| 36 | +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h |
| 37 | @@ -1277,6 +1277,11 @@ struct OrtApi { |
| 38 | */ |
| 39 | ORT_API2_STATUS(KernelInfoGetAttributeArray_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, |
| 40 | _Out_ int64_t* out, _Inout_ size_t* size); |
| 41 | + |
| 42 | + /** |
| 43 | + * Enable custom operators in ORT CustomOps: https://github.com/microsoft/onnxruntime-extensions.git |
| 44 | + */ |
| 45 | + ORT_API2_STATUS(EnableOrtCustomOps, _Inout_ OrtSessionOptions* options); |
| 46 | }; |
| 47 | |
| 48 | /* |
| 49 | diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h |
| 50 | index d85ecd776..d0e9fe6a3 100644 |
| 51 | --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h |
| 52 | +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h |
| 53 | @@ -308,6 +308,7 @@ struct SessionOptions : Base<OrtSessionOptions> { |
| 54 | |
| 55 | SessionOptions& EnableProfiling(const ORTCHAR_T* profile_file_prefix); |
| 56 | SessionOptions& DisableProfiling(); |
| 57 | + SessionOptions& EnableOrtCustomOps(); |
| 58 | |
| 59 | SessionOptions& EnableMemPattern(); |
| 60 | SessionOptions& DisableMemPattern(); |
| 61 | diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h |
| 62 | index 64199ac6c..8ee32e4a1 100644 |
| 63 | --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h |
| 64 | +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h |
| 65 | @@ -435,6 +435,11 @@ inline SessionOptions& SessionOptions::DisableProfiling() { |
| 66 | return *this; |
| 67 | } |
| 68 | |
| 69 | +inline SessionOptions& SessionOptions::EnableOrtCustomOps() { |
| 70 | + ThrowOnError(GetApi().EnableOrtCustomOps(p_)); |
| 71 | + return *this; |
| 72 | +} |
| 73 | + |
| 74 | inline SessionOptions& SessionOptions::EnableMemPattern() { |
| 75 | ThrowOnError(GetApi().EnableMemPattern(p_)); |
| 76 | return *this; |
| 77 | diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc |
| 78 | index 12174163e..1fb57d39c 100644 |
| 79 | --- a/onnxruntime/core/session/onnxruntime_c_api.cc |
| 80 | +++ b/onnxruntime/core/session/onnxruntime_c_api.cc |
| 81 | @@ -35,6 +35,9 @@ |
| 82 | #include "abi_session_options_impl.h" |
| 83 | #include "core/framework/TensorSeq.h" |
| 84 | #include "core/platform/ort_mutex.h" |
| 85 | + |
| 86 | +#include "ortcustomops.h" |
| 87 | + |
| 88 | #ifdef USE_CUDA |
| 89 | #include "core/providers/cuda/cuda_provider_factory.h" |
| 90 | #endif |
| 91 | @@ -403,6 +406,13 @@ ORT_API_STATUS_IMPL(OrtApis::RegisterCustomOpsLibrary, _Inout_ OrtSessionOptions |
| 92 | API_IMPL_END |
| 93 | } |
| 94 | |
| 95 | +ORT_API_STATUS_IMPL(OrtApis::EnableOrtCustomOps, _Inout_ OrtSessionOptions* options) { |
| 96 | + API_IMPL_BEGIN |
| 97 | + |
| 98 | + return RegisterCustomOps(options, OrtGetApiBase()); |
| 99 | + API_IMPL_END |
| 100 | +} |
| 101 | + |
| 102 | namespace { |
| 103 | // provider either model_path, or modal_data + model_data_length. |
| 104 | static ORT_STATUS_PTR CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, |
| 105 | @@ -2123,6 +2133,7 @@ static constexpr OrtApi ort_api_1_to_8 = { |
| 106 | // Version 8 - In development, feel free to add/remove/rearrange here |
| 107 | &OrtApis::KernelInfoGetAttributeArray_float, |
| 108 | &OrtApis::KernelInfoGetAttributeArray_int64, |
| 109 | + &OrtApis::EnableOrtCustomOps, |
| 110 | }; |
| 111 | |
| 112 | // Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other) |
| 113 | diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h |
| 114 | index f19b1d729..0dccf457a 100644 |
| 115 | --- a/onnxruntime/core/session/ort_apis.h |
| 116 | +++ b/onnxruntime/core/session/ort_apis.h |
| 117 | @@ -263,4 +263,6 @@ ORT_API_STATUS_IMPL(SetCurrentGpuDeviceId, _In_ int device_id); |
| 118 | ORT_API_STATUS_IMPL(GetCurrentGpuDeviceId, _In_ int* device_id); |
| 119 | ORT_API_STATUS_IMPL(KernelInfoGetAttributeArray_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out, _Inout_ size_t* size); |
| 120 | ORT_API_STATUS_IMPL(KernelInfoGetAttributeArray_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out, _Inout_ size_t* size); |
| 121 | + |
| 122 | +ORT_API_STATUS_IMPL(EnableOrtCustomOps, _Inout_ OrtSessionOptions* options); |
| 123 | } // namespace OrtApis |
| 124 | diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc |
| 125 | index 021636a16..5e34519a2 100644 |
| 126 | --- a/onnxruntime/test/shared_lib/test_inference.cc |
| 127 | +++ b/onnxruntime/test/shared_lib/test_inference.cc |
| 128 | @@ -172,6 +172,8 @@ static constexpr PATH_TYPE VARIED_INPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/f |
| 129 | static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/foo_bar_1.onnx"); |
| 130 | static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_bar_2.onnx"); |
| 131 | static constexpr PATH_TYPE CUSTOM_OP_MODEL_WITH_ATTRIBUTES_URI = TSTR("testdata/foo_bar_3.onnx"); |
| 132 | +static constexpr PATH_TYPE ORT_CUSTOM_OPS_MODEL_URI = TSTR("testdata/custom_op_string_lower.onnx"); |
| 133 | +static constexpr PATH_TYPE ORT_CUSTOM_OPS_MODEL_URI_2 = TSTR("testdata/custom_op_negpos.onnx"); |
| 134 | |
| 135 | #ifdef ENABLE_LANGUAGE_INTEROP_OPS |
| 136 | static constexpr PATH_TYPE PYOP_FLOAT_MODEL_URI = TSTR("testdata/pyop_1.onnx"); |
| 137 | @@ -265,6 +267,91 @@ TEST(CApiTest, custom_op_handler) { |
| 138 | #endif |
| 139 | } |
| 140 | |
| 141 | +TEST(CApiTest, test_enable_ort_customops_negpos) { |
| 142 | + |
| 143 | + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); |
| 144 | + auto allocator = onnxruntime::make_unique<MockedOrtAllocator>(); |
| 145 | + |
| 146 | + // Create Inputs |
| 147 | + std::vector<Ort::Value> ort_inputs; |
| 148 | + std::vector<float> input_data = {-1.1f, 2.2f, 4.4f, -5.5f}; |
| 149 | + std::vector<int64_t> input_dims = {2, 2}; |
| 150 | + ort_inputs.emplace_back(Ort::Value::CreateTensor<float>(info, const_cast<float*>(input_data.data()), input_data.size(), input_dims.data(), input_dims.size())); |
| 151 | + |
| 152 | + // Create Session with ORT CustomOps |
| 153 | + Ort::SessionOptions session_options; |
| 154 | + session_options.EnableOrtCustomOps(); |
| 155 | + Ort::Session session(*ort_env, ORT_CUSTOM_OPS_MODEL_URI_2, session_options); |
| 156 | + |
| 157 | + // Create Input and Output Names |
| 158 | + std::vector<const char*> input_names = {"X"}; |
| 159 | + const char* output_names[] = {"out0", "out1"}; |
| 160 | + |
| 161 | + // Run Session |
| 162 | + std::vector<Ort::Value> ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names, countof(output_names)); |
| 163 | + |
| 164 | + // Validate Results |
| 165 | + ASSERT_EQ(ort_outputs.size(), 2u); |
| 166 | + |
| 167 | + std::vector<int64_t> out_dims = {2, 2}; |
| 168 | + std::vector<float> values_out0 = {-1.1f, 0.0f, 0.0f, -5.5f}; |
| 169 | + auto type_info = ort_outputs[0].GetTensorTypeAndShapeInfo(); |
| 170 | + ASSERT_EQ(type_info.GetShape(), out_dims); |
| 171 | + size_t total_len = type_info.GetElementCount(); |
| 172 | + ASSERT_EQ(values_out0.size(), total_len); |
| 173 | + |
| 174 | + float* f = ort_outputs[0].GetTensorMutableData<float>(); |
| 175 | + for (size_t i = 0; i != total_len; ++i) { |
| 176 | + ASSERT_EQ(values_out0[i], f[i]); |
| 177 | + } |
| 178 | +} |
| 179 | + |
| 180 | +TEST(CApiTest, test_enable_ort_customops_stringlower) { |
| 181 | + |
| 182 | + auto allocator = onnxruntime::make_unique<MockedOrtAllocator>(); |
| 183 | + |
| 184 | + // Create Inputs |
| 185 | + std::vector<Ort::Value> ort_inputs; |
| 186 | + std::string input_data{"HI, This is ENGINEER from Microsoft."}; |
| 187 | + const char* const input_strings[] = {input_data.c_str()}; |
| 188 | + std::vector<int64_t> input_dims = {1, 1}; |
| 189 | + |
| 190 | + Ort::Value input_tensor = Ort::Value::CreateTensor(allocator.get(), input_dims.data(), input_dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING); |
| 191 | + input_tensor.FillStringTensor(input_strings, 1U); |
| 192 | + ort_inputs.push_back(std::move(input_tensor)); |
| 193 | + |
| 194 | + // Create Session with ORT CustomOps |
| 195 | + Ort::SessionOptions session_options; |
| 196 | + session_options.EnableOrtCustomOps(); |
| 197 | + Ort::Session session(*ort_env, ORT_CUSTOM_OPS_MODEL_URI, session_options); |
| 198 | + |
| 199 | + // Create Input and Output Names |
| 200 | + std::vector<const char*> input_names = {"input_1"}; |
| 201 | + const char* output_names[] = {"customout"}; |
| 202 | + |
| 203 | + // Run Session |
| 204 | + std::vector<Ort::Value> ort_outputs = session.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names, countof(output_names)); |
| 205 | + |
| 206 | + // Validate Results |
| 207 | + ASSERT_EQ(ort_outputs.size(), 1u); |
| 208 | + |
| 209 | + std::vector<int64_t> out_dims = {1, 1}; |
| 210 | + auto type_info = ort_outputs[0].GetTensorTypeAndShapeInfo(); |
| 211 | + ASSERT_EQ(type_info.GetShape(), out_dims); |
| 212 | + ASSERT_EQ(type_info.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING); |
| 213 | + |
| 214 | + std::string output_data{"hi, this is engineer from microsoft."}; |
| 215 | + auto expected_string = output_data.c_str(); |
| 216 | + size_t expected_string_len = strlen(expected_string); |
| 217 | + auto data_length = ort_outputs[0].GetStringTensorDataLength(); |
| 218 | + ASSERT_EQ(expected_string_len, data_length); |
| 219 | + |
| 220 | + std::string result(data_length, '\0'); |
| 221 | + std::vector<size_t> offsets(type_info.GetElementCount()); |
| 222 | + ort_outputs[0].GetStringTensorContent((void*)result.data(), data_length, offsets.data(), offsets.size()); |
| 223 | + ASSERT_STREQ(result.c_str(), expected_string); |
| 224 | +} |
| 225 | + |
| 226 | //test custom op which accepts float and double as inputs |
| 227 | TEST(CApiTest, varied_input_custom_op_handler) { |
| 228 | std::vector<Input> inputs(2); |
| 229 | diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc |
| 230 | index 523a2fc6f..970ca1552 100644 |
| 231 | --- a/onnxruntime/wasm/api.cc |
| 232 | +++ b/onnxruntime/wasm/api.cc |
| 233 | @@ -22,6 +22,10 @@ Ort::Session* OrtCreateSession(void* data, size_t data_length) { |
| 234 | Ort::SessionOptions session_options; |
| 235 | session_options.SetLogId("onnxruntime"); |
| 236 | |
| 237 | + // Enable ORT CustomOps |
| 238 | + // TODO: add condition check here to enable ORT CustomOps |
| 239 | + session_options.EnableOrtCustomOps(); |
| 240 | + |
| 241 | #if !defined(__EMSCRIPTEN_PTHREADS__) |
| 242 | // must disable thread pool when WebAssembly multi-threads support is disabled. |
| 243 | session_options.SetIntraOpNumThreads(1); |
| 244 | |