microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
8bc8e43da10ff92869ffd485c5b9f5a497a229a6

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

base/ortx_common.h

118lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3#pragma once
4#include <locale>
5#include <optional>
6#include <string>
7#include <sstream>
8
9#include "string_utils.h"
10#ifdef _WIN32
11#include <Windows.h>
12#endif
13
14#define ORTX_RETURN_IF_ERROR(expr) \
15 do { \
16 auto _status = (expr); \
17 if (_status != nullptr) { \
18 return _status; \
19 } \
20 } while (0)
21
22template <typename T>
23bool TryParseStringWithClassicLocale(std::string_view str, T& value) {
24 if constexpr (std::is_integral<T>::value && std::is_unsigned<T>::value) {
25 // if T is unsigned integral type, reject negative values which will wrap
26 if (!str.empty() && str[0] == '-') {
27 return false;
28 }
29 }
30
31 // don't allow leading whitespace
32 if (!str.empty() && std::isspace(str[0], std::locale::classic())) {
33 return false;
34 }
35
36 std::istringstream is{std::string{str}};
37 is.imbue(std::locale::classic());
38 T parsed_value{};
39
40 const bool parse_successful =
41 is >> parsed_value &&
42 is.get() == std::istringstream::traits_type::eof(); // don't allow trailing characters
43 if (!parse_successful) {
44 return false;
45 }
46
47 value = std::move(parsed_value);
48 return true;
49}
50
51inline bool TryParseStringWithClassicLocale(std::string_view str, std::string& value) {
52 value = str;
53 return true;
54}
55
56inline bool TryParseStringWithClassicLocale(std::string_view str, bool& value) {
57 if (str == "0" || str == "False" || str == "false") {
58 value = false;
59 return true;
60 }
61
62 if (str == "1" || str == "True" || str == "true") {
63 value = true;
64 return true;
65 }
66
67 return false;
68}
69
70template <typename T>
71std::optional<T> ParseEnvironmentVariable(const std::string& name) {
72 std::string buffer;
73#ifdef _WIN32
74 constexpr size_t kBufferSize = 32767;
75
76 // Create buffer to hold the result
77 buffer.resize(kBufferSize, '\0');
78
79 // The last argument is the size of the buffer pointed to by the lpBuffer parameter, including the null-terminating character, in characters.
80 // If the function succeeds, the return value is the number of characters stored in the buffer pointed to by lpBuffer, not including the terminating null character.
81 // Therefore, If the function succeeds, kBufferSize should be larger than char_count.
82 auto char_count = GetEnvironmentVariableA(name.c_str(), buffer.data(), kBufferSize);
83
84 if (kBufferSize > char_count) {
85 buffer.resize(char_count);
86 } else {
87 // Else either the call was failed, or the buffer wasn't large enough.
88 // TODO: Understand the reason for failure by calling GetLastError().
89 // If it is due to the specified environment variable being found in the environment block,
90 // GetLastError() returns ERROR_ENVVAR_NOT_FOUND.
91 // For now, we assume that the environment variable is not found.
92 buffer.clear();
93 }
94#else
95 char* val = getenv(name.c_str());
96 buffer = (val == nullptr) ? std::string() : std::string(val);
97#endif
98 T parsed_value;
99 if (!TryParseStringWithClassicLocale(buffer, parsed_value)) {
100 OrtW::Exception(MakeString("Failed to parse environment variable - name: ", name, ", value: ", buffer), OrtErrorCode::ORT_FAIL);
101 }
102 return parsed_value;
103}
104
105template <typename T>
106T ParseEnvironmentVariableWithDefault(const std::string& name, const T& default_value) {
107 const auto parsed = ParseEnvironmentVariable<T>(name);
108 if (parsed.has_value()) {
109 return *parsed;
110 }
111
112 return default_value;
113}
114
115inline bool IsScalarOr1ElementVector(size_t num_dimensions, int64_t shape_size) {
116 if (num_dimensions == 0 || (num_dimensions == 1 && shape_size == 1)) return true;
117 return false;
118}
119