microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
0647ce6d14b95209220714685d83659bc0d53aad

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/math/negpos.hpp

62lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5
6#include "ocos.h"
7
8struct KernelNegPos : BaseKernel {
9 KernelNegPos(const OrtApi& api) : BaseKernel(api) {
10 }
11
12 void Compute(OrtKernelContext* context){
13 // Setup inputs
14 const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
15 const float* X = ort_.GetTensorData<float>(input_X);
16
17 // Setup output
18 OrtTensorDimensions dimensions(ort_, input_X);
19
20 OrtValue* output0 = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
21 float* out0 = ort_.GetTensorMutableData<float>(output0);
22 OrtValue* output1 = ort_.KernelContext_GetOutput(context, 1, dimensions.data(), dimensions.size());
23 float* out1 = ort_.GetTensorMutableData<float>(output1);
24
25 OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output0);
26 int64_t size = ort_.GetTensorShapeElementCount(output_info);
27 ort_.ReleaseTensorTypeAndShapeInfo(output_info);
28
29 // Do computation
30 for (int64_t i = 0; i < size; i++) {
31 if (X[i] > 0) {
32 out0[i] = 0;
33 out1[i] = X[i];
34 } else {
35 out0[i] = X[i];
36 out1[i] = 0;
37 }
38 }
39 }
40};
41
42struct CustomOpNegPos : Ort::CustomOpBase<CustomOpNegPos, KernelNegPos> {
43 const char* GetName() const{
44 return "NegPos";
45 }
46
47 size_t GetInputTypeCount() const{
48 return 1;
49 }
50
51 ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
52 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
53 }
54
55 size_t GetOutputTypeCount() const{
56 return 2;
57 }
58
59 ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
60 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
61 }
62};
63