1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
|
#pragma once
/// Defines the Float8_e4m3fnuz type (8-bit floating-point) including
/// conversions to standard C types and basic arithmetic operations. Note that
/// arithmetic operations are implemented by converting to floating point and
/// performing the operation in float32.
/// Binary configuration remains the same as Float8_e4m3fn:
/// s eeee mmm
/// 1 sign bit
/// 4 exponent bits
/// 3 mantissa bits
/// The key differences versus Float8_e4m3fn are:
/// bias = 8
/// no infinities or negative zero
/// NaN only when sign bit is 1, rest all 0s
///
/// Implementation based on the paper https://arxiv.org/pdf/2206.02915.pdf and
/// the existing Float8_e4m3fn implementation.
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/floating_point_utils.h>
#include <type_traits>
#if defined(__cplusplus)
#include <cstdint>
#elif !defined(__OPENCL_VERSION__)
#include <math.h>
#include <stdint.h>
#endif
#include <iosfwd>
#include <ostream>
namespace c10 {
namespace detail {
/*
* Convert a 32-bit floating-point number in IEEE single-precision format to a
* 8-bit floating-point number in fp8 E4M3FNUZ format, in bit representation.
*/
inline C10_HOST_DEVICE uint8_t fp8e4m3fnuz_from_fp32_value(float f) {
/*
* Binary representation of 256.0f, which is the first value not representable
* (i.e. the first value which would overflow in to the sign bit, resulting in
* a NaN) in fp8e4m3fnuz range:
* 1 0000 000 - fp8e4m3fnuz
* 0 10000111 00000000000000000000000 - fp32
*/
constexpr uint32_t fnuz_max = UINT32_C(0x87) << 23;
/*
* A mask for converting fp32 numbers lower than fp8e4m3fnuz normal range
* into denorm representation
* magic number: ((127 - 8) + (23 - 3) + 1)
*/
constexpr uint32_t denorm_mask = UINT32_C(0x8C) << 23;
uint32_t f_bits = fp32_to_bits(f);
uint32_t result = 0u;
/*
* Extract the sign of the input number into the high bit of the 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const uint32_t sign = f_bits & UINT32_C(0x80000000);
/*
* Set sign bit to 0
*/
f_bits ^= sign;
if (f_bits >= fnuz_max) {
// NaN -- sign bit set to 1, rest 0s.
return 0x80;
}
if (f_bits < (UINT32_C(0x78) << 23) /* 2^-7 in float32 */) {
// Input exponent is less than -7, the smallest e4m3fnuz exponent, so the
// number will become subnormal.
f_bits = fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
result = static_cast<uint8_t>(f_bits - denorm_mask);
if (result == 0) {
// fnuz types don't have negative zero.
return 0;
}
} else {
// resulting mantissa is odd
uint8_t mant_odd = (f_bits >> 20) & 1;
// update exponent, rounding bias part 1
f_bits += ((uint32_t)(8 - 127) << 23) + 0x7FFFF;
// rounding bias part 2
f_bits += mant_odd;
// take the bits!
result = static_cast<uint8_t>(f_bits >> 20);
}
result |= sign >> 24;
return result;
}
} // namespace detail
struct alignas(1) Float8_e4m3fnuz {
uint8_t x;
struct from_bits_t {};
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
return from_bits_t();
}
Float8_e4m3fnuz() = default;
constexpr C10_HOST_DEVICE Float8_e4m3fnuz(uint8_t bits, from_bits_t)
: x(bits) {}
inline C10_HOST_DEVICE Float8_e4m3fnuz(float value);
inline C10_HOST_DEVICE operator float() const;
inline C10_HOST_DEVICE bool isnan() const;
};
C10_API inline std::ostream& operator<<(
std::ostream& out,
const Float8_e4m3fnuz& value) {
out << (float)value;
return out;
}
} // namespace c10
#include <c10/util/Float8_e4m3fnuz-inl.h> // IWYU pragma: keep
|