microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
c3145b8f52cb80bda7a3a073f73e818987196f77

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

base/string_tensor.cc

57lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3#include "string_tensor.h"
4#include "string_utils.h"
5#include "ustring.h"
6
7void GetTensorMutableDataString(const OrtApi& api, const OrtW::CustomOpApi& ort, const OrtKernelContext* context,
8 const OrtValue* value, std::vector<std::string>& output) {
9 (void)context;
10 OrtTensorDimensions dimensions(ort, value);
11 size_t len = static_cast<size_t>(dimensions.Size());
12 size_t data_len;
13 OrtW::ThrowOnError(api, api.GetStringTensorDataLength(value, &data_len));
14 output.resize(len);
15 std::vector<char> result(data_len + len + 1, '\0');
16 std::vector<size_t> offsets(len);
17 OrtW::ThrowOnError(api, api.GetStringTensorContent(value, (void*)result.data(), data_len, offsets.data(), offsets.size()));
18 output.resize(len);
19 for (int64_t i = (int64_t)len - 1; i >= 0; --i) {
20 if (i < static_cast<int64_t>(len) - 1)
21 result[offsets[static_cast<size_t>(i + (int64_t)1)]] = '\0';
22 output[static_cast<size_t>(i)] = result.data() + offsets[static_cast<size_t>(i)];
23 }
24}
25
26void FillTensorDataString(const OrtApi& api, const OrtW::CustomOpApi& ort, const OrtKernelContext* context,
27 const std::vector<std::string>& value, OrtValue* output) {
28 (void)ort;
29 (void)context;
30 std::vector<const char*> temp(value.size());
31 for (size_t i = 0; i < value.size(); ++i) {
32 temp[i] = value[i].c_str();
33 }
34
35 OrtW::ThrowOnError(api, api.FillStringTensor(output, temp.data(), value.size()));
36}
37
38void GetTensorMutableDataString(const OrtApi& api, const OrtW::CustomOpApi& ort, const OrtKernelContext* context,
39 const OrtValue* value, std::vector<ustring>& output) {
40 std::vector<std::string> utf8_strings;
41 GetTensorMutableDataString(api, ort, context, value, utf8_strings);
42
43 output.reserve(utf8_strings.size());
44 for (auto& str : utf8_strings) {
45 output.emplace_back(str);
46 }
47}
48
49void FillTensorDataString(const OrtApi& api, const OrtW::CustomOpApi& ort, const OrtKernelContext* context,
50 const std::vector<ustring>& value, OrtValue* output) {
51 std::vector<std::string> utf8_strings;
52 utf8_strings.reserve(value.size());
53 for (const auto& str : value) {
54 utf8_strings.push_back(std::string(str));
55 }
56 FillTensorDataString(api, ort, context, utf8_strings, output);
57}