1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
|
#pragma once
// Defines the bloat16 type (brain floating-point). This representation uses
// 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa.
#include <c10/macros/Macros.h>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <iosfwd>
#ifndef C10_EMBEDDED
#include <ostream>
#endif // C10_EMBEDDED
#if defined(__CUDACC__) && !defined(USE_ROCM)
#include <cuda_bf16.h>
#endif
#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
#if defined(CL_SYCL_LANGUAGE_VERSION)
#include <CL/sycl.hpp> // for SYCL 1.2.1
#else
#include <sycl/sycl.hpp> // for SYCL 2020
#endif
#include <ext/oneapi/bfloat16.hpp>
#endif
namespace c10 {
namespace detail {
inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) {
float res = 0;
uint32_t tmp = src;
tmp <<= 16;
#if defined(USE_ROCM)
float* tempRes;
// We should be using memcpy in order to respect the strict aliasing rule
// but it fails in the HIP environment.
tempRes = reinterpret_cast<float*>(&tmp);
res = *tempRes;
#else
std::memcpy(&res, &tmp, sizeof(tmp));
#endif
return res;
}
inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) {
uint32_t res = 0;
#if defined(USE_ROCM)
// We should be using memcpy in order to respect the strict aliasing rule
// but it fails in the HIP environment.
uint32_t* tempRes = reinterpret_cast<uint32_t*>(&src);
res = *tempRes;
#else
std::memcpy(&res, &src, sizeof(res));
#endif
return res >> 16;
}
inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) {
#if defined(USE_ROCM)
if (src != src) {
#elif defined(_MSC_VER)
if (isnan(src)) {
#else
if (std::isnan(src)) {
#endif
return UINT16_C(0x7FC0);
} else {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
union {
uint32_t U32; // NOLINT(facebook-hte-BadMemberName)
float F32; // NOLINT(facebook-hte-BadMemberName)
};
F32 = src;
uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
return static_cast<uint16_t>((U32 + rounding_bias) >> 16);
}
}
} // namespace detail
struct alignas(2) BFloat16 {
uint16_t x;
// HIP wants __host__ __device__ tag, CUDA does not
#if defined(USE_ROCM)
C10_HOST_DEVICE BFloat16() = default;
#else
BFloat16() = default;
#endif
struct from_bits_t {};
static constexpr C10_HOST_DEVICE from_bits_t from_bits() {
return from_bits_t();
}
constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t)
: x(bits) {}
/* implicit */ inline C10_HOST_DEVICE BFloat16(float value);
inline C10_HOST_DEVICE operator float() const;
#if defined(__CUDACC__) && !defined(USE_ROCM)
inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value);
explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const;
#endif
#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16& value);
explicit inline C10_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() const;
#endif
};
#ifndef C10_EMBEDDED
C10_API inline std::ostream& operator<<(
std::ostream& out,
const BFloat16& value) {
out << (float)value;
return out;
}
#endif // C10_EMBEDDED
} // namespace c10
#include <c10/util/BFloat16-inl.h> // IWYU pragma: keep
|