microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
d8581da434a333e573cdbc51b9558142203c9c8c

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/text/masked_fill.cc

80lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "masked_fill.hpp"
5#include "string_tensor.h"
6#include <vector>
7#include <locale>
8#include <codecvt>
9#include <algorithm>
10
11
12KernelMaskedFill::KernelMaskedFill(const OrtApi& api, const OrtKernelInfo* /*info*/) : BaseKernel(api) {
13}
14
15void KernelMaskedFill::Compute(OrtKernelContext* context) {
16 // Setup inputs
17 const OrtValue* input_value = ort_.KernelContext_GetInput(context, 0);
18 const OrtValue* input_mask = ort_.KernelContext_GetInput(context, 1);
19
20 OrtTensorDimensions value_dimensions(ort_, input_value);
21 OrtTensorDimensions mask_dimensions(ort_, input_mask);
22
23 if (!(value_dimensions.IsScalar() || value_dimensions.IsVector())) {
24 ORTX_CXX_API_THROW("[MaskedFill]: the dimension of input value should be vector or scalar.", ORT_INVALID_ARGUMENT);
25 }
26
27 if (value_dimensions != mask_dimensions) {
28 ORTX_CXX_API_THROW("[MaskedFill]: the dimension of input value and mask should be same.", ORT_INVALID_ARGUMENT);
29 }
30
31 std::vector<std::string> value;
32 const bool * mask = nullptr;
33
34 GetTensorMutableDataString(api_, ort_, context, input_value, value);
35 mask = ort_.GetTensorData<bool>(input_mask);
36
37 std::vector<std::string> result;
38 std::vector<int64_t> result_dimension;
39
40 for (size_t i = 0; i < value.size(); i++) {
41 if (!mask[i]) {
42 continue;
43 }
44
45 result.push_back(value[i]);
46 }
47 result_dimension.push_back(result.size());
48
49 OrtValue* output = ort_.KernelContext_GetOutput(context, 0, result_dimension.data(), result_dimension.size());
50
51 FillTensorDataString(api_, ort_, context, result, output);
52}
53
54void* CustomOpMaskedFill::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
55 return CreateKernelImpl(api, info);
56};
57
58const char* CustomOpMaskedFill::GetName() const { return "MaskedFill"; };
59
60size_t CustomOpMaskedFill::GetInputTypeCount() const {
61 return 2;
62};
63
64ONNXTensorElementDataType CustomOpMaskedFill::GetInputType(size_t index) const {
65 switch (index) {
66 case 0:
67 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
68 case 1:
69 return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
70 default:
71 ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
72 }};
73
74size_t CustomOpMaskedFill::GetOutputTypeCount() const {
75 return 1;
76};
77
78ONNXTensorElementDataType CustomOpMaskedFill::GetOutputType(size_t /*index*/) const {
79 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
80};
81