microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
0169129b19715e12031e1f6378121bd671ea7ce3

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/cv2/super_resolution_preprocess.cc

91lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "super_resolution_preprocess.hpp"
5#include "string_utils.h"
6
7#include <opencv2/core.hpp>
8#include <opencv2/imgproc.hpp>
9#include <opencv2/imgcodecs.hpp>
10
11#include <cstdint>
12
13KernelSuperResolutionPreProcess::KernelSuperResolutionPreProcess(const OrtApi& api) : BaseKernel(api) {}
14
15void KernelSuperResolutionPreProcess::Compute(OrtKernelContext* context) {
16 // Setup inputs
17 const OrtValue* const inputs = ort_.KernelContext_GetInput(context, 0ULL);
18 OrtTensorDimensions dimensions(ort_, inputs);
19 if (dimensions.size() != 1ULL) {
20 throw std::runtime_error("Only raw image formats are supported.");
21 }
22
23 // Get data & the length
24 const uint8_t* const encoded_bgr_image_data = ort_.GetTensorData<uint8_t>(inputs);
25
26 OrtTensorTypeAndShapeInfo* const input_info = ort_.GetTensorTypeAndShape(inputs);
27 const int64_t encoded_bgr_image_data_len = ort_.GetTensorShapeElementCount(input_info);
28 ort_.ReleaseTensorTypeAndShapeInfo(input_info);
29
30 // Decode the image
31 const std::vector<int32_t> encoded_bgr_image_sizes{1, static_cast<int32_t>(encoded_bgr_image_data_len)};
32 const cv::Mat encoded_bgr_image(encoded_bgr_image_sizes, CV_8UC1,
33 const_cast<void*>(static_cast<const void*>(encoded_bgr_image_data)));
34 // OpenCV decodes images in BGR format.
35 // Ref: https://stackoverflow.com/a/44359400
36 const cv::Mat decoded_bgr_image = cv::imdecode(encoded_bgr_image, cv::IMREAD_COLOR);
37
38 cv::Mat normalized_bgr_image;
39 decoded_bgr_image.convertTo(normalized_bgr_image, CV_32F);
40
41 cv::Mat ycrcb_image;
42 cv::cvtColor(normalized_bgr_image, ycrcb_image, cv::COLOR_BGR2YCrCb);
43
44 cv::Mat channels[3];
45 cv::split(ycrcb_image, channels);
46 channels[0] /= 255.0;
47
48 // Setup output & copy to destination
49 for (int32_t i = 0; i < 3; ++i) {
50 const cv::Mat& channel = channels[i];
51 const cv::Size size = channel.size();
52
53 const std::vector<int64_t> output_dimensions{1LL, 1LL, size.height, size.width};
54 OrtValue* const output_value = ort_.KernelContext_GetOutput(
55 context, i, output_dimensions.data(), output_dimensions.size());
56 float* const data = ort_.GetTensorMutableData<float>(output_value);
57 memcpy(data, channel.data, channel.total() * channel.elemSize());
58 }
59}
60
61const char* CustomOpSuperResolutionPreProcess::GetName() const {
62 return "SuperResolutionPreProcess";
63}
64
65size_t CustomOpSuperResolutionPreProcess::GetInputTypeCount() const {
66 return 1;
67}
68
69ONNXTensorElementDataType CustomOpSuperResolutionPreProcess::GetInputType(size_t index) const {
70 switch (index) {
71 case 0:
72 return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
73 default:
74 ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
75 }
76}
77
78size_t CustomOpSuperResolutionPreProcess::GetOutputTypeCount() const {
79 return 3;
80}
81
82ONNXTensorElementDataType CustomOpSuperResolutionPreProcess::GetOutputType(size_t index) const {
83 switch (index) {
84 case 0:
85 case 1:
86 case 2:
87 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
88 default:
89 ORTX_CXX_API_THROW(MakeString("Unexpected output index ", index), ORT_INVALID_ARGUMENT);
90 }
91}
92