microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
d8581da434a333e573cdbc51b9558142203c9c8c

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/text/string_concat.cc

56lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "string_concat.hpp"
5#include "string_tensor.h"
6#include <vector>
7#include <locale>
8#include <codecvt>
9#include <algorithm>
10
11
12KernelStringConcat::KernelStringConcat(const OrtApi& api) : BaseKernel(api) {
13}
14
15void KernelStringConcat::Compute(OrtKernelContext* context) {
16 // Setup inputs
17 const OrtValue* left = ort_.KernelContext_GetInput(context, 0);
18 const OrtValue* right = ort_.KernelContext_GetInput(context, 1);
19 OrtTensorDimensions left_dim(ort_, left);
20 OrtTensorDimensions right_dim(ort_, right);
21
22 if (left_dim != right_dim) {
23 ORTX_CXX_API_THROW("Two input tensor should have the same dimension.", ORT_INVALID_ARGUMENT);
24 }
25
26 std::vector<std::string> left_value;
27 std::vector<std::string> right_value;
28 GetTensorMutableDataString(api_, ort_, context, left, left_value);
29 GetTensorMutableDataString(api_, ort_, context, right, right_value);
30
31 // reuse left_value as output to save memory
32 for (size_t i = 0; i < left_value.size(); i++) {
33 left_value[i].append(right_value[i]);
34 }
35
36 OrtValue* output = ort_.KernelContext_GetOutput(context, 0, left_dim.data(), left_dim.size());
37 FillTensorDataString(api_, ort_, context, left_value, output);
38}
39
40const char* CustomOpStringConcat::GetName() const { return "StringConcat"; };
41
42size_t CustomOpStringConcat::GetInputTypeCount() const {
43 return 2;
44};
45
46ONNXTensorElementDataType CustomOpStringConcat::GetInputType(size_t /*index*/) const {
47 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
48};
49
50size_t CustomOpStringConcat::GetOutputTypeCount() const {
51 return 1;
52};
53
54ONNXTensorElementDataType CustomOpStringConcat::GetOutputType(size_t /*index*/) const {
55 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
56};