microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
base/ortx_common.h
118lines · modecode
| 1 | // Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | // Licensed under the MIT License. |
| 3 | #pragma once |
| 4 | #include <locale> |
| 5 | #include <optional> |
| 6 | #include <string> |
| 7 | #include <sstream> |
| 8 | |
| 9 | #include "string_utils.h" |
| 10 | #ifdef _WIN32 |
| 11 | #include <Windows.h> |
| 12 | #endif |
| 13 | |
| 14 | #define ORTX_RETURN_IF_ERROR(expr) \ |
| 15 | do { \ |
| 16 | auto _status = (expr); \ |
| 17 | if (_status != nullptr) { \ |
| 18 | return _status; \ |
| 19 | } \ |
| 20 | } while (0) |
| 21 | |
| 22 | template <typename T> |
| 23 | bool TryParseStringWithClassicLocale(std::string_view str, T& value) { |
| 24 | if constexpr (std::is_integral<T>::value && std::is_unsigned<T>::value) { |
| 25 | // if T is unsigned integral type, reject negative values which will wrap |
| 26 | if (!str.empty() && str[0] == '-') { |
| 27 | return false; |
| 28 | } |
| 29 | } |
| 30 | |
| 31 | // don't allow leading whitespace |
| 32 | if (!str.empty() && std::isspace(str[0], std::locale::classic())) { |
| 33 | return false; |
| 34 | } |
| 35 | |
| 36 | std::istringstream is{std::string{str}}; |
| 37 | is.imbue(std::locale::classic()); |
| 38 | T parsed_value{}; |
| 39 | |
| 40 | const bool parse_successful = |
| 41 | is >> parsed_value && |
| 42 | is.get() == std::istringstream::traits_type::eof(); // don't allow trailing characters |
| 43 | if (!parse_successful) { |
| 44 | return false; |
| 45 | } |
| 46 | |
| 47 | value = std::move(parsed_value); |
| 48 | return true; |
| 49 | } |
| 50 | |
| 51 | inline bool TryParseStringWithClassicLocale(std::string_view str, std::string& value) { |
| 52 | value = str; |
| 53 | return true; |
| 54 | } |
| 55 | |
| 56 | inline bool TryParseStringWithClassicLocale(std::string_view str, bool& value) { |
| 57 | if (str == "0" || str == "False" || str == "false") { |
| 58 | value = false; |
| 59 | return true; |
| 60 | } |
| 61 | |
| 62 | if (str == "1" || str == "True" || str == "true") { |
| 63 | value = true; |
| 64 | return true; |
| 65 | } |
| 66 | |
| 67 | return false; |
| 68 | } |
| 69 | |
| 70 | template <typename T> |
| 71 | std::optional<T> ParseEnvironmentVariable(const std::string& name) { |
| 72 | std::string buffer; |
| 73 | #ifdef _WIN32 |
| 74 | constexpr size_t kBufferSize = 32767; |
| 75 | |
| 76 | // Create buffer to hold the result |
| 77 | buffer.resize(kBufferSize, '\0'); |
| 78 | |
| 79 | // The last argument is the size of the buffer pointed to by the lpBuffer parameter, including the null-terminating character, in characters. |
| 80 | // If the function succeeds, the return value is the number of characters stored in the buffer pointed to by lpBuffer, not including the terminating null character. |
| 81 | // Therefore, If the function succeeds, kBufferSize should be larger than char_count. |
| 82 | auto char_count = GetEnvironmentVariableA(name.c_str(), buffer.data(), kBufferSize); |
| 83 | |
| 84 | if (kBufferSize > char_count) { |
| 85 | buffer.resize(char_count); |
| 86 | } else { |
| 87 | // Else either the call was failed, or the buffer wasn't large enough. |
| 88 | // TODO: Understand the reason for failure by calling GetLastError(). |
| 89 | // If it is due to the specified environment variable being found in the environment block, |
| 90 | // GetLastError() returns ERROR_ENVVAR_NOT_FOUND. |
| 91 | // For now, we assume that the environment variable is not found. |
| 92 | buffer.clear(); |
| 93 | } |
| 94 | #else |
| 95 | char* val = getenv(name.c_str()); |
| 96 | buffer = (val == nullptr) ? std::string() : std::string(val); |
| 97 | #endif |
| 98 | T parsed_value; |
| 99 | if (!TryParseStringWithClassicLocale(buffer, parsed_value)) { |
| 100 | OrtW::Exception(MakeString("Failed to parse environment variable - name: ", name, ", value: ", buffer), OrtErrorCode::ORT_FAIL); |
| 101 | } |
| 102 | return parsed_value; |
| 103 | } |
| 104 | |
| 105 | template <typename T> |
| 106 | T ParseEnvironmentVariableWithDefault(const std::string& name, const T& default_value) { |
| 107 | const auto parsed = ParseEnvironmentVariable<T>(name); |
| 108 | if (parsed.has_value()) { |
| 109 | return *parsed; |
| 110 | } |
| 111 | |
| 112 | return default_value; |
| 113 | } |
| 114 | |
| 115 | inline bool IsScalarOr1ElementVector(size_t num_dimensions, int64_t shape_size) { |
| 116 | if (num_dimensions == 0 || (num_dimensions == 1 && shape_size == 1)) return true; |
| 117 | return false; |
| 118 | } |
| 119 | |