1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
|
#pragma once
/// Defines the Float8_e5m2 type (8-bit floating-point) including conversions
/// to standard C types and basic arithmetic operations. Note that arithmetic
/// operations are implemented by converting to floating point and
/// performing the operation in float32.
/// Binary configuration:
/// s eeeee mm
/// 1 sign bit
/// 5 exponent bits
/// 2 mantissa bits
/// bias = 15
///
/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf
/// and inspired by Half implementation from pytorch/c10/util/Half.h
#include <c10/util/Half.h>
namespace c10 {
namespace detail {
/*
* Convert a 8-bit floating-point number in fp8 E5M2 format, in bit
* representation, to a 32-bit floating-point number in IEEE single-precision
* format, in bit representation.
*
* @note The implementation doesn't use any floating-point operations.
*/
inline C10_HOST_DEVICE float fp8e5m2_to_fp32_value(uint8_t input) {
/*
* Extend the fp8 E5M2 number to 32 bits and shift to the
* upper part of the 32-bit word:
* +---+----+---+-----------------------------+
* | S |EEEEE|MM|0000 0000 0000 0000 0000 0000|
* +---+----+---+-----------------------------+
* Bits 31 26-30 24-25 0-23
*
* S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
* - zero bits.
*/
uint16_t half_representation = input;
half_representation <<= 8;
return fp16_ieee_to_fp32_value(half_representation);
}
/*
* Convert a 32-bit floating-point number in IEEE single-precision format to a
* 8-bit floating-point number in fp8 E5M2 format, in bit representation.
*/
inline C10_HOST_DEVICE uint8_t fp8e5m2_from_fp32_value(float f) {
/*
* Binary representation of fp32 infinity
* 0 11111111 00000000000000000000000
*/
constexpr uint32_t fp32_inf = UINT32_C(255) << 23;
/*
* Binary representation of 65536.0f, which is the first value
* not representable in fp8e5m2 range:
* 0 11111 00 - fp8e5m2
* 0 10001111 00000000000000000000000 - fp32
*/
constexpr uint32_t fp8_max = UINT32_C(143) << 23;
/*
* A mask for converting fp32 numbers lower than fp8e5m2 normal range
* into denorm representation
* magic number: ((127 - 15) + (23 - 2) + 1)
*/
constexpr uint32_t denorm_mask = UINT32_C(134) << 23;
uint32_t f_bits = fp32_to_bits(f);
uint8_t result = 0u;
/*
* Extract the sign of the input number into the high bit of the 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const uint32_t sign = f_bits & UINT32_C(0x80000000);
/*
* Set sign bit to 0
*/
f_bits ^= sign;
if (f_bits >= fp8_max) {
// NaN - all exponent and mantissa bits set to 1
result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C);
} else {
if (f_bits < (UINT32_C(113) << 23)) {
// Input number is smaller than 2^(-14), which is the smallest
// fp8e5m2 normal number
f_bits =
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
result = static_cast<uint8_t>(f_bits - denorm_mask);
} else {
// resulting mantissa is odd
uint32_t mant_odd = (f_bits >> 21) & 1;
// update exponent, rounding bias part 1
f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF;
// rounding bias part 2
f_bits += mant_odd;
// take the bits!
result = static_cast<uint8_t>(f_bits >> 21);
}
}
result |= static_cast<uint8_t>(sign >> 24);
return result;
}
} // namespace detail
struct alignas(1) Float8_e5m2 {
uint8_t x;
struct from_bits_t {};
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
return from_bits_t();
}
Float8_e5m2() = default;
constexpr C10_HOST_DEVICE Float8_e5m2(uint8_t bits, from_bits_t) : x(bits) {}
inline C10_HOST_DEVICE Float8_e5m2(float value);
inline C10_HOST_DEVICE operator float() const;
inline C10_HOST_DEVICE bool isnan() const;
inline C10_HOST_DEVICE bool isinf() const;
};
C10_API inline std::ostream& operator<<(
std::ostream& out,
const Float8_e5m2& value) {
out << (float)value;
return out;
}
} // namespace c10
#include <c10/util/Float8_e5m2-inl.h> // IWYU pragma: keep
|