microsoft/onnxruntime-extensions

Public

mirrored from https://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
leca/gqa

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

base/ortx_common.h

117lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3#pragma once
4#include <locale>
5#include <optional>
6#include <string>
7#include <sstream>
8#include "string_utils.h"
9#ifdef _WIN32
10#include <Windows.h>
11#endif
12
13#define ORTX_RETURN_IF_ERROR(expr) \
14 do { \
15 auto _status = (expr); \
16 if (_status != nullptr) { \
17 return _status; \
18 } \
19 } while (0)
20
21template <typename T>
22bool TryParseStringWithClassicLocale(std::string_view str, T& value) {
23 if constexpr (std::is_integral<T>::value && std::is_unsigned<T>::value) {
24 // if T is unsigned integral type, reject negative values which will wrap
25 if (!str.empty() && str[0] == '-') {
26 return false;
27 }
28 }
29
30 // don't allow leading whitespace
31 if (!str.empty() && std::isspace(str[0], std::locale::classic())) {
32 return false;
33 }
34
35 std::istringstream is{std::string{str}};
36 is.imbue(std::locale::classic());
37 T parsed_value{};
38
39 const bool parse_successful =
40 is >> parsed_value &&
41 is.get() == std::istringstream::traits_type::eof(); // don't allow trailing characters
42 if (!parse_successful) {
43 return false;
44 }
45
46 value = std::move(parsed_value);
47 return true;
48}
49
50inline bool TryParseStringWithClassicLocale(std::string_view str, std::string& value) {
51 value = str;
52 return true;
53}
54
55inline bool TryParseStringWithClassicLocale(std::string_view str, bool& value) {
56 if (str == "0" || str == "False" || str == "false") {
57 value = false;
58 return true;
59 }
60
61 if (str == "1" || str == "True" || str == "true") {
62 value = true;
63 return true;
64 }
65
66 return false;
67}
68
69template <typename T>
70std::optional<T> ParseEnvironmentVariable(const std::string& name) {
71 std::string buffer;
72#ifdef _WIN32
73 constexpr size_t kBufferSize = 32767;
74
75 // Create buffer to hold the result
76 buffer.resize(kBufferSize, '\0');
77
78 // The last argument is the size of the buffer pointed to by the lpBuffer parameter, including the null-terminating character, in characters.
79 // If the function succeeds, the return value is the number of characters stored in the buffer pointed to by lpBuffer, not including the terminating null character.
80 // Therefore, If the function succeeds, kBufferSize should be larger than char_count.
81 auto char_count = GetEnvironmentVariableA(name.c_str(), buffer.data(), kBufferSize);
82
83 if (kBufferSize > char_count) {
84 buffer.resize(char_count);
85 } else {
86 // Else either the call was failed, or the buffer wasn't large enough.
87 // TODO: Understand the reason for failure by calling GetLastError().
88 // If it is due to the specified environment variable being found in the environment block,
89 // GetLastError() returns ERROR_ENVVAR_NOT_FOUND.
90 // For now, we assume that the environment variable is not found.
91 buffer.clear();
92 }
93#else
94 char* val = getenv(name.c_str());
95 buffer = (val == nullptr) ? std::string() : std::string(val);
96#endif
97 T parsed_value;
98 if (!TryParseStringWithClassicLocale(buffer, parsed_value)) {
99 OrtW::Exception(MakeString("Failed to parse environment variable - name: ", name, ", value: ", buffer), OrtErrorCode::ORT_FAIL);
100 }
101 return parsed_value;
102}
103
104template <typename T>
105T ParseEnvironmentVariableWithDefault(const std::string& name, const T& default_value) {
106 const auto parsed = ParseEnvironmentVariable<T>(name);
107 if (parsed.has_value()) {
108 return *parsed;
109 }
110
111 return default_value;
112}
113
114inline bool IsScalarOr1ElementVector(size_t num_dimensions, int64_t shape_size) {
115 if (num_dimensions == 0 || (num_dimensions == 1 && shape_size == 1)) return true;
116 return false;
117}
118