microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
f74770feed077546874ed7e66d1aba9e2509fea9

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/text/string_concat.cc

60lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "string_concat.hpp"
5#include "string_tensor.h"
6#include <vector>
7#include <locale>
8#include <codecvt>
9#include <algorithm>
10
11
12KernelStringConcat::KernelStringConcat(OrtApi api) : BaseKernel(api) {
13}
14
15void KernelStringConcat::Compute(OrtKernelContext* context) {
16 // Setup inputs
17 const OrtValue* left = ort_.KernelContext_GetInput(context, 0);
18 const OrtValue* right = ort_.KernelContext_GetInput(context, 1);
19 OrtTensorDimensions left_dim(ort_, left);
20 OrtTensorDimensions right_dim(ort_, right);
21
22 if (left_dim != right_dim) {
23 throw std::runtime_error("Two input tensor should have the same dimension.");
24 }
25
26 std::vector<std::string> left_value;
27 std::vector<std::string> right_value;
28 GetTensorMutableDataString(api_, ort_, context, left, left_value);
29 GetTensorMutableDataString(api_, ort_, context, right, right_value);
30
31 // reuse left_value as output to save memory
32 for (int i = 0; i < left_value.size(); i++) {
33 left_value[i].append(right_value[i]);
34 }
35
36 OrtValue* output = ort_.KernelContext_GetOutput(context, 0, left_dim.data(), left_dim.size());
37 FillTensorDataString(api_, ort_, context, left_value, output);
38}
39
40void* CustomOpStringConcat::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
41 return new KernelStringConcat(api);
42};
43
44const char* CustomOpStringConcat::GetName() const { return "StringConcat"; };
45
46size_t CustomOpStringConcat::GetInputTypeCount() const {
47 return 2;
48};
49
50ONNXTensorElementDataType CustomOpStringConcat::GetInputType(size_t /*index*/) const {
51 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
52};
53
54size_t CustomOpStringConcat::GetOutputTypeCount() const {
55 return 1;
56};
57
58ONNXTensorElementDataType CustomOpStringConcat::GetOutputType(size_t /*index*/) const {
59 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
60};