1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
|
// clang-format off
#include <c10/util/BFloat16.h>
#include <c10/util/BFloat16-math.h>
#include <c10/util/irange.h>
// clang-format on
#include <gtest/gtest.h>
namespace {
float float_from_bytes(uint32_t sign, uint32_t exponent, uint32_t fraction) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
uint32_t bytes;
bytes = 0;
bytes |= sign;
bytes <<= 8;
bytes |= exponent;
bytes <<= 23;
bytes |= fraction;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
float res;
std::memcpy(&res, &bytes, sizeof(res));
return res;
}
TEST(BFloat16Conversion, FloatToBFloat16AndBack) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
float in[100];
for (const auto i : c10::irange(100)) {
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers)
in[i] = i + 1.25;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
c10::BFloat16 bfloats[100];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
float out[100];
for (const auto i : c10::irange(100)) {
bfloats[i].x = c10::detail::bits_from_f32(in[i]);
out[i] = c10::detail::f32_from_bits(bfloats[i].x);
// The relative error should be less than 1/(2^7) since BFloat16
// has 7 bits mantissa.
EXPECT_LE(fabs(out[i] - in[i]) / in[i], 1.0 / 128);
}
}
TEST(BFloat16Conversion, FloatToBFloat16RNEAndBack) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
float in[100];
for (const auto i : c10::irange(100)) {
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers)
in[i] = i + 1.25;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
c10::BFloat16 bfloats[100];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays)
float out[100];
for (const auto i : c10::irange(100)) {
bfloats[i].x = c10::detail::round_to_nearest_even(in[i]);
out[i] = c10::detail::f32_from_bits(bfloats[i].x);
// The relative error should be less than 1/(2^7) since BFloat16
// has 7 bits mantissa.
EXPECT_LE(fabs(out[i] - in[i]) / in[i], 1.0 / 128);
}
}
TEST(BFloat16Conversion, NaN) {
float inNaN = float_from_bytes(0, 0xFF, 0x7FFFFF);
EXPECT_TRUE(std::isnan(inNaN));
c10::BFloat16 a = c10::BFloat16(inNaN);
float out = c10::detail::f32_from_bits(a.x);
EXPECT_TRUE(std::isnan(out));
}
TEST(BFloat16Conversion, Inf) {
float inInf = float_from_bytes(0, 0xFF, 0);
EXPECT_TRUE(std::isinf(inInf));
c10::BFloat16 a = c10::BFloat16(inInf);
float out = c10::detail::f32_from_bits(a.x);
EXPECT_TRUE(std::isinf(out));
}
TEST(BFloat16Conversion, SmallestDenormal) {
float in = std::numeric_limits<float>::denorm_min(); // The smallest non-zero
// subnormal number
c10::BFloat16 a = c10::BFloat16(in);
float out = c10::detail::f32_from_bits(a.x);
EXPECT_FLOAT_EQ(in, out);
}
TEST(BFloat16Math, Addition) {
// This test verifies that if only first 7 bits of float's mantissa are
// changed after addition, we should have no loss in precision.
// input bits
// S | Exponent | Mantissa
// 0 | 10000000 | 10010000000000000000000 = 3.125
float input = float_from_bytes(0, 0, 0x40480000);
// expected bits
// S | Exponent | Mantissa
// 0 | 10000001 | 10010000000000000000000 = 6.25
float expected = float_from_bytes(0, 0, 0x40c80000);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
c10::BFloat16 b;
b.x = c10::detail::bits_from_f32(input);
b = b + b;
float res = c10::detail::f32_from_bits(b.x);
EXPECT_EQ(res, expected);
}
TEST(BFloat16Math, Subtraction) {
// This test verifies that if only first 7 bits of float's mantissa are
// changed after subtraction, we should have no loss in precision.
// input bits
// S | Exponent | Mantissa
// 0 | 10000001 | 11101000000000000000000 = 7.625
float input = float_from_bytes(0, 0, 0x40f40000);
// expected bits
// S | Exponent | Mantissa
// 0 | 10000000 | 01010000000000000000000 = 2.625
float expected = float_from_bytes(0, 0, 0x40280000);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
c10::BFloat16 b;
b.x = c10::detail::bits_from_f32(input);
b = b - 5;
float res = c10::detail::f32_from_bits(b.x);
EXPECT_EQ(res, expected);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(BFloat16Math, NextAfterZero) {
const c10::BFloat16 zero{0};
auto check_nextafter =
[](c10::BFloat16 from, c10::BFloat16 to, c10::BFloat16 expected) {
c10::BFloat16 actual = std::nextafter(from, to);
// Check for bitwise equality!
ASSERT_EQ(actual.x ^ expected.x, uint16_t{0});
};
check_nextafter(zero, zero, /*expected=*/zero);
check_nextafter(zero, -zero, /*expected=*/-zero);
check_nextafter(-zero, zero, /*expected=*/zero);
check_nextafter(-zero, -zero, /*expected=*/-zero);
}
float BinaryToFloat(uint32_t bytes) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
float res;
std::memcpy(&res, &bytes, sizeof(res));
return res;
}
struct BFloat16TestParam {
uint32_t input;
uint16_t rne;
};
class BFloat16Test : public ::testing::Test,
public ::testing::WithParamInterface<BFloat16TestParam> {};
TEST_P(BFloat16Test, BFloat16RNETest) {
float value = BinaryToFloat(GetParam().input);
uint16_t rounded = c10::detail::round_to_nearest_even(value);
EXPECT_EQ(GetParam().rne, rounded);
}
INSTANTIATE_TEST_CASE_P(
BFloat16Test_Instantiation,
BFloat16Test,
::testing::Values(
BFloat16TestParam{0x3F848000, 0x3F84},
BFloat16TestParam{0x3F848010, 0x3F85},
BFloat16TestParam{0x3F850000, 0x3F85},
BFloat16TestParam{0x3F858000, 0x3F86},
BFloat16TestParam{0x3FFF8000, 0x4000}));
} // namespace
|