File: Half_test.cpp

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (142 lines) | stat: -rw-r--r-- 3,925 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#include <cmath>
#include <limits>
#include <vector>

#include <c10/util/Half.h>
#include <c10/util/floating_point_utils.h>
#include <c10/util/irange.h>
#include <gtest/gtest.h>

namespace {

float halfbits2float(unsigned short h) {
  unsigned sign = ((h >> 15) & 1);
  unsigned exponent = ((h >> 10) & 0x1f);
  unsigned mantissa = ((h & 0x3ff) << 13);

  if (exponent == 0x1f) { /* NaN or Inf */
    mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0);
    exponent = 0xff;
  } else if (!exponent) { /* Denorm or Zero */
    if (mantissa) {
      // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
      unsigned int msb;
      exponent = 0x71;
      do {
        msb = (mantissa & 0x400000);
        mantissa <<= 1; /* normalize */
        --exponent;
      } while (!msb);
      mantissa &= 0x7fffff; /* 1.mantissa is implicit */
    }
  } else {
    exponent += 0x70;
  }

  unsigned result_bit = (sign << 31) | (exponent << 23) | mantissa;

  return c10::detail::fp32_from_bits(result_bit);
}

unsigned short float2halfbits(float src) {
  unsigned x = c10::detail::fp32_to_bits(src);

  // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
  unsigned u = (x & 0x7fffffff), shift = 0;

  // Get rid of +NaN/-NaN case first.
  if (u > 0x7f800000) {
    return 0x7fffU;
  }

  unsigned sign = ((x >> 16) & 0x8000);

  // Get rid of +Inf/-Inf, +0/-0.
  if (u > 0x477fefff) {
    return sign | 0x7c00U;
  }
  if (u < 0x33000001) {
    return (sign | 0x0000);
  }

  unsigned exponent = ((u >> 23) & 0xff);
  unsigned mantissa = (u & 0x7fffff);

  if (exponent > 0x70) {
    shift = 13;
    exponent -= 0x70;
  } else {
    shift = 0x7e - exponent;
    exponent = 0;
    mantissa |= 0x800000;
  }
  unsigned lsb = (1 << shift);
  unsigned lsb_s1 = (lsb >> 1);
  unsigned lsb_m1 = (lsb - 1);

  // Round to nearest even.
  unsigned remainder = (mantissa & lsb_m1);
  mantissa >>= shift;
  if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
    ++mantissa;
    if (!(mantissa & 0x3ff)) {
      ++exponent;
      mantissa = 0;
    }
  }

  return (sign | (exponent << 10) | mantissa);
}
TEST(HalfConversionTest, TestPorableConversion) {
  std::vector<uint16_t> inputs = {
      0,
      0xfbff, // 1111 1011 1111 1111
      (1 << 15 | 1),
      0x7bff // 0111 1011 1111 1111
  };
  for (auto x : inputs) {
    auto target = c10::detail::fp16_ieee_to_fp32_value(x);
    EXPECT_EQ(halfbits2float(x), target)
        << "Test failed for uint16 to float " << x << "\n";
    EXPECT_EQ(
        float2halfbits(target), c10::detail::fp16_ieee_from_fp32_value(target))
        << "Test failed for float to uint16" << target << "\n";
  }
}

TEST(HalfConversion, TestNativeConversionToFloat) {
  // There are only 2**16 possible values, so test them all
  for (auto x : c10::irange(std::numeric_limits<uint16_t>::max() + 1)) {
    auto h = c10::Half(x, c10::Half::from_bits());
    auto f = halfbits2float(x);
    // NaNs are not equal to each other
    if (std::isnan(f) && std::isnan(static_cast<float>(h))) {
      continue;
    }
    EXPECT_EQ(f, static_cast<float>(h)) << "Conversion error using " << x;
  }
}

TEST(HalfConversion, TestNativeConversionToHalf) {
  auto check_conversion = [](float f) {
    auto h = c10::Half(f);
    auto h_bits = float2halfbits(f);
    // NaNs are not equal to each other, just check that half is NaN
    if (std::isnan(f)) {
      EXPECT_TRUE(std::isnan(static_cast<float>(h)));
    } else {
      EXPECT_EQ(h.x, h_bits) << "Conversion error using " << f;
    }
  };

  for (auto x : c10::irange(std::numeric_limits<uint16_t>::max() + 1)) {
    check_conversion(halfbits2float(x));
  }
  // Check a few values outside of Half range
  check_conversion(std::numeric_limits<float>::max());
  check_conversion(std::numeric_limits<float>::min());
  check_conversion(std::numeric_limits<float>::epsilon());
  check_conversion(std::numeric_limits<float>::lowest());
}

} // namespace