File: fast_pow.cpp

package info (click to toggle)
halide 21.0.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 55,752 kB
  • sloc: cpp: 289,334; ansic: 22,751; python: 7,486; makefile: 4,299; sh: 2,508; java: 1,549; javascript: 282; pascal: 207; xml: 127; asm: 9
file content (110 lines) | stat: -rw-r--r-- 3,605 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#include "Halide.h"
#include "halide_benchmark.h"
#include <algorithm>
#include <cstdio>

using namespace Halide;
using namespace Halide::Tools;

// powf() is a macro in some environments, so always wrap it
extern "C" HALIDE_EXPORT_SYMBOL float pow_ref(float x, float y) {
    return powf(x, y);
}
HalideExtern_2(float, pow_ref, float, float);

int main(int argc, char **argv) {
    Target host = get_host_target();
    Target hl_target = get_target_from_environment();
    Target hl_jit_target = get_jit_target_from_environment();
    printf("host is:          %s\n", host.to_string().c_str());
    printf("HL_TARGET is:     %s\n", hl_target.to_string().c_str());
    printf("HL_JIT_TARGET is: %s\n", hl_jit_target.to_string().c_str());

    if (hl_jit_target.arch == Target::X86 &&
        !hl_jit_target.has_feature(Target::SSE41)) {
        printf("[SKIP] These intrinsics are known to be slow on x86 without sse 4.1.\n");
        return 0;
    }

    if (hl_jit_target.arch == Target::WebAssembly) {
        printf("[SKIP] Performance tests are meaningless and/or misleading under WebAssembly interpreter.\n");
        return 0;
    }

    Func f, g, h;
    Var x, y;

    Param<int> pows_per_pixel;

    RDom s(0, pows_per_pixel);
    f(x, y) = sum(pow_ref((x + 1) / 512.0f, (y + 1 + s) / 512.0f));
    g(x, y) = sum(pow((x + 1) / 512.0f, (y + 1 + s) / 512.0f));
    h(x, y) = sum(fast_pow((x + 1) / 512.0f, (y + 1 + s) / 512.0f));
    f.vectorize(x, 8);
    g.vectorize(x, 8);
    h.vectorize(x, 8);

    Buffer<float> correct_result(2048, 768);
    Buffer<float> fast_result(2048, 768);
    Buffer<float> faster_result(2048, 768);

    pows_per_pixel.set(1);

    f.realize(correct_result);
    g.realize(fast_result);
    h.realize(faster_result);

    pows_per_pixel.set(20);

    // All profiling runs are done into the same buffer, to avoid
    // cache weirdness.
    Buffer<float> timing_scratch(256, 256);
    double t1 = 1e3 * benchmark([&]() { f.realize(timing_scratch); });
    double t2 = 1e3 * benchmark([&]() { g.realize(timing_scratch); });
    double t3 = 1e3 * benchmark([&]() { h.realize(timing_scratch); });

    RDom r(correct_result);
    Func fast_error, faster_error;
    Expr fast_delta = correct_result(r.x, r.y) - fast_result(r.x, r.y);
    Expr faster_delta = correct_result(r.x, r.y) - faster_result(r.x, r.y);
    fast_error() += cast<double>(fast_delta * fast_delta);
    faster_error() += cast<double>(faster_delta * faster_delta);

    Buffer<double> fast_err = fast_error.realize();
    Buffer<double> faster_err = faster_error.realize();

    int timing_N = timing_scratch.width() * timing_scratch.height() * 10;
    int correctness_N = fast_result.width() * fast_result.height();
    fast_err() = sqrt(fast_err() / correctness_N);
    faster_err() = sqrt(faster_err() / correctness_N);

    printf("powf: %f ns per pixel\n"
           "Halide's pow: %f ns per pixel (rms error = %0.10f)\n"
           "Halide's fast_pow: %f ns per pixel (rms error = %0.10f)\n",
           1000000 * t1 / timing_N,
           1000000 * t2 / timing_N, fast_err(),
           1000000 * t3 / timing_N, faster_err());

    if (fast_err() > 0.000001) {
        printf("Error for pow too large\n");
        return 1;
    }

    if (faster_err() > 0.0001) {
        printf("Error for fast_pow too large\n");
        return 1;
    }

    if (t1 < t2) {
        printf("powf is faster than Halide's pow\n");
        return 1;
    }

    if (t2 * 1.5 < t3) {
        printf("pow is more than 1.5x faster than fast_pow\n");
        return 1;
    }

    printf("Success!\n");
    return 0;
}