File: fast_math.cpp

package info (click to toggle)
nmodl 0.6-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 5,992 kB
  • sloc: cpp: 28,492; javascript: 9,841; yacc: 2,804; python: 1,967; lex: 1,674; xml: 181; sh: 136; ansic: 37; makefile: 18; pascal: 7
file content (126 lines) | stat: -rw-r--r-- 3,942 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
/*************************************************************************
 * Copyright (C) 2018-2022 Blue Brain Project
 *
 * This file is part of NMODL distributed under the terms of the GNU
 * Lesser General Public License. See top-level LICENSE file for details.
 *************************************************************************/

#define CATCH_CONFIG_MAIN

#include "codegen/fast_math.hpp"

#include <catch2/catch_test_macros.hpp>

namespace nmodl {
namespace fast_math {

template <class T, class = typename std::enable_if<std::is_floating_point<T>::value>::type>
bool check_over_span(T f_ref(T),
                     T f_test(T),
                     const T low_limit,
                     const T high_limit,
                     const size_t npoints) {
    constexpr uint nULP = 4;
    constexpr T eps = std::numeric_limits<T>::epsilon();
    constexpr T one_o_eps = 1.0 / std::numeric_limits<T>::epsilon();
    T low = std::numeric_limits<T>::min() * one_o_eps * 1e2;

    T range = high_limit - low_limit;

    bool ret = true;
    for (size_t i = 0; i < npoints; ++i) {
        T x = low_limit + range * i / npoints;
        T ref = f_ref(x);
        T test = f_test(x);
        T diff = std::abs(ref - test);
        T max = std::max(std::abs(ref), std::abs(test));
        T tol = max * nULP;
        // normalize based on range
        if (tol > low) {
            tol *= eps;
        } else {
            diff *= one_o_eps;
        }
        if (diff > tol && diff != 0.0) {
            ret = false;
        }
    }
    return ret;
}

template <class T, class = typename std::enable_if<std::is_floating_point<T>::value>::type>
T exprelr_ref(const T x) {
    return (1.0 + x == 1.0) ? 1.0 : x / (std::exp(x) - 1.0);
};

SCENARIO("Check fast_math") {
    constexpr double low_limit = -708.0;
    constexpr double high_limit = 708.0;
    constexpr float low_limit_f = -87.0f;
    constexpr float high_limit_f = 88.0f;
    constexpr size_t npoints = 2000;
    constexpr double min_double = std::numeric_limits<double>::min();
    constexpr double max_double = std::numeric_limits<double>::max();
    constexpr double min_float = std::numeric_limits<float>::min();
    constexpr double max_float = std::numeric_limits<float>::max();

    GIVEN("vexp (double)") {
        auto test = check_over_span(std::exp, vexp, low_limit, high_limit, npoints);

        THEN("error inside threshold") {
            REQUIRE(test);
        }
    }
    GIVEN("vexp (float)") {
        auto test = check_over_span(std::exp, vexp, low_limit_f, high_limit_f, npoints);

        THEN("error inside threshold") {
            REQUIRE(test);
        }
    }
    GIVEN("expm1 (double)") {
        auto test = check_over_span(std::expm1, vexpm1, low_limit, high_limit, npoints);

        THEN("error inside threshold") {
            REQUIRE(test);
        }
    }
    GIVEN("expm1 (float)") {
        auto test = check_over_span(std::expm1, vexpm1, low_limit_f, high_limit_f, npoints);

        THEN("error inside threshold") {
            REQUIRE(test);
        }
    }
    GIVEN("exprelr (double)") {
        auto test = check_over_span(exprelr_ref, exprelr, low_limit, high_limit, npoints);

        THEN("error inside threshold") {
            REQUIRE(test);
        }
    }
    GIVEN("exprelr (float)") {
        auto test = check_over_span(exprelr_ref, exprelr, low_limit_f, high_limit_f, npoints);

        THEN("error inside threshold") {
            REQUIRE(test);
        }
    }
    GIVEN("log10 (double)") {
        auto test = check_over_span(std::log10, log10, min_double, max_double, npoints);

        THEN("error inside threshold") {
            REQUIRE(test);
        }
    }
    GIVEN("log10 (float)") {
        auto test = check_over_span(std::log10, log10, min_float, max_float, npoints);

        THEN("error inside threshold") {
            REQUIRE(test);
        }
    }
}

}  // namespace fast_math
}  // namespace nmodl