microsoft/onnxruntime-extensions

Public

mirrored from https://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v0.7.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/math/dlib/inverse.hpp

55lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5
6#include <dlib/matrix.h>
7#include "ocos.h"
8
9struct KernelInverse : BaseKernel {
10 KernelInverse(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
11 }
12
13 void Compute(OrtKernelContext* context) {
14 // Setup inputs
15 const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
16 const float* X = ort_.GetTensorData<float>(input_X);
17
18 // Setup output
19 OrtTensorDimensions dimensions(ort_, input_X);
20 if (dimensions.size() != 2) {
21 throw std::runtime_error("Only 2-d matrix supported.");
22 }
23
24 OrtValue* output0 = ort_.KernelContext_GetOutput(
25 context, 0, dimensions.data(), dimensions.size());
26 float* out0 = ort_.GetTensorMutableData<float>(output0);
27
28 dlib::matrix<float> dm_x(dimensions[0], dimensions[1]);
29 std::copy(X, X + dm_x.size(), dm_x.begin());
30 dlib::matrix<float> dm = dlib::inv(dm_x);
31 memcpy(out0, dm.steal_memory().get(), dm_x.size() * sizeof(float));
32 }
33};
34
35struct CustomOpInverse : OrtW::CustomOpBase<CustomOpInverse, KernelInverse> {
36 const char* GetName() const {
37 return "Inverse";
38 }
39
40 size_t GetInputTypeCount() const {
41 return 1;
42 }
43
44 ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
45 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
46 }
47
48 size_t GetOutputTypeCount() const {
49 return 1;
50 }
51
52 ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
53 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
54 }
55};
56