microsoft/onnxruntime-extensions

Public

mirrored from https://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v0.4.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/text/string_concat.cc

60lines · modepreview

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "string_concat.hpp"
#include "string_tensor.h"
#include <vector>
#include <locale>
#include <codecvt>
#include <algorithm>


KernelStringConcat::KernelStringConcat(OrtApi api) : BaseKernel(api) {
}

void KernelStringConcat::Compute(OrtKernelContext* context) {
  // Setup inputs
  const OrtValue* left = ort_.KernelContext_GetInput(context, 0);
  const OrtValue* right = ort_.KernelContext_GetInput(context, 1);
  OrtTensorDimensions left_dim(ort_, left);
  OrtTensorDimensions right_dim(ort_, right);

  if (left_dim != right_dim) {
    ORT_CXX_API_THROW("Two input tensor should have the same dimension.", ORT_INVALID_ARGUMENT);
  }

  std::vector<std::string> left_value;
  std::vector<std::string> right_value;
  GetTensorMutableDataString(api_, ort_, context, left, left_value);
  GetTensorMutableDataString(api_, ort_, context, right, right_value);

  // reuse left_value as output to save memory
  for (int i = 0; i < left_value.size(); i++) {
    left_value[i].append(right_value[i]);
  }

  OrtValue* output = ort_.KernelContext_GetOutput(context, 0, left_dim.data(), left_dim.size());
  FillTensorDataString(api_, ort_, context, left_value, output);
}

void* CustomOpStringConcat::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
  return new KernelStringConcat(api);
};

const char* CustomOpStringConcat::GetName() const { return "StringConcat"; };

size_t CustomOpStringConcat::GetInputTypeCount() const {
  return 2;
};

ONNXTensorElementDataType CustomOpStringConcat::GetInputType(size_t /*index*/) const {
  return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};

size_t CustomOpStringConcat::GetOutputTypeCount() const {
  return 1;
};

ONNXTensorElementDataType CustomOpStringConcat::GetOutputType(size_t /*index*/) const {
  return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};