1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
|
#pragma once
#include <c10/util/floating_point_utils.h>
#include <cstdint>
#if defined(SYCL_LANGUAGE_VERSION)
#include <sycl/sycl.hpp>
#endif
namespace c10::detail {
/*
* Convert a 8-bit floating-point number in either f8 E4M3FNUZ or bf8 E5M2FNUZ
* format, in bit representation, to a 32-bit floating-point number.
*/
template <uint32_t we, uint32_t wm>
inline C10_HOST_DEVICE float fp8_fnuz_to_fp32_value(uint8_t x) {
static_assert((we == 4 && wm == 3) || (we == 5 && wm == 2));
constexpr uint32_t weo = 8;
constexpr uint32_t wmo = 23;
if (x == 0) {
return 0;
}
if (x == 0x80) {
constexpr uint32_t ifNaN = 0x7F800001;
return fp32_from_bits(ifNaN);
}
uint32_t mantissa = x & ((1 << wm) - 1);
uint32_t exponent = (x & 0x7F) >> wm;
// subnormal input
if (exponent == 0) {
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
uint32_t renorm_shift = __clz(mantissa);
#elif defined(__SYCL_DEVICE_ONLY__)
uint32_t renorm_shift = sycl::clz(mantissa);
#elif defined(_MSC_VER)
unsigned long nonsign_bsr;
_BitScanReverse(&nonsign_bsr, (unsigned long)mantissa);
uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31;
#else
uint32_t renorm_shift = __builtin_clz(mantissa);
#endif
uint32_t sh = 1 + renorm_shift - (32 - wm);
mantissa <<= sh;
exponent += 1 - sh;
mantissa &= ((1 << wm) - 1);
}
const uint32_t exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1));
exponent += exp_low_cutoff - 1;
mantissa <<= wmo - wm;
uint32_t sign = x >> 7;
uint32_t retval = (sign << 31) | (exponent << 23) | mantissa;
return fp32_from_bits(retval);
}
} // namespace c10::detail
|