microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
operators/text/masked_fill.cc
80lines · modecode
| 1 | // Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | // Licensed under the MIT License. |
| 3 | |
| 4 | #include "masked_fill.hpp" |
| 5 | #include "string_tensor.h" |
| 6 | #include <vector> |
| 7 | #include <locale> |
| 8 | #include <codecvt> |
| 9 | #include <algorithm> |
| 10 | |
| 11 | |
| 12 | KernelMaskedFill::KernelMaskedFill(const OrtApi& api, const OrtKernelInfo* /*info*/) : BaseKernel(api) { |
| 13 | } |
| 14 | |
| 15 | void KernelMaskedFill::Compute(OrtKernelContext* context) { |
| 16 | // Setup inputs |
| 17 | const OrtValue* input_value = ort_.KernelContext_GetInput(context, 0); |
| 18 | const OrtValue* input_mask = ort_.KernelContext_GetInput(context, 1); |
| 19 | |
| 20 | OrtTensorDimensions value_dimensions(ort_, input_value); |
| 21 | OrtTensorDimensions mask_dimensions(ort_, input_mask); |
| 22 | |
| 23 | if (!(value_dimensions.IsScalar() || value_dimensions.IsVector())) { |
| 24 | ORTX_CXX_API_THROW("[MaskedFill]: the dimension of input value should be vector or scalar.", ORT_INVALID_ARGUMENT); |
| 25 | } |
| 26 | |
| 27 | if (value_dimensions != mask_dimensions) { |
| 28 | ORTX_CXX_API_THROW("[MaskedFill]: the dimension of input value and mask should be same.", ORT_INVALID_ARGUMENT); |
| 29 | } |
| 30 | |
| 31 | std::vector<std::string> value; |
| 32 | const bool * mask = nullptr; |
| 33 | |
| 34 | GetTensorMutableDataString(api_, ort_, context, input_value, value); |
| 35 | mask = ort_.GetTensorData<bool>(input_mask); |
| 36 | |
| 37 | std::vector<std::string> result; |
| 38 | std::vector<int64_t> result_dimension; |
| 39 | |
| 40 | for (size_t i = 0; i < value.size(); i++) { |
| 41 | if (!mask[i]) { |
| 42 | continue; |
| 43 | } |
| 44 | |
| 45 | result.push_back(value[i]); |
| 46 | } |
| 47 | result_dimension.push_back(result.size()); |
| 48 | |
| 49 | OrtValue* output = ort_.KernelContext_GetOutput(context, 0, result_dimension.data(), result_dimension.size()); |
| 50 | |
| 51 | FillTensorDataString(api_, ort_, context, result, output); |
| 52 | } |
| 53 | |
| 54 | void* CustomOpMaskedFill::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const { |
| 55 | return CreateKernelImpl(api, info); |
| 56 | }; |
| 57 | |
| 58 | const char* CustomOpMaskedFill::GetName() const { return "MaskedFill"; }; |
| 59 | |
| 60 | size_t CustomOpMaskedFill::GetInputTypeCount() const { |
| 61 | return 2; |
| 62 | }; |
| 63 | |
| 64 | ONNXTensorElementDataType CustomOpMaskedFill::GetInputType(size_t index) const { |
| 65 | switch (index) { |
| 66 | case 0: |
| 67 | return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; |
| 68 | case 1: |
| 69 | return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; |
| 70 | default: |
| 71 | ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT); |
| 72 | }}; |
| 73 | |
| 74 | size_t CustomOpMaskedFill::GetOutputTypeCount() const { |
| 75 | return 1; |
| 76 | }; |
| 77 | |
| 78 | ONNXTensorElementDataType CustomOpMaskedFill::GetOutputType(size_t /*index*/) const { |
| 79 | return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; |
| 80 | }; |