File: isqrt.hpp

package info (click to toggle)
primecount 7.16%2Bds-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 1,724 kB
  • sloc: cpp: 18,769; ansic: 102; makefile: 89; sh: 86
file content (145 lines) | stat: -rw-r--r-- 3,552 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
///
/// @file  isqrt.hpp
/// @brief Integer square root function
///
/// Copyright (C) 2024 Kim Walisch, <kim.walisch@gmail.com>
///
/// This file is distributed under the BSD License. See the COPYING
/// file in the top level directory.
///

#ifndef ISQRT_HPP
#define ISQRT_HPP

#include <macros.hpp>
#include <int128_t.hpp>

#include <algorithm>
#include <cmath>
#include <stdint.h>

namespace {

#if __cplusplus >= 202002L

/// C++20 compile time square root using binary search
template <typename T>
consteval T sqrt_helper(T x, T lo, T hi)
{
  if (lo == hi)
    return lo;

  const T mid = (lo + hi + 1) / 2;

  if (x / mid < mid)
    return sqrt_helper<T>(x, lo, mid - 1);
  else
    return sqrt_helper(x, mid, hi);
}

template <typename T>
consteval T ct_sqrt(T x)
{
  return sqrt_helper<T>(x, 0, x / 2 + 1);
}

#elif __cplusplus >= 201402L

/// C++14 compile time square root using binary search
template <typename T>
constexpr T sqrt_helper(T x, T lo, T hi)
{
  if (lo == hi)
    return lo;

  const T mid = (lo + hi + 1) / 2;

  if (x / mid < mid)
    return sqrt_helper<T>(x, lo, mid - 1);
  else
    return sqrt_helper(x, mid, hi);
}

template <typename T>
constexpr T ct_sqrt(T x)
{
  return sqrt_helper<T>(x, 0, x / 2 + 1);
}

#else

#define MID ((lo + hi + 1) / 2)

/// C++11 compile time square root using binary search
template <typename T>
constexpr T sqrt_helper(T x, T lo, T hi)
{
  return lo == hi ? lo : ((x / MID < MID)
      ? sqrt_helper<T>(x, lo, MID - 1) : sqrt_helper<T>(x, MID, hi));
}

template <typename T>
constexpr T ct_sqrt(T x)
{
  return sqrt_helper<T>(x, 0, x / 2 + 1);
}

#endif

template <typename T>
ALWAYS_INLINE T isqrt(T x)
{
  T s = (T) std::sqrt((double) x);

  // By using constexpr for the sqrt_max variable type it
  // is guaranteed that ct_sqrt() is evaluated at compile
  // time. Compilation will fail if the compiler fails to
  // evaluate ct_sqrt() at compile time. This is great,
  // ct_sqrt() must be evaluated at compile time otherwise
  // the runtime complexity of isqrt(x) would deteriorate
  // from O(1) to O(log2(x)).
  //
  // If sqrt_max were declared without constexpr then the
  // compiler would be free to compute ct_sqrt() either at
  // compile time or at run time e.g. GCC-11 computes
  // ct_sqrt(MAX_INT128) at compile time whereas Clang-12
  // computes ct_sqrt(MAX_INT128) at run time even at -O2.
  //
  // C++20 fixed this annoying issue by adding consteval
  // to C++. Hence if the compiler supports C++20 ct_sqrt()
  // is defined as consteval instead of constexpr. Hence
  // using C++20 ct_sqrt() will be evaluated at compile
  // time in all cases i.e. even if sqrt_max were declared
  // without constexpr.
  //
  constexpr T sqrt_max = ct_sqrt(pstd::numeric_limits<T>::max());

  // For 128-bit integers we use uint64_t as the
  // result type. For all other types we use the
  // same result type as the input type.
  using R = typename pstd::conditional<sizeof(T) / 2 == sizeof(uint64_t), uint64_t, T>::type;
  R r = (R) std::min(s, sqrt_max);

  // In my tests the first corrections were needed above
  // 10^22 where the results were off by 1. Above 10^32 the
  // first results occurred that were off by > 1. Since
  // primecount only supports numbers up to 10^31 this is
  // not an issue for us.
  if (r * (T) r > x)
  {
    do { r--; }
    while (r * (T) r > x);
  }
  // Same as (r + 1)^2 < x but overflow safe
  else if ((T) (r * 2) < x - r * (T) r)
  {
    do { r++; }
    while ((T) (r * 2) < x - r * (T) r);
  }

  return r;
}

} // namespace

#endif