1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
|
#pragma once
/// Defines the Float8_e5m2fnuz type (8-bit floating-point) including
/// conversions to standard C types and basic arithmetic operations. Note that
/// arithmetic operations are implemented by converting to floating point and
/// performing the operation in float32.
/// Binary configuration remains the same as e5m2:
/// s eeeee mm
/// 1 sign bit
/// 5 exponent bits
/// 2 mantissa bits
/// The key differences that e5m2fnuz brings are:
/// bias = 16
/// no infinities or negative zero
/// NaN only when sign bit is 1, rest all 0s
///
/// Implementation based on the paper https://arxiv.org/pdf/2206.02915.pdf and
/// the existing Float8_e4m3fn implementation.
#include <c10/macros/Macros.h>
#include <c10/util/TypeSafeSignMath.h>
#include <c10/util/floating_point_utils.h>
#if defined(__cplusplus)
#include <cstdint>
#elif !defined(__OPENCL_VERSION__)
#include <math.h>
#include <stdint.h>
#endif
#include <iosfwd>
#include <ostream>
namespace c10 {
namespace detail {
/*
* Convert a 32-bit floating-point number in IEEE single-precision format to a
* 8-bit floating-point number in fp8 E5M2 format, in bit representation.
*/
inline C10_HOST_DEVICE uint8_t fp8e5m2fnuz_from_fp32_value(float f) {
/*
* Binary representation of 65536.0f, which is the first value not
* representable (i.e. the first value which would overflow in to the sign
* bit, resulting in a NaN) in fp8e4m3fnuz range:
* 1 00000 00 - fp8e5m2fnuz
* 0 10001111 00000000000000000000000 - fp32
*/
constexpr uint32_t fnuz_max = UINT32_C(0x8F) << 23;
/*
* A mask for converting fp32 numbers lower than fp8e5m2fnuz normal range
* into denormalized representation.
* magic number: ((127 - 16) + (23 - 2) + 1)
*/
constexpr uint32_t denorm_mask = UINT32_C(0x85) << 23;
uint32_t f_bits = fp32_to_bits(f);
uint32_t result = 0u;
/*
* Extract the sign of the input number into the high bit of the 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const uint32_t sign = f_bits & UINT32_C(0x80000000);
/*
* Set sign bit to 0
*/
f_bits ^= sign;
if (f_bits >= fnuz_max) {
// NaN -- sign bit set to 1, rest 0s
return 0x80;
}
if (f_bits < (UINT32_C(0x70) << 23) /* 2^-15 in float32 */) {
// Input exponent is less than -15, the smallest e5m2fnuz exponent, so the
// number will become subnormal.
f_bits = fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
result = static_cast<uint8_t>(f_bits - denorm_mask);
if (result == 0) {
// fnuz types don't have negative zero.
return 0;
}
} else {
// resulting mantissa is odd
uint8_t mant_odd = (f_bits >> 21) & 1;
// update exponent, rounding bias part 1
f_bits += ((uint32_t)(16 - 127) << 23) + 0xFFFFF;
// rounding bias part 2
f_bits += mant_odd;
// take the bits!
result = static_cast<uint8_t>(f_bits >> 21);
}
result |= sign >> 24;
return result;
}
} // namespace detail
struct alignas(1) Float8_e5m2fnuz {
uint8_t x;
struct from_bits_t {};
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
return from_bits_t();
}
Float8_e5m2fnuz() = default;
constexpr C10_HOST_DEVICE Float8_e5m2fnuz(uint8_t bits, from_bits_t)
: x(bits) {}
inline C10_HOST_DEVICE Float8_e5m2fnuz(float value);
inline C10_HOST_DEVICE operator float() const;
inline C10_HOST_DEVICE bool isnan() const;
inline C10_HOST_DEVICE bool isinf() const;
};
C10_API inline std::ostream& operator<<(
std::ostream& out,
const Float8_e5m2fnuz& value) {
out << (float)value;
return out;
}
} // namespace c10
#include <c10/util/Float8_e5m2fnuz-inl.h> // IWYU pragma: keep
|