File: specialize_trim_condition.cpp

package info (click to toggle)
halide 21.0.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 55,752 kB
  • sloc: cpp: 289,334; ansic: 22,751; python: 7,486; makefile: 4,299; sh: 2,508; java: 1,549; javascript: 282; pascal: 207; xml: 127; asm: 9
file content (78 lines) | stat: -rw-r--r-- 2,219 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include "Halide.h"
#include "HalideRuntime.h"
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>

using namespace Halide;

int load_count = 0;

// A trace that records the number of loads
int my_trace(JITUserContext *user_context, const halide_trace_event_t *ev) {

    if (ev->event == halide_trace_load) {
        load_count++;
    }
    return 0;
}

int main(int argc, char **argv) {
    Param<float> scale_factor_x, scale_factor_y;
    ImageParam input(UInt(8), 2);

    Var x, y;

    Func f;
    Expr upsample_x = scale_factor_x > cast<float>(1.0f);
    Expr upsample_y = scale_factor_y > cast<float>(1.0f);
    Expr upsample = upsample_x && upsample_y;
    Expr downsample = !upsample_x && !upsample_y;

    f(x, y) = select(upsample, input(cast<int>(x / 2), cast<int>(y / 2)),
                     select(downsample, input(x * 2, y * 2), 0));

    input.trace_loads();
    f.jit_handlers().custom_trace = &my_trace;

    // Impossible condition
    // f.specialize(upsample && downsample);
    f.specialize(upsample && !downsample);
    f.specialize(!upsample && downsample);
    f.specialize(!upsample && !downsample);
    f.specialize_fail("Unreachable condition");

    Buffer<uint8_t> img(16, 16);
    input.set(img);

    {
        // In this specialization, one of the select branches should be trimmed,
        // resulting in one load per output pixel
        load_count = 0;
        scale_factor_x.set(2.0f);
        scale_factor_y.set(2.0f);
        Buffer<uint8_t> out = f.realize({8, 8});
        assert(load_count == 64);
    }
    {
        // In this specialization, no select can be trimmed,
        // resulting in two loads per output pixel
        load_count = 0;
        scale_factor_x.set(0.5f);
        scale_factor_y.set(2.0f);
        Buffer<uint8_t> out = f.realize({8, 8});
        assert(load_count == 128);
    }
    {
        // In this specialization, one of the select branches should be trimmed,
        // resulting in one load per output pixel
        load_count = 0;
        scale_factor_x.set(0.5f);
        scale_factor_y.set(0.5f);
        Buffer<uint8_t> out = f.realize({8, 8});
        assert(load_count == 64);
    }

    printf("Success!\n");
    return 0;
}