File: Float8_fnuz_cvt.h

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (64 lines) | stat: -rw-r--r-- 1,732 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#pragma once

#include <c10/util/floating_point_utils.h>

#include <cstdint>

#if defined(SYCL_LANGUAGE_VERSION)
#include <sycl/sycl.hpp>
#endif

namespace c10::detail {

/*
 * Convert a 8-bit floating-point number in either f8 E4M3FNUZ or bf8 E5M2FNUZ
 * format, in bit representation, to a 32-bit floating-point number.
 */
template <uint32_t we, uint32_t wm>
inline C10_HOST_DEVICE float fp8_fnuz_to_fp32_value(uint8_t x) {
  static_assert((we == 4 && wm == 3) || (we == 5 && wm == 2));
  constexpr uint32_t weo = 8;
  constexpr uint32_t wmo = 23;

  if (x == 0) {
    return 0;
  }

  if (x == 0x80) {
    constexpr uint32_t ifNaN = 0x7F800001;
    return fp32_from_bits(ifNaN);
  }

  uint32_t mantissa = x & ((1 << wm) - 1);
  uint32_t exponent = (x & 0x7F) >> wm;

  // subnormal input
  if (exponent == 0) {
    // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
    uint32_t renorm_shift = __clz(mantissa);
#elif defined(__SYCL_DEVICE_ONLY__)
    uint32_t renorm_shift = sycl::clz(mantissa);
#elif defined(_MSC_VER)
    unsigned long nonsign_bsr;
    _BitScanReverse(&nonsign_bsr, (unsigned long)mantissa);
    uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31;
#else
    uint32_t renorm_shift = __builtin_clz(mantissa);
#endif
    uint32_t sh = 1 + renorm_shift - (32 - wm);
    mantissa <<= sh;
    exponent += 1 - sh;
    mantissa &= ((1 << wm) - 1);
  }

  const uint32_t exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1));
  exponent += exp_low_cutoff - 1;
  mantissa <<= wmo - wm;

  uint32_t sign = x >> 7;
  uint32_t retval = (sign << 31) | (exponent << 23) | mantissa;
  return fp32_from_bits(retval);
}

} // namespace c10::detail