microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
base/string_tensor.cc
57lines · modecode
| 1 | // Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | // Licensed under the MIT License. |
| 3 | #include "string_tensor.h" |
| 4 | #include "string_utils.h" |
| 5 | #include "ustring.h" |
| 6 | |
| 7 | void GetTensorMutableDataString(const OrtApi& api, const OrtW::CustomOpApi& ort, const OrtKernelContext* context, |
| 8 | const OrtValue* value, std::vector<std::string>& output) { |
| 9 | (void)context; |
| 10 | OrtTensorDimensions dimensions(ort, value); |
| 11 | size_t len = static_cast<size_t>(dimensions.Size()); |
| 12 | size_t data_len; |
| 13 | OrtW::ThrowOnError(api, api.GetStringTensorDataLength(value, &data_len)); |
| 14 | output.resize(len); |
| 15 | std::vector<char> result(data_len + len + 1, '\0'); |
| 16 | std::vector<size_t> offsets(len); |
| 17 | OrtW::ThrowOnError(api, api.GetStringTensorContent(value, (void*)result.data(), data_len, offsets.data(), offsets.size())); |
| 18 | output.resize(len); |
| 19 | for (int64_t i = (int64_t)len - 1; i >= 0; --i) { |
| 20 | if (i < static_cast<int64_t>(len) - 1) |
| 21 | result[offsets[static_cast<size_t>(i + (int64_t)1)]] = '\0'; |
| 22 | output[static_cast<size_t>(i)] = result.data() + offsets[static_cast<size_t>(i)]; |
| 23 | } |
| 24 | } |
| 25 | |
| 26 | void FillTensorDataString(const OrtApi& api, const OrtW::CustomOpApi& ort, const OrtKernelContext* context, |
| 27 | const std::vector<std::string>& value, OrtValue* output) { |
| 28 | (void)ort; |
| 29 | (void)context; |
| 30 | std::vector<const char*> temp(value.size()); |
| 31 | for (size_t i = 0; i < value.size(); ++i) { |
| 32 | temp[i] = value[i].c_str(); |
| 33 | } |
| 34 | |
| 35 | OrtW::ThrowOnError(api, api.FillStringTensor(output, temp.data(), value.size())); |
| 36 | } |
| 37 | |
| 38 | void GetTensorMutableDataString(const OrtApi& api, const OrtW::CustomOpApi& ort, const OrtKernelContext* context, |
| 39 | const OrtValue* value, std::vector<ustring>& output) { |
| 40 | std::vector<std::string> utf8_strings; |
| 41 | GetTensorMutableDataString(api, ort, context, value, utf8_strings); |
| 42 | |
| 43 | output.reserve(utf8_strings.size()); |
| 44 | for (auto& str : utf8_strings) { |
| 45 | output.emplace_back(str); |
| 46 | } |
| 47 | } |
| 48 | |
| 49 | void FillTensorDataString(const OrtApi& api, const OrtW::CustomOpApi& ort, const OrtKernelContext* context, |
| 50 | const std::vector<ustring>& value, OrtValue* output) { |
| 51 | std::vector<std::string> utf8_strings; |
| 52 | utf8_strings.reserve(value.size()); |
| 53 | for (const auto& str : value) { |
| 54 | utf8_strings.push_back(std::string(str)); |
| 55 | } |
| 56 | FillTensorDataString(api, ort, context, utf8_strings, output); |
| 57 | } |