File: nl_means_generator.cpp

package info (click to toggle)
halide 21.0.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 55,752 kB
  • sloc: cpp: 289,334; ansic: 22,751; python: 7,486; makefile: 4,299; sh: 2,508; java: 1,549; javascript: 282; pascal: 207; xml: 127; asm: 9
file content (162 lines) | stat: -rw-r--r-- 5,874 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#include "Halide.h"

namespace {

using namespace Halide;

class NonLocalMeans : public Halide::Generator<NonLocalMeans> {
public:
    Input<Buffer<float, 3>> input{"input"};
    Input<int> patch_size{"patch_size"};
    Input<int> search_area{"search_area"};
    Input<float> sigma{"sigma"};

    Output<Buffer<float, 3>> non_local_means{"non_local_means"};

    void generate() {
        /* THE ALGORITHM */

        // This implements the basic description of non-local means found at
        // https://en.wikipedia.org/wiki/Non-local_means

        Var x("x"), y("y"), c("c");

        Expr inv_sigma_sq = -1.0f / (sigma * sigma * patch_size * patch_size);

        // Add a boundary condition
        Func clamped = BoundaryConditions::repeat_edge(input);

        // Define the difference images
        Var dx("dx"), dy("dy");
        Func dc("d");
        dc(x, y, dx, dy, c) = pow(clamped(x, y, c) - clamped(x + dx, y + dy, c), 2);

        // Sum across color channels
        RDom channels(0, 3);
        Func d("d");
        d(x, y, dx, dy) = sum(dc(x, y, dx, dy, channels));

        // Find the patch differences by blurring the difference images
        RDom patch_dom(-(patch_size / 2), patch_size);
        Func blur_d_y("blur_d_y");
        blur_d_y(x, y, dx, dy) = sum(d(x, y + patch_dom, dx, dy));

        Func blur_d("blur_d");
        blur_d(x, y, dx, dy) = sum(blur_d_y(x + patch_dom, y, dx, dy));

        // Compute the weights from the patch differences
        Func w("w");
        w(x, y, dx, dy) = fast_exp(blur_d(x, y, dx, dy) * inv_sigma_sq);

        // Add an alpha channel
        Func clamped_with_alpha("clamped_with_alpha");
        clamped_with_alpha(x, y, c) = mux(c, {clamped(x, y, 0), clamped(x, y, 1), clamped(x, y, 2), 1.0f});

        // Define a reduction domain for the search area
        RDom s_dom(-(search_area / 2), search_area, -(search_area / 2), search_area);

        // Compute the sum of the pixels in the search area
        Func non_local_means_sum("non_local_means_sum");
        non_local_means_sum(x, y, c) += w(x, y, s_dom.x, s_dom.y) * clamped_with_alpha(x + s_dom.x, y + s_dom.y, c);

        non_local_means(x, y, c) =
            clamp(non_local_means_sum(x, y, c) / non_local_means_sum(x, y, 3), 0.0f, 1.0f);

        /* THE SCHEDULE */

        // Require 3 channels for output
        non_local_means.dim(2).set_bounds(0, 3);

        Var tx("tx"), ty("ty"), xi("xi"), yi("yi");

        /* ESTIMATES */
        // (This can be useful in conjunction with RunGen and benchmarks as well
        // as auto-schedule, so we do it in all cases.)
        // Provide estimates on the input image
        input.set_estimates({{0, 1536}, {0, 2560}, {0, 3}});
        // Provide estimates on the parameters
        patch_size.set_estimate(7);
        search_area.set_estimate(7);
        sigma.set_estimate(0.12f);
        // Provide estimates on the output pipeline
        non_local_means.set_estimates({{0, 1536}, {0, 2560}, {0, 3}});

        if (using_autoscheduler()) {
            // nothing
        } else if (get_target().has_gpu_feature()) {
            // 22 ms on a 2060 RTX
            Var xii, yii;

            // We'll use 32x16 thread blocks throughout. This was
            // found by just trying lots of sizes, but large thread
            // blocks are particularly good in the blur_d stage to
            // avoid doing wasted blurring work at tile boundaries
            // (especially for large patch sizes).

            non_local_means.compute_root()
                .reorder(c, x, y)
                .unroll(c)
                .gpu_tile(x, y, xi, yi, 32, 16);

            non_local_means_sum.compute_root()
                .gpu_tile(x, y, xi, yi, 32, 16)
                .update()
                .reorder(c, s_dom.x, x, y, s_dom.y)
                .tile(x, y, xi, yi, 32, 16)
                .gpu_blocks(x, y)
                .gpu_threads(xi, yi)
                .unroll(c);

            // The patch size we're benchmarking for is 7, which
            // implies an expansion of 6 pixels for footprint of the
            // blur, so we'll size tiles of blur_d to be a multiple of
            // the thread block size minus 6.
            blur_d.compute_at(non_local_means_sum, s_dom.y)
                .tile(x, y, xi, yi, 128 - 6, 32 - 6)
                .tile(xi, yi, xii, yii, 32, 16)
                .gpu_threads(xii, yii)
                .gpu_blocks(x, y, dx);

            blur_d_y.compute_at(blur_d, x)
                .tile(x, y, xi, yi, 32, 16)
                .gpu_threads(xi, yi);

            d.compute_at(blur_d, x)
                .tile(x, y, xi, yi, 32, 16)
                .gpu_threads(xi, yi);

        } else {
            // 64 ms on an Intel i9-9960X using 32 threads at 3.0 GHz

            const int vec = natural_vector_size<float>();

            non_local_means.compute_root()
                .reorder(c, x, y)
                .tile(x, y, tx, ty, x, y, 16, 8)
                .parallel(ty)
                .vectorize(x, vec);
            blur_d_y.compute_at(non_local_means, tx)
                .hoist_storage(non_local_means, ty)
                .reorder(y, x)
                .vectorize(x, vec);
            d.compute_at(non_local_means, tx)
                .hoist_storage(non_local_means, ty)
                .vectorize(x, vec);
            non_local_means_sum.compute_at(non_local_means, x)
                .reorder(c, x, y)
                .bound(c, 0, 4)
                .unroll(c)
                .vectorize(x, vec);
            non_local_means_sum.update(0)
                .reorder(c, x, y, s_dom.x, s_dom.y)
                .unroll(c)
                .vectorize(x, vec);
            blur_d.compute_at(non_local_means_sum, x)
                .vectorize(x, vec);
        }
    }
};

}  // namespace

HALIDE_REGISTER_GENERATOR(NonLocalMeans, nl_means)