microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
9eef22cb81d762f6c093a4740c992582267a783f

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/math/dlib/inverse.hpp

64lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5
6#include <dlib/matrix.h>
7#include "ocos.h"
8
9
10struct KernelInverse : BaseKernel {
11 KernelInverse(OrtApi api) : BaseKernel(api) {
12 }
13
14 void Compute(OrtKernelContext* context) {
15 // Setup inputs
16 const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
17 const float* X = ort_.GetTensorData<float>(input_X);
18
19 // Setup output
20 OrtTensorDimensions dimensions(ort_, input_X);
21 if (dimensions.size() != 2) {
22 throw std::runtime_error("Only 2-d matrix supported.");
23 }
24
25 OrtValue* output0 = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
26 float* out0 = ort_.GetTensorMutableData<float>(output0);
27
28 OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output0);
29 int64_t size = ort_.GetTensorShapeElementCount(output_info);
30 ort_.ReleaseTensorTypeAndShapeInfo(output_info);
31
32 dlib::matrix<float> dm(dimensions[0], dimensions[1]);
33 // Do computation
34 for (int64_t i = 0; i < size; i++) {
35 out0[i] = dm(i / dimensions[1], i % dimensions[1]);
36 }
37 }
38};
39
40struct CustomOpInverse : Ort::CustomOpBase<CustomOpInverse, KernelInverse> {
41 void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
42 return new KernelInverse(api);
43 }
44
45 const char* GetName() const {
46 return "Inverse";
47 }
48
49 size_t GetInputTypeCount() const {
50 return 1;
51 }
52
53 ONNXTensorElementDataType GetInputType(size_t index) const {
54 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
55 }
56
57 size_t GetOutputTypeCount() const {
58 return 1;
59 }
60
61 ONNXTensorElementDataType GetOutputType(size_t index) const {
62 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
63 }
64};
65