microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
d8581da434a333e573cdbc51b9558142203c9c8c

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/math/dlib/inverse.hpp

56lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5
6#include <dlib/matrix.h>
7#include "ocos.h"
8
9
10struct KernelInverse : BaseKernel {
11 KernelInverse(const OrtApi& api) : BaseKernel(api) {
12 }
13
14 void Compute(OrtKernelContext* context) {
15 // Setup inputs
16 const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
17 const float* X = ort_.GetTensorData<float>(input_X);
18
19 // Setup output
20 OrtTensorDimensions dimensions(ort_, input_X);
21 if (dimensions.size() != 2) {
22 throw std::runtime_error("Only 2-d matrix supported.");
23 }
24
25 OrtValue* output0 = ort_.KernelContext_GetOutput(
26 context, 0, dimensions.data(), dimensions.size());
27 float* out0 = ort_.GetTensorMutableData<float>(output0);
28
29 dlib::matrix<float> dm_x(dimensions[0], dimensions[1]);
30 std::copy(X, X + dm_x.size(), dm_x.begin());
31 dlib::matrix<float> dm = dlib::inv(dm_x);
32 memcpy(out0, dm.steal_memory().get(), dm_x.size() * sizeof(float));
33 }
34};
35
36struct CustomOpInverse : OrtW::CustomOpBase<CustomOpInverse, KernelInverse> {
37 const char* GetName() const {
38 return "Inverse";
39 }
40
41 size_t GetInputTypeCount() const {
42 return 1;
43 }
44
45 ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
46 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
47 }
48
49 size_t GetOutputTypeCount() const {
50 return 1;
51 }
52
53 ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
54 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
55 }
56};
57