File: FMA.h

package info (click to toggle)
llvm-toolchain-13 1%3A13.0.1-11
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 1,418,840 kB
  • sloc: cpp: 5,290,826; ansic: 996,570; asm: 544,593; python: 188,212; objc: 72,027; lisp: 30,291; f90: 25,395; sh: 24,898; javascript: 9,780; pascal: 9,398; perl: 7,484; ml: 5,432; awk: 3,523; makefile: 2,913; xml: 953; cs: 573; fortran: 539
file content (72 lines) | stat: -rw-r--r-- 3,128 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
//===-- Common header for FMA implementations -------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIBC_UTILS_FPUTIL_GENERIC_FMA_H
#define LLVM_LIBC_UTILS_FPUTIL_GENERIC_FMA_H

#include "utils/CPP/TypeTraits.h"

namespace __llvm_libc {
namespace fputil {
namespace generic {

template <typename T>
static inline cpp::EnableIfType<cpp::IsSame<T, float>::Value, T> fma(T x, T y,
                                                                     T z) {
  // Product is exact.
  double prod = static_cast<double>(x) * static_cast<double>(y);
  double z_d = static_cast<double>(z);
  double sum = prod + z_d;
  fputil::FPBits<double> bit_prod(prod), bitz(z_d), bit_sum(sum);

  if (!(bit_sum.isInfOrNaN() || bit_sum.isZero())) {
    // Since the sum is computed in double precision, rounding might happen
    // (for instance, when bitz.exponent > bit_prod.exponent + 5, or
    // bit_prod.exponent > bitz.exponent + 40).  In that case, when we round
    // the sum back to float, double rounding error might occur.
    // A concrete example of this phenomenon is as follows:
    //   x = y = 1 + 2^(-12), z = 2^(-53)
    // The exact value of x*y + z is 1 + 2^(-11) + 2^(-24) + 2^(-53)
    // So when rounding to float, fmaf(x, y, z) = 1 + 2^(-11) + 2^(-23)
    // On the other hand, with the default rounding mode,
    //   double(x*y + z) = 1 + 2^(-11) + 2^(-24)
    // and casting again to float gives us:
    //   float(double(x*y + z)) = 1 + 2^(-11).
    //
    // In order to correct this possible double rounding error, first we use
    // Dekker's 2Sum algorithm to find t such that sum - t = prod + z exactly,
    // assuming the (default) rounding mode is round-to-the-nearest,
    // tie-to-even.  Moreover, t satisfies the condition that t < eps(sum),
    // i.e., t.exponent < sum.exponent - 52. So if t is not 0, meaning rounding
    // occurs when computing the sum, we just need to use t to adjust (any) last
    // bit of sum, so that the sticky bits used when rounding sum to float are
    // correct (when it matters).
    fputil::FPBits<double> t(
        (bit_prod.getUnbiasedExponent() >= bitz.getUnbiasedExponent())
            ? ((double(bit_sum) - double(bit_prod)) - double(bitz))
            : ((double(bit_sum) - double(bitz)) - double(bit_prod)));

    // Update sticky bits if t != 0.0 and the least (52 - 23 - 1 = 28) bits are
    // zero.
    if (!t.isZero() && ((bit_sum.getMantissa() & 0xfff'ffffULL) == 0)) {
      if (bit_sum.getSign() != t.getSign()) {
        bit_sum.setMantissa(bit_sum.getMantissa() + 1);
      } else if (bit_sum.getMantissa()) {
        bit_sum.setMantissa(bit_sum.getMantissa() - 1);
      }
    }
  }

  return static_cast<float>(static_cast<double>(bit_sum));
}

} // namespace generic
} // namespace fputil
} // namespace __llvm_libc

#endif // LLVM_LIBC_UTILS_FPUTIL_GENERIC_FMA_H