microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
0169129b19715e12031e1f6378121bd671ea7ce3

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/cv2/gaussian_blur.hpp

82lines · modecode

1#include <opencv2/core.hpp>
2#include <opencv2/imgproc.hpp>
3
4
5struct KernelGaussianBlur : BaseKernel {
6 KernelGaussianBlur(const OrtApi& api) : BaseKernel(api) {
7 }
8
9 void Compute(OrtKernelContext* context) {
10 size_t input_c = ort_.KernelContext_GetInputCount(context);
11 const OrtValue* input_data = ort_.KernelContext_GetInput(context, 0);
12 const float* p_input_data = ort_.GetTensorData<float>(input_data);
13 std::int64_t ksize[] = {3, 3};
14 double sigma[] = {0., 0.};
15 if (input_c > 1) {
16 const OrtValue* input_ksize = ort_.KernelContext_GetInput(context, 1);
17 OrtTensorDimensions dim_ksize(ort_, input_ksize);
18 if (dim_ksize.size() != 1 || dim_ksize[0] != 2) {
19 ORTX_CXX_API_THROW("[GaussianBlur]: ksize shape is (2,)", ORT_INVALID_ARGUMENT);
20 }
21 std::copy_n(ort_.GetTensorData<std::int64_t>(input_ksize), 2, ksize);
22 }
23
24 if (input_c > 2) {
25 const OrtValue* input_sigma = ort_.KernelContext_GetInput(context, 2);
26 OrtTensorDimensions dim_sigma(ort_, input_sigma);
27 if (dim_sigma.size() != 1 || dim_sigma[0] != 2) {
28 ORTX_CXX_API_THROW("[GaussianBlur]: sigma shape is (2,)", ORT_INVALID_ARGUMENT);
29 }
30 std::copy_n(ort_.GetTensorData<double>(input_sigma), 2, sigma);
31 }
32
33 OrtTensorDimensions input_data_dimensions(ort_, input_data);
34
35 int n = static_cast<int>(input_data_dimensions[0]);
36 int h = static_cast<int>(input_data_dimensions[1]);
37 int w = static_cast<int>(input_data_dimensions[2]);
38 int c = static_cast<int>(input_data_dimensions[3]);
39 (void)n;
40 (void)c;
41
42 cv::Mat input_image(cv::Size(w, h), CV_32FC3, (void*)p_input_data);
43 cv::Mat output_image;
44 cv::GaussianBlur(input_image,
45 output_image,
46 cv::Size(static_cast<int>(ksize[1]), static_cast<int>(ksize[0])),
47 sigma[0], sigma[1], cv::BORDER_DEFAULT);
48
49 OrtValue* image_y = ort_.KernelContext_GetOutput(context,
50 0, input_data_dimensions.data(), input_data_dimensions.size());
51 float* p_output_image = ort_.GetTensorMutableData<float>(image_y);
52 memcpy(p_output_image, output_image.data, output_image.total() * output_image.elemSize());
53 }
54};
55
56struct CustomOpGaussianBlur : OrtW::CustomOpBase<CustomOpGaussianBlur, KernelGaussianBlur> {
57 size_t GetInputTypeCount() const {
58 return 3;
59 }
60
61 size_t GetOutputTypeCount() const {
62 return 1;
63 }
64
65 ONNXTensorElementDataType GetInputType(size_t index) const {
66 if (index == 0) {
67 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
68 } else if (index == 1) {
69 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
70 } else {
71 return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
72 }
73 }
74
75 ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
76 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
77 }
78
79 const char* GetName() const{
80 return "GaussianBlur";
81 }
82};
83