1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
|
#pragma once
/// Defines the Float8_e4m3fn type (8-bit floating-point) including conversions
/// to standard C types and basic arithmetic operations. Note that arithmetic
/// operations are implemented by converting to floating point and
/// performing the operation in float32.
/// Binary configuration:
/// s eeee mmm
/// 1 sign bit
/// 4 exponent bits
/// 3 mantissa bits
/// bias = 7
///
/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf
/// and inspired by Half implementation from pytorch/c10/util/Half.h
#include <c10/macros/Macros.h>
#include <c10/util/floating_point_utils.h>
#if defined(__cplusplus)
#include <cmath>
#include <cstdint>
#elif !defined(__OPENCL_VERSION__)
#include <math.h>
#include <stdint.h>
#endif
#ifdef _MSC_VER
#include <intrin.h>
#endif
#include <climits>
#include <iostream>
namespace c10 {
namespace detail {
/*
* Convert a 8-bit floating-point number in fp8 E4M3FN format, in bit
* representation, to a 32-bit floating-point number in IEEE single-precision
* format, in bit representation.
*
* @note The implementation doesn't use any floating-point operations.
*/
inline C10_HOST_DEVICE float fp8e4m3fn_to_fp32_value(uint8_t input) {
/*
* Extend the fp8 E4M3FN number to 32 bits and shift to the
* upper part of the 32-bit word:
* +---+----+---+-----------------------------+
* | S |EEEE|MMM|0000 0000 0000 0000 0000 0000|
* +---+----+---+-----------------------------+
* Bits 31 27-30 24-26 0-23
*
* S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
* - zero bits.
*/
const uint32_t w = (uint32_t)input << 24;
/*
* Extract the sign of the input number into the high bit of the 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const uint32_t sign = w & UINT32_C(0x80000000);
/*
* Extract mantissa and biased exponent of the input number into the bits 0-30
* of the 32-bit word:
*
* +---+----+---+-----------------------------+
* | S |EEEE|MMM|0000 0000 0000 0000 0000 0000|
* +---+----+---+-----------------------------+
* Bits 31 27-30 24-26 0-23
*/
const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
/*
* Renorm shift is the number of bits to shift mantissa left to make the
* half-precision number normalized. If the initial number is normalized, some
* of its high 5 bits (sign == 0 and 4-bit exponent) equals one. In this case
* renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note
* that if we shift denormalized nonsign by renorm_shift, the unit bit of
* mantissa will shift into exponent, turning the biased exponent into 1, and
* making mantissa normalized (i.e. without leading 1).
*/
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
uint32_t renorm_shift = __clz(nonsign);
#elif defined(__SYCL_DEVICE_ONLY__)
// Note: zero is not a supported input into `__builtin_clz`
uint32_t renorm_shift =
nonsign != 0 ? __builtin_clz(nonsign) : sizeof(uint32_t) * CHAR_BIT;
#elif defined(_MSC_VER)
unsigned long nonsign_bsr;
_BitScanReverse(&nonsign_bsr, (unsigned long)nonsign);
uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31;
#else
// Note: zero is not a supported input into `__builtin_clz`
uint32_t renorm_shift =
nonsign != 0 ? __builtin_clz(nonsign) : sizeof(uint32_t) * CHAR_BIT;
#endif
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
/*
* Iff fp8e4m3fn number has all exponent and mantissa bits set to 1,
* the addition overflows it into bit 31, and the subsequent shift turns the
* high 9 bits into 1. Thus inf_nan_mask == 0x7F800000 if the fp8e4m3fn number
* is Nan, 0x00000000 otherwise
*/
const int32_t inf_nan_mask =
((int32_t)(nonsign + 0x01000000) >> 8) & INT32_C(0x7F800000);
/*
* Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31
* into 1. Otherwise, bit 31 remains 0. The signed shift right by 31
* broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask ==
* 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h)
* 0x00000000 otherwise
*/
const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31;
/*
* 1. Shift nonsign left by renorm_shift to normalize it (if the input
* was denormal)
* 2. Shift nonsign right by 4 so the exponent (4 bits originally)
* becomes an 8-bit field and 3-bit mantissa shifts into the 3 high
* bits of the 23-bit mantissa of IEEE single-precision number.
* 3. Add 0x78 to the exponent (starting at bit 23) to compensate the
* different in exponent bias (0x7F for single-precision number less 0x07
* for fp8e4m3fn number).
* 4. Subtract renorm_shift from the exponent (starting at bit 23) to
* account for renormalization. As renorm_shift is less than 0x78, this
* can be combined with step 3.
* 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the
* input was NaN or infinity.
* 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent
* into zero if the input was zero.
* 7. Combine with the sign of the input number.
*/
uint32_t result = sign |
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
inf_nan_mask) &
~zero_mask);
return fp32_from_bits(result);
}
/*
* Convert a 32-bit floating-point number in IEEE single-precision format to a
* 8-bit floating-point number in fp8 E4M3FN format, in bit representation.
*/
inline C10_HOST_DEVICE uint8_t fp8e4m3fn_from_fp32_value(float f) {
/*
* Binary representation of 480.0f, which is the first value
* not representable in fp8e4m3fn range:
* 0 1111 111 - fp8e4m3fn
* 0 10000111 11100000000000000000000 - fp32
*/
constexpr uint32_t fp8_max = UINT32_C(1087) << 20;
/*
* A mask for converting fp32 numbers lower than fp8e4m3fn normal range
* into denorm representation
* magic number: ((127 - 7) + (23 - 3) + 1)
*/
constexpr uint32_t denorm_mask = UINT32_C(141) << 23;
uint32_t f_bits = fp32_to_bits(f);
uint8_t result = 0u;
/*
* Extract the sign of the input number into the high bit of the 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const uint32_t sign = f_bits & UINT32_C(0x80000000);
/*
* Set sign bit to 0
*/
f_bits ^= sign;
if (f_bits >= fp8_max) {
// NaN - all exponent and mantissa bits set to 1
result = 0x7f;
} else {
if (f_bits < (UINT32_C(121) << 23)) {
// Input number is smaller than 2^(-6), which is the smallest
// fp8e4m3fn normal number
f_bits =
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
result = static_cast<uint8_t>(f_bits - denorm_mask);
} else {
// resulting mantissa is odd
uint8_t mant_odd = (f_bits >> 20) & 1;
// update exponent, rounding bias part 1
f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
// rounding bias part 2
f_bits += mant_odd;
// take the bits!
result = static_cast<uint8_t>(f_bits >> 20);
}
}
result |= static_cast<uint8_t>(sign >> 24);
return result;
}
} // namespace detail
struct alignas(1) Float8_e4m3fn {
uint8_t x;
struct from_bits_t {};
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
return from_bits_t();
}
Float8_e4m3fn() = default;
constexpr C10_HOST_DEVICE Float8_e4m3fn(uint8_t bits, from_bits_t)
: x(bits) {}
inline C10_HOST_DEVICE Float8_e4m3fn(float value);
inline C10_HOST_DEVICE operator float() const;
inline C10_HOST_DEVICE bool isnan() const;
};
C10_API inline std::ostream& operator<<(
std::ostream& out,
const Float8_e4m3fn& value) {
out << (float)value;
return out;
}
} // namespace c10
#include <c10/util/Float8_e4m3fn-inl.h> // IWYU pragma: keep
|