microsoft/onnxruntime-extensions

Public

mirrored from https://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
skottmckay/BuildInfra_AndTestImageLibs

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/cv2/super_resolution_postprocess.cc

99lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "super_resolution_postprocess.hpp"
5#include "string_utils.h"
6
7#include <opencv2/core.hpp>
8#include <opencv2/imgproc.hpp>
9#include <opencv2/imgcodecs.hpp>
10
11#include <cstdint>
12
13KernelSuperResolutionPostProcess::KernelSuperResolutionPostProcess(const OrtApi& api) : BaseKernel(api) {}
14
15void KernelSuperResolutionPostProcess::Compute(OrtKernelContext* context) {
16 // Setup inputs
17 const OrtValue* const input_y = ort_.KernelContext_GetInput(context, 0ULL);
18 const OrtValue* const input_cr = ort_.KernelContext_GetInput(context, 1ULL);
19 const OrtValue* const input_cb = ort_.KernelContext_GetInput(context, 2ULL);
20
21 const OrtTensorDimensions dimensions_y(ort_, input_y);
22 const OrtTensorDimensions dimensions_cr(ort_, input_cr);
23 const OrtTensorDimensions dimensions_cb(ort_, input_cb);
24 if ((dimensions_y.size() != 4ULL) || (dimensions_cr.size() != 4ULL) || (dimensions_cb.size() != 4ULL)) {
25 throw std::runtime_error("Expecting 3 channels y, cr, and cb.");
26 }
27
28 // Get data & the length
29 const float* const channel_y_data = ort_.GetTensorData<float>(input_y);
30 const float* const channel_cr_data = ort_.GetTensorData<float>(input_cr);
31 const float* const channel_cb_data = ort_.GetTensorData<float>(input_cb);
32
33 cv::Mat y(
34 std::vector<int32_t>{static_cast<int32_t>(dimensions_y[2]), static_cast<int32_t>(dimensions_y[3])},
35 CV_32F, const_cast<void*>(static_cast<const void*>(channel_y_data)));
36 cv::Mat cr(
37 std::vector<int32_t>{static_cast<int32_t>(dimensions_cr[2]), static_cast<int32_t>(dimensions_cr[3])},
38 CV_32F, const_cast<void*>(static_cast<const void*>(channel_cr_data)));
39 cv::Mat cb(
40 std::vector<int32_t>{static_cast<int32_t>(dimensions_cb[2]), static_cast<int32_t>(dimensions_cb[3])},
41 CV_32F, const_cast<void*>(static_cast<const void*>(channel_cb_data)));
42
43 // Scale the individual channels
44 y *= 255.0;
45 cv::resize(cr, cr, y.size(), 0, 0, cv::INTER_CUBIC);
46 cv::resize(cb, cb, y.size(), 0, 0, cv::INTER_CUBIC);
47
48 // Merge the channels
49 const cv::Mat channels[] = {y, cr, cb};
50 cv::Mat ycrcb_image;
51 cv::merge(channels, 3, ycrcb_image);
52
53 // Convert it back to BGR format
54 cv::Mat bgr_image;
55 cv::cvtColor(ycrcb_image, bgr_image, cv::COLOR_YCrCb2BGR);
56
57 // Encode it as jpg
58 std::vector<uchar> encoded_image;
59 cv::imencode(".jpg", bgr_image, encoded_image);
60
61 // Setup output & copy to destination
62 const std::vector<int64_t> output_dimensions{1LL, static_cast<int64_t>(encoded_image.size())};
63 OrtValue* const output_value = ort_.KernelContext_GetOutput(
64 context, 0, output_dimensions.data(), output_dimensions.size());
65 float* const data = ort_.GetTensorMutableData<float>(output_value);
66 memcpy(data, encoded_image.data(), encoded_image.size());
67}
68
69const char* CustomOpSuperResolutionPostProcess::GetName() const {
70 return "SuperResolutionPostProcess";
71}
72
73size_t CustomOpSuperResolutionPostProcess::GetInputTypeCount() const {
74 return 3;
75}
76
77ONNXTensorElementDataType CustomOpSuperResolutionPostProcess::GetInputType(size_t index) const {
78 switch (index) {
79 case 0:
80 case 1:
81 case 2:
82 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
83 default:
84 ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
85 }
86}
87
88size_t CustomOpSuperResolutionPostProcess::GetOutputTypeCount() const {
89 return 1;
90}
91
92ONNXTensorElementDataType CustomOpSuperResolutionPostProcess::GetOutputType(size_t index) const {
93 switch (index) {
94 case 0:
95 return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
96 default:
97 ORT_CXX_API_THROW(MakeString("Unexpected output index ", index), ORT_INVALID_ARGUMENT);
98 }
99}
100