microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
b5ba84a185c5c85c436e744c40ac9a8cbd9d3f1f

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

test/shared_test/test_ortops.cc

113lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "onnxruntime_cxx_api.h"
5#include "gtest/gtest.h"
6#include "ocos.h"
7
8#include "test_kernel.hpp"
9
10
11struct Input {
12 const char* name = nullptr;
13 std::vector<int64_t> dims;
14 std::vector<float> values;
15};
16
17void RunSession(Ort::Session& session_object,
18 const std::vector<Input>& inputs,
19 const char* output_name,
20 const std::vector<int64_t>& dims_y,
21 const std::vector<int32_t>& values_y) {
22
23 std::vector<Ort::Value> ort_inputs;
24 std::vector<const char*> input_names;
25
26 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
27
28 for (size_t i = 0; i < inputs.size(); i++) {
29 input_names.emplace_back(inputs[i].name);
30 ort_inputs.emplace_back(Ort::Value::CreateTensor<float>(memory_info,
31 const_cast<float*>(inputs[i].values.data()), inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size()));
32 }
33
34 std::vector<Ort::Value> ort_outputs;
35 ort_outputs = session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), &output_name, 1);
36 ASSERT_EQ(ort_outputs.size(), 1u);
37 auto output_tensor = &ort_outputs[0];
38
39 auto type_info = output_tensor->GetTensorTypeAndShapeInfo();
40 ASSERT_EQ(type_info.GetShape(), dims_y);
41 size_t total_len = type_info.GetElementCount();
42 ASSERT_EQ(values_y.size(), total_len);
43
44 int32_t* f = output_tensor->GetTensorMutableData<int32_t>();
45 for (size_t i = 0; i != total_len; ++i) {
46 ASSERT_EQ(values_y[i], f[i]);
47 }
48}
49
50void TestInference(Ort::Env& env, const ORTCHAR_T* model_uri,
51 const std::vector<Input>& inputs,
52 const char* output_name,
53 const std::vector<int64_t>& expected_dims_y,
54 const std::vector<int32_t>& expected_values_y,
55 const char* custom_op_library_filename) {
56 Ort::SessionOptions session_options;
57 void* handle = nullptr;
58 if (custom_op_library_filename) {
59 Ort::ThrowOnError(Ort::GetApi().RegisterCustomOpsLibrary((OrtSessionOptions*)session_options, custom_op_library_filename, &handle));
60 }
61
62 // if session creation passes, model loads fine
63 Ort::Session session(env, model_uri, session_options);
64
65 // Now run
66 RunSession(session,
67 inputs,
68 output_name,
69 expected_dims_y,
70 expected_values_y);
71}
72
73static CustomOpOne op_1st;
74static CustomOpTwo op_2nd;
75
76TEST(utils, test_ort_case) {
77
78 auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
79 std::cout << "Running custom op inference" << std::endl;
80
81 std::vector<Input> inputs(2);
82 inputs[0].name = "input_1";
83 inputs[0].dims = {3, 5};
84 inputs[0].values = {1.1f, 2.2f, 3.3f, 4.4f, 5.5f,
85 6.6f, 7.7f, 8.8f, 9.9f, 10.0f,
86 11.1f, 12.2f, 13.3f, 14.4f, 15.5f};
87 inputs[1].name = "input_2";
88 inputs[1].dims = {3, 5};
89 inputs[1].values = {15.5f, 14.4f, 13.3f, 12.2f, 11.1f,
90 10.0f, 9.9f, 8.8f, 7.7f, 6.6f,
91 5.5f, 4.4f, 3.3f, 2.2f, 1.1f};
92
93 // prepare expected inputs and outputs
94 std::vector<int64_t> expected_dims_y = {3, 5};
95 std::vector<int32_t> expected_values_y =
96 {17, 17, 17, 17, 17,
97 17, 18, 18, 18, 17,
98 17, 17, 17, 17, 17};
99
100#if defined(_WIN32)
101 const char lib_name[] = "ortcustomops.dll";
102 const ORTCHAR_T model_path[] = L"data\\custom_op_test.onnx";
103#elif defined(__APPLE__)
104 const char lib_name[] = "libortcustomops.dylib";
105 const ORTCHAR_T model_path[] = "data/custom_op_test.onnx";
106#else
107 const char lib_name[] = "./libortcustomops.so";
108 const ORTCHAR_T model_path[] = "data/custom_op_test.onnx";
109#endif
110 AddExternalCustomOp(&op_1st);
111 AddExternalCustomOp(&op_2nd);
112 TestInference(*ort_env, model_path, inputs, "output", expected_dims_y, expected_values_y, lib_name);
113}
114