microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
operators/cv2/gaussian_blur.hpp
82lines · modecode
| 1 | #include <opencv2/core.hpp> |
| 2 | #include <opencv2/imgproc.hpp> |
| 3 | |
| 4 | |
| 5 | struct KernelGaussianBlur : BaseKernel { |
| 6 | KernelGaussianBlur(const OrtApi& api) : BaseKernel(api) { |
| 7 | } |
| 8 | |
| 9 | void Compute(OrtKernelContext* context) { |
| 10 | size_t input_c = ort_.KernelContext_GetInputCount(context); |
| 11 | const OrtValue* input_data = ort_.KernelContext_GetInput(context, 0); |
| 12 | const float* p_input_data = ort_.GetTensorData<float>(input_data); |
| 13 | std::int64_t ksize[] = {3, 3}; |
| 14 | double sigma[] = {0., 0.}; |
| 15 | if (input_c > 1) { |
| 16 | const OrtValue* input_ksize = ort_.KernelContext_GetInput(context, 1); |
| 17 | OrtTensorDimensions dim_ksize(ort_, input_ksize); |
| 18 | if (dim_ksize.size() != 1 || dim_ksize[0] != 2) { |
| 19 | ORTX_CXX_API_THROW("[GaussianBlur]: ksize shape is (2,)", ORT_INVALID_ARGUMENT); |
| 20 | } |
| 21 | std::copy_n(ort_.GetTensorData<std::int64_t>(input_ksize), 2, ksize); |
| 22 | } |
| 23 | |
| 24 | if (input_c > 2) { |
| 25 | const OrtValue* input_sigma = ort_.KernelContext_GetInput(context, 2); |
| 26 | OrtTensorDimensions dim_sigma(ort_, input_sigma); |
| 27 | if (dim_sigma.size() != 1 || dim_sigma[0] != 2) { |
| 28 | ORTX_CXX_API_THROW("[GaussianBlur]: sigma shape is (2,)", ORT_INVALID_ARGUMENT); |
| 29 | } |
| 30 | std::copy_n(ort_.GetTensorData<double>(input_sigma), 2, sigma); |
| 31 | } |
| 32 | |
| 33 | OrtTensorDimensions input_data_dimensions(ort_, input_data); |
| 34 | |
| 35 | int n = static_cast<int>(input_data_dimensions[0]); |
| 36 | int h = static_cast<int>(input_data_dimensions[1]); |
| 37 | int w = static_cast<int>(input_data_dimensions[2]); |
| 38 | int c = static_cast<int>(input_data_dimensions[3]); |
| 39 | (void)n; |
| 40 | (void)c; |
| 41 | |
| 42 | cv::Mat input_image(cv::Size(w, h), CV_32FC3, (void*)p_input_data); |
| 43 | cv::Mat output_image; |
| 44 | cv::GaussianBlur(input_image, |
| 45 | output_image, |
| 46 | cv::Size(static_cast<int>(ksize[1]), static_cast<int>(ksize[0])), |
| 47 | sigma[0], sigma[1], cv::BORDER_DEFAULT); |
| 48 | |
| 49 | OrtValue* image_y = ort_.KernelContext_GetOutput(context, |
| 50 | 0, input_data_dimensions.data(), input_data_dimensions.size()); |
| 51 | float* p_output_image = ort_.GetTensorMutableData<float>(image_y); |
| 52 | memcpy(p_output_image, output_image.data, output_image.total() * output_image.elemSize()); |
| 53 | } |
| 54 | }; |
| 55 | |
| 56 | struct CustomOpGaussianBlur : OrtW::CustomOpBase<CustomOpGaussianBlur, KernelGaussianBlur> { |
| 57 | size_t GetInputTypeCount() const { |
| 58 | return 3; |
| 59 | } |
| 60 | |
| 61 | size_t GetOutputTypeCount() const { |
| 62 | return 1; |
| 63 | } |
| 64 | |
| 65 | ONNXTensorElementDataType GetInputType(size_t index) const { |
| 66 | if (index == 0) { |
| 67 | return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; |
| 68 | } else if (index == 1) { |
| 69 | return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; |
| 70 | } else { |
| 71 | return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; |
| 72 | } |
| 73 | } |
| 74 | |
| 75 | ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { |
| 76 | return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; |
| 77 | } |
| 78 | |
| 79 | const char* GetName() const{ |
| 80 | return "GaussianBlur"; |
| 81 | } |
| 82 | }; |
| 83 | |