microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
0647ce6d14b95209220714685d83659bc0d53aad

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/cv2/imdecode.hpp

85lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5
6#include <opencv2/core.hpp>
7#include <opencv2/imgproc.hpp>
8#include <opencv2/imgcodecs.hpp>
9
10#include "ocos.h"
11#include "string_utils.h"
12
13#include <cstdint>
14
15struct KernelImageDecoder : BaseKernel {
16 KernelImageDecoder(const OrtApi& api) : BaseKernel(api) {}
17
18 void Compute(OrtKernelContext* context) {
19 // Setup inputs
20 const OrtValue* const inputs = ort_.KernelContext_GetInput(context, 0ULL);
21 OrtTensorDimensions dimensions(ort_, inputs);
22 if (dimensions.size() != 1ULL) {
23 ORT_CXX_API_THROW("[ImageDecoder]: Only raw image formats are supported.", ORT_INVALID_ARGUMENT);
24 }
25
26 // Get data & the length
27 const uint8_t* const encoded_image_data = ort_.GetTensorData<uint8_t>(inputs);
28
29 OrtTensorTypeAndShapeInfo* const input_info = ort_.GetTensorTypeAndShape(inputs);
30 const int64_t encoded_image_data_len = ort_.GetTensorShapeElementCount(input_info);
31 ort_.ReleaseTensorTypeAndShapeInfo(input_info);
32
33 // Decode the image
34 const std::vector<int32_t> encoded_image_sizes{1, static_cast<int32_t>(encoded_image_data_len)};
35 const cv::Mat encoded_image(encoded_image_sizes, CV_8UC1,
36 const_cast<void*>(static_cast<const void*>(encoded_image_data)));
37 const cv::Mat decoded_image = cv::imdecode(encoded_image, cv::IMREAD_COLOR);
38
39 // Setup output & copy to destination
40 const cv::Size decoded_image_size = decoded_image.size();
41 const int64_t colors = 3;
42
43 const std::vector<int64_t> output_dimensions{decoded_image_size.height, decoded_image_size.width, colors};
44 OrtValue *const output_value = ort_.KernelContext_GetOutput(
45 context, 0, output_dimensions.data(), output_dimensions.size());
46 uint8_t* const decoded_image_data = ort_.GetTensorMutableData<uint8_t>(output_value);
47 memcpy(decoded_image_data, decoded_image.data, decoded_image.total() * decoded_image.elemSize());
48 }
49};
50
51struct CustomOpImageDecoder : Ort::CustomOpBase<CustomOpImageDecoder, KernelImageDecoder> {
52 void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
53 return new KernelImageDecoder(api);
54 }
55
56 const char* GetName() const {
57 return "ImageDecoder";
58 }
59
60 size_t GetInputTypeCount() const {
61 return 1;
62 }
63
64 ONNXTensorElementDataType GetInputType(size_t index) const {
65 switch (index) {
66 case 0:
67 return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
68 default:
69 ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
70 }
71 }
72
73 size_t GetOutputTypeCount() const {
74 return 1;
75 }
76
77 ONNXTensorElementDataType GetOutputType(size_t index) const {
78 switch (index) {
79 case 0:
80 return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
81 default:
82 ORT_CXX_API_THROW(MakeString("Unexpected output index ", index), ORT_INVALID_ARGUMENT);
83 }
84 }
85};
86