microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
223400118478b0d6512ae8491292723cf61c8fe8

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

includes/onnxruntime_f16.h

238lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5
6#include "onnxruntime_c_api.h"
7#if ORT_API_VERSION >= 16
8
9#include "onnxruntime_float16.h"
10#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
11#include <cuda_bf16.h>
12#endif
13
14namespace Ort {
15namespace Custom {
16
17// MFloat16
18struct MFloat16 : onnxruntime_float16::Float16Impl<MFloat16> {
19 private:
20 constexpr explicit MFloat16(uint16_t v) noexcept { val = v; }
21
22 public:
23 using Base = onnxruntime_float16::Float16Impl<MFloat16>;
24
25 MFloat16() = default;
26
27 constexpr static MFloat16 FromBits(uint16_t v) noexcept { return MFloat16(v); }
28
29 explicit MFloat16(float v) noexcept { val = Base::ToUint16Impl(v); }
30
31 float ToFloat() const noexcept { return Base::ToFloatImpl(); }
32
33 using Base::Abs;
34 using Base::AreZero;
35 using Base::IsFinite;
36 using Base::IsInfinity;
37 using Base::IsNaN;
38 using Base::IsNaNOrZero;
39 using Base::IsNegative;
40 using Base::IsNegativeInfinity;
41 using Base::IsNormal;
42 using Base::IsPositiveInfinity;
43 using Base::IsSubnormal;
44 using Base::Negate;
45
46 explicit operator float() const noexcept { return ToFloat(); }
47
48 using Base::operator==;
49 using Base::operator!=;
50 using Base::operator<;
51};
52
53#if defined(__CUDACC__) || defined(__HIPCC__)
54#define ORTC_HOST_DEVICE __host__ __device__
55#else
56#define ORTC_HOST_DEVICE
57#endif
58
59// BFloat16
60struct BFloat16 : onnxruntime_float16::BFloat16Impl<BFloat16> {
61 using Base = onnxruntime_float16::BFloat16Impl<BFloat16>;
62
63#if defined(__HIP__)
64 ORTC_HOST_DEVICE BFloat16() = default;
65#else
66 BFloat16() = default;
67#endif
68
69 struct FromBitsT {};
70 static constexpr ORTC_HOST_DEVICE FromBitsT FromBits() noexcept { return FromBitsT(); }
71 constexpr ORTC_HOST_DEVICE BFloat16(unsigned short bits, FromBitsT) noexcept { val = bits; }
72
73 static constexpr ORTC_HOST_DEVICE BFloat16 FromBits(uint16_t bits) noexcept {
74 return BFloat16(bits, FromBits());
75 }
76
77 inline ORTC_HOST_DEVICE BFloat16(float v) noexcept {
78#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
79 val = __bfloat16_as_ushort(__float2bfloat16(v));
80#elif defined(__HIP__)
81 // We should be using memcpy in order to respect the strict aliasing rule but it fails in the HIP environment.
82 if (v != v) { // isnan
83 val = UINT16_C(0x7FC0);
84 } else {
85 union {
86 uint32_t U32;
87 float F32;
88 };
89
90 F32 = v;
91 uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
92 val = static_cast<uint16_t>((U32 + rounding_bias) >> 16);
93 }
94#else
95
96 // Use C isnan to work both in host and device
97 if (std::isnan(v)) {
98 val = kPositiveQNaNBits;
99 } else {
100 auto get_msb_half = [](float fl) {
101 uint16_t result;
102 if constexpr (onnxruntime_float16::detail::endian::native == onnxruntime_float16::detail::endian::little) {
103 std::memcpy(&result, reinterpret_cast<char*>(&fl) + sizeof(uint16_t), sizeof(uint16_t));
104 } else {
105 std::memcpy(&result, &fl, sizeof(uint16_t));
106 }
107 return result;
108 };
109
110 uint16_t upper_bits = get_msb_half(v);
111 union {
112 uint32_t U32;
113 float F32;
114 };
115 F32 = v;
116 U32 += (upper_bits & 1) + kRoundToNearest;
117 val = get_msb_half(F32);
118 }
119#endif
120 }
121
122 inline ORTC_HOST_DEVICE float ToFloat() const noexcept {
123#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
124 return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&val));
125#elif defined(__HIP__)
126 // We should be using memcpy in order to respect the strict aliasing rule but it fails in the HIP environment.
127 float result = 0;
128 uint32_t tmp = val;
129 tmp <<= 16;
130 float* tempRes = reinterpret_cast<float*>(&tmp);
131 result = *tempRes;
132 return result;
133#else
134
135 if (IsNaNHostDevice()) {
136 return std::numeric_limits<float>::quiet_NaN();
137 }
138
139 float result = 0;
140 char* const first = reinterpret_cast<char*>(&result);
141 if constexpr (onnxruntime_float16::detail::endian::native == onnxruntime_float16::detail::endian::little) {
142 char* const second = first + sizeof(uint16_t);
143 std::memcpy(second, &val, sizeof(uint16_t));
144 } else {
145 std::memcpy(first, &val, sizeof(uint16_t));
146 }
147 return result;
148#endif
149 }
150
151 static const BFloat16 NaN;
152 static const BFloat16 NegativeNaN;
153 static const BFloat16 Infinity;
154 static const BFloat16 NegativeInfinity;
155 static const BFloat16 Epsilon;
156 static const BFloat16 MinValue;
157 static const BFloat16 MaxValue;
158 static const BFloat16 Zero;
159 static const BFloat16 One;
160 static const BFloat16 MinusOne;
161
162 using Base::IsNegative;
163
164 using Base::IsNaN;
165
166 using Base::IsFinite;
167
168 using Base::IsPositiveInfinity;
169
170 using Base::IsNegativeInfinity;
171
172 using Base::IsInfinity;
173
174 using Base::IsNaNOrZero;
175
176 using Base::IsNormal;
177
178 using Base::IsSubnormal;
179
180 using Base::Abs;
181
182 using Base::Negate;
183
184 ORTC_HOST_DEVICE operator float() const noexcept { return ToFloat(); }
185
186#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
187 ORTC_HOST_DEVICE BFloat16(const __nv_bfloat16& value) { val = *reinterpret_cast<const unsigned short*>(&value); }
188 explicit ORTC_HOST_DEVICE operator __nv_bfloat16() const { return *reinterpret_cast<const __nv_bfloat16*>(&val); }
189#endif
190
191 ORTC_HOST_DEVICE bool operator==(const BFloat16& rhs) const noexcept {
192 if (IsNaNHostDevice() || rhs.IsNaNHostDevice()) {
193 // IEEE defines that NaN is not equal to anything, including itself.
194 return false;
195 }
196 return val == rhs.val;
197 }
198
199 ORTC_HOST_DEVICE bool operator!=(const BFloat16& rhs) const noexcept {
200 return !(*this == rhs);
201 }
202
203 ORTC_HOST_DEVICE bool operator<(const BFloat16& rhs) const noexcept {
204 if (IsNaNHostDevice() || rhs.IsNaNHostDevice()) {
205 // IEEE defines that NaN is unordered with respect to everything, including itself.
206 return false;
207 }
208
209 const bool left_is_negative = IsNegativeHostDevice();
210 if (left_is_negative != rhs.IsNegativeHostDevice()) {
211 // When the signs of left and right differ, we know that left is less than right if it is
212 // the negative value. The exception to this is if both values are zero, in which case IEEE
213 // says they should be equal, even if the signs differ.
214 return left_is_negative && !AreZeroHostDevice(*this, rhs);
215 }
216 return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
217 }
218
219 ORTC_HOST_DEVICE bool IsNegativeHostDevice() const noexcept {
220 return (val & kSignMask) != 0;
221 }
222
223 ORTC_HOST_DEVICE bool IsNaNHostDevice() const noexcept {
224 return static_cast<uint16_t>(val & ~kSignMask) > kPositiveInfinityBits;
225 }
226
227 ORTC_HOST_DEVICE static bool AreZeroHostDevice(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept {
228 // IEEE defines that positive and negative zero are equal, this gives us a quick equality check
229 // for two values by or'ing the private bits together and stripping the sign. They are both zero,
230 // and therefore equivalent, if the resulting value is still zero.
231 return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
232 }
233};
234
235} // namespace Custom
236} // namespace Ort
237
238#endif