microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
5d53f91f11c643be1228af6384fd81a518502d3a

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/vision/decode_image.cc

41lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "decode_image.hpp"
5
6#include <opencv2/imgcodecs.hpp>
7
8namespace ort_extensions {
9
10void KernelDecodeImage::Compute(OrtKernelContext* context) {
11 // Setup inputs
12 const OrtValue* const inputs = ort_.KernelContext_GetInput(context, 0ULL);
13 OrtTensorDimensions dimensions(ort_, inputs);
14 if (dimensions.size() != 1ULL) {
15 ORTX_CXX_API_THROW("[DecodeImage]: Raw image bytes with 1D shape expected.", ORT_INVALID_ARGUMENT);
16 }
17
18 OrtTensorTypeAndShapeInfo* input_info = ort_.GetTensorTypeAndShape(inputs);
19 const int64_t encoded_image_data_len = ort_.GetTensorShapeElementCount(input_info);
20 ort_.ReleaseTensorTypeAndShapeInfo(input_info);
21
22 // Decode the image
23 const std::vector<int32_t> encoded_image_sizes{1, static_cast<int32_t>(encoded_image_data_len)};
24 const void* encoded_image_data = ort_.GetTensorData<uint8_t>(inputs); // uint8 data
25 const cv::Mat encoded_image(encoded_image_sizes, CV_8UC1, const_cast<void*>(encoded_image_data));
26 const cv::Mat decoded_image = cv::imdecode(encoded_image, cv::IMREAD_COLOR);
27
28 if (decoded_image.data == nullptr) {
29 ORTX_CXX_API_THROW("[DecodeImage] Invalid input. Failed to decode image.", ORT_INVALID_ARGUMENT);
30 };
31
32 // Setup output & copy to destination
33 const cv::Size decoded_image_size = decoded_image.size();
34 const int64_t colors = decoded_image.elemSize(); // == 3 as it's BGR
35
36 const std::vector<int64_t> output_dims{decoded_image_size.height, decoded_image_size.width, colors};
37 OrtValue* output_value = ort_.KernelContext_GetOutput(context, 0, output_dims.data(), output_dims.size());
38 uint8_t* decoded_image_data = ort_.GetTensorMutableData<uint8_t>(output_value);
39 memcpy(decoded_image_data, decoded_image.data, decoded_image_size.height * decoded_image_size.width * colors);
40}
41} // namespace ort_extensions
42