microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
352b2003bc6c4604f4285d64133d8bdf11549253

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

test/shared_test/test_ortops.cc

300lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "onnxruntime_cxx_api.h"
5
6#include <filesystem>
7#include "gtest/gtest.h"
8#include "ocos.h"
9#include "string_utils.h"
10#include "string_tensor.h"
11#include "test_kernel.hpp"
12
13
14const char* GetLibraryPath() {
15#if defined(_WIN32)
16 return "ortcustomops.dll";
17#elif defined(__APPLE__)
18 return "libortcustomops.dylib";
19#else
20 return "./libortcustomops.so";
21#endif
22}
23
24struct KernelOne : BaseKernel {
25 KernelOne(OrtApi api) : BaseKernel(api) {
26 }
27
28 void Compute(OrtKernelContext* context) {
29 // Setup inputs
30 const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
31 const OrtValue* input_Y = ort_.KernelContext_GetInput(context, 1);
32 const float* X = ort_.GetTensorData<float>(input_X);
33 const float* Y = ort_.GetTensorData<float>(input_Y);
34
35 // Setup output
36 OrtTensorDimensions dimensions(ort_, input_X);
37
38 OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
39 float* out = ort_.GetTensorMutableData<float>(output);
40
41 OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
42 int64_t size = ort_.GetTensorShapeElementCount(output_info);
43 ort_.ReleaseTensorTypeAndShapeInfo(output_info);
44
45 // Do computation
46 for (int64_t i = 0; i < size; i++) {
47 out[i] = X[i] + Y[i];
48 }
49 }
50};
51
52struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> {
53 void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
54 return new KernelOne(api);
55 };
56 const char* GetName() const {
57 return "CustomOpOne";
58 };
59 size_t GetInputTypeCount() const {
60 return 2;
61 };
62 ONNXTensorElementDataType GetInputType(size_t index) const {
63 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
64 };
65 size_t GetOutputTypeCount() const {
66 return 1;
67 };
68 ONNXTensorElementDataType GetOutputType(size_t index) const {
69 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
70 };
71};
72
73struct KernelTwo : BaseKernel {
74 KernelTwo(OrtApi api) : BaseKernel(api) {
75 }
76 void Compute(OrtKernelContext* context) {
77 // Setup inputs
78 const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
79 const float* X = ort_.GetTensorData<float>(input_X);
80
81 // Setup output
82 OrtTensorDimensions dimensions(ort_, input_X);
83
84 OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
85 int32_t* out = ort_.GetTensorMutableData<int32_t>(output);
86
87 OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
88 int64_t size = ort_.GetTensorShapeElementCount(output_info);
89 ort_.ReleaseTensorTypeAndShapeInfo(output_info);
90
91 // Do computation
92 for (int64_t i = 0; i < size; i++) {
93 out[i] = (int32_t)(round(X[i]));
94 }
95 }
96};
97
98struct CustomOpTwo : Ort::CustomOpBase<CustomOpTwo, KernelTwo> {
99 void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
100 return new KernelTwo(api);
101 };
102 const char* GetName() const {
103 return "CustomOpTwo";
104 };
105 size_t GetInputTypeCount() const {
106 return 1;
107 };
108 ONNXTensorElementDataType GetInputType(size_t index) const {
109 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
110 };
111 size_t GetOutputTypeCount() const {
112 return 1;
113 };
114 ONNXTensorElementDataType GetOutputType(size_t index) const {
115 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
116 };
117};
118
119template <typename T>
120void _emplace_back(Ort::MemoryInfo& memory_info, std::vector<Ort::Value>& ort_inputs, const std::vector<T>& values, const std::vector<int64_t>& dims) {
121 ort_inputs.emplace_back(Ort::Value::CreateTensor<T>(
122 memory_info, const_cast<T*>(values.data()), values.size(), dims.data(), dims.size()));
123}
124
125template <typename T>
126void _assert_eq(Ort::Value& output_tensor, const std::vector<T>& expected, size_t total_len) {
127 ASSERT_EQ(expected.size(), total_len);
128 T* f = output_tensor.GetTensorMutableData<T>();
129 for (size_t i = 0; i != total_len; ++i) {
130 ASSERT_EQ(expected[i], f[i]);
131 }
132}
133
134void GetTensorMutableDataString(const OrtApi& api, const OrtValue* value, std::vector<std::string>& output) {
135 Ort::CustomOpApi ort(api);
136 OrtTensorDimensions dimensions(ort, value);
137 size_t len = static_cast<size_t>(dimensions.Size());
138 size_t data_len;
139 Ort::ThrowOnError(api, api.GetStringTensorDataLength(value, &data_len));
140 output.resize(len);
141 std::vector<char> result(data_len + len + 1, '\0');
142 std::vector<size_t> offsets(len);
143 Ort::ThrowOnError(api, api.GetStringTensorContent(value, (void*)result.data(), data_len, offsets.data(), offsets.size()));
144 output.resize(len);
145 for (int64_t i = (int64_t)len - 1; i >= 0; --i) {
146 if (i < len - 1)
147 result[offsets[i + (int64_t)1]] = '\0';
148 output[i] = result.data() + offsets[i];
149 }
150}
151
152void RunSession(Ort::Session& session_object,
153 const std::vector<TestValue>& inputs,
154 const std::vector<TestValue>& outputs) {
155 std::vector<Ort::Value> ort_inputs;
156 std::vector<const char*> input_names;
157 std::vector<const char*> output_names;
158
159 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
160 Ort::AllocatorWithDefaultOptions allocator;
161
162 for (size_t i = 0; i < inputs.size(); i++) {
163 input_names.emplace_back(inputs[i].name);
164 switch (inputs[i].element_type) {
165 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
166 _emplace_back(memory_info, ort_inputs, inputs[i].values_float, inputs[i].dims);
167 break;
168 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
169 _emplace_back(memory_info, ort_inputs, inputs[i].values_int32, inputs[i].dims);
170 break;
171 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: {
172 Ort::Value& ort_value = ort_inputs.emplace_back(
173 Ort::Value::CreateTensor(allocator, inputs[i].dims.data(), inputs[i].dims.size(),
174 ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING));
175 for (size_t i_str = 0; i_str < inputs[i].values_string.size(); ++i_str) {
176 ort_value.FillStringTensorElement(inputs[i].values_string[i_str].c_str(), i_str);
177 }
178 } break;
179 default:
180 throw std::runtime_error(MakeString(
181 "Unable to handle input ", i, " type ", inputs[i].element_type,
182 " is not implemented yet."));
183 }
184 }
185 for (size_t index = 0; index < outputs.size(); ++index) {
186 output_names.push_back(outputs[index].name);
187 }
188
189 std::vector<Ort::Value> ort_outputs;
190 ort_outputs = session_object.Run(Ort::RunOptions{nullptr},
191 input_names.data(), ort_inputs.data(), ort_inputs.size(),
192 output_names.data(), outputs.size());
193 ASSERT_EQ(outputs.size(), ort_outputs.size());
194 for (size_t index = 0; index < outputs.size(); ++index) {
195 auto output_tensor = &ort_outputs[index];
196 const TestValue& expected = outputs[index];
197
198 auto type_info = output_tensor->GetTensorTypeAndShapeInfo();
199 ONNXTensorElementDataType output_type = type_info.GetElementType();
200 ASSERT_EQ(output_type, expected.element_type);
201 std::vector<int64_t> dimension = type_info.GetShape();
202 ASSERT_EQ(dimension, expected.dims);
203 size_t total_len = type_info.GetElementCount();
204 switch (expected.element_type) {
205 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
206 _assert_eq(*output_tensor, expected.values_float, total_len);
207 break;
208 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
209 _assert_eq(*output_tensor, expected.values_int32, total_len);
210 break;
211 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: {
212 std::vector<std::string> output_string;
213 GetTensorMutableDataString(Ort::GetApi(), *output_tensor, output_string);
214 ASSERT_EQ(expected.values_string, output_string);
215 break;
216 }
217 default:
218 throw std::runtime_error(MakeString(
219 "Unable to handle output ", index, " type ", expected.element_type,
220 " is not implemented yet."));
221 }
222 }
223}
224
225void TestInference(Ort::Env& env, const ORTCHAR_T* model_uri,
226 const std::vector<TestValue>& inputs,
227 const std::vector<TestValue>& outputs,
228 const char* custom_op_library_filename) {
229 Ort::SessionOptions session_options;
230 void* handle = nullptr;
231 if (custom_op_library_filename) {
232 Ort::ThrowOnError(Ort::GetApi().RegisterCustomOpsLibrary((OrtSessionOptions*)session_options, custom_op_library_filename, &handle));
233 }
234
235 // if session creation passes, model loads fine
236 Ort::Session session(env, model_uri, session_options);
237
238 // Now run
239 RunSession(session, inputs, outputs);
240}
241
242static CustomOpOne op_1st;
243static CustomOpTwo op_2nd;
244
245TEST(utils, test_ort_case) {
246 auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
247
248 std::vector<TestValue> inputs(2);
249 inputs[0].name = "input_1";
250 inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
251 inputs[0].dims = {3, 5};
252 inputs[0].values_float = {1.1f, 2.2f, 3.3f, 4.4f, 5.5f,
253 6.6f, 7.7f, 8.8f, 9.9f, 10.0f,
254 11.1f, 12.2f, 13.3f, 14.4f, 15.5f};
255 inputs[1].name = "input_2";
256 inputs[1].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
257 inputs[1].dims = {3, 5};
258 inputs[1].values_float = {15.5f, 14.4f, 13.3f, 12.2f, 11.1f,
259 10.0f, 9.9f, 8.8f, 7.7f, 6.6f,
260 5.5f, 4.4f, 3.3f, 2.2f, 1.1f};
261
262 // prepare expected inputs and outputs
263 std::vector<TestValue> outputs(1);
264 outputs[0].name = "output";
265 outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
266 outputs[0].dims = {3, 5};
267 outputs[0].values_int32 = {17, 17, 17, 17, 17,
268 17, 18, 18, 18, 17,
269 17, 17, 17, 17, 17};
270
271 std::filesystem::path model_path = __FILE__;
272 model_path = model_path.parent_path();
273 model_path /= "..";
274 model_path /= "data";
275 model_path /= "custom_op_test.onnx";
276 AddExternalCustomOp(&op_1st);
277 AddExternalCustomOp(&op_2nd);
278 TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
279}
280
281TEST(ustring, tensor_operator) {
282 OrtValue *tensor;
283 OrtAllocator* allocator;
284
285 const auto* api_base = OrtGetApiBase();
286 const auto* api = api_base->GetApi(ORT_API_VERSION);
287 api->GetAllocatorWithDefaultOptions(&allocator);
288 Ort::CustomOpApi custom_api(*api);
289
290 std::vector<int64_t> dim{2, 2};
291 api->CreateTensorAsOrtValue(allocator, dim.data(), dim.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, &tensor);
292
293 std::vector<ustring> input_value{ustring("test"), ustring("测试"), ustring("Test de"), ustring("🧐")};
294 FillTensorDataString(*api, custom_api, nullptr, input_value, tensor);
295
296 std::vector<ustring> output_value;
297 GetTensorMutableDataString(*api, custom_api, nullptr, tensor, output_value);
298
299 EXPECT_EQ(input_value, output_value);
300}
301