File: extern_consumer.cpp

package info (click to toggle)
halide 21.0.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 55,752 kB
  • sloc: cpp: 289,334; ansic: 22,751; python: 7,486; makefile: 4,299; sh: 2,508; java: 1,549; javascript: 282; pascal: 207; xml: 127; asm: 9
file content (121 lines) | stat: -rw-r--r-- 3,365 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#include "Halide.h"
#include "halide_test_dirs.h"

#include <cstdio>

using namespace Halide;

extern "C" HALIDE_EXPORT_SYMBOL int dump_to_file(halide_buffer_t *input, const char *filename,
                                                 int desired_min, int desired_extent,
                                                 halide_buffer_t *) {
    // Note the final output buffer argument is unused.
    if (input->is_bounds_query()) {
        // Request some range of the input buffer
        input->dim[0].min = desired_min;
        input->dim[0].extent = desired_extent;
    } else {
        FILE *f = fopen(filename, "w");
        // Depending on the schedule, other consumers, etc, Halide may
        // have evaluated more than we asked for, so don't assume that
        // the min and extents match what we requested.
        int *base = ((int *)input->host) - input->dim[0].min;
        for (int i = desired_min; i < desired_min + desired_extent; i++) {
            fprintf(f, "%d\n", base[i]);
        }
        fclose(f);
    }

    return 0;
}

bool check_result() {
    // Check the right thing happened
    const char *correct =
        "0\n"
        "1\n"
        "4\n"
        "9\n"
        "16\n"
        "25\n"
        "36\n"
        "49\n"
        "64\n"
        "81\n";

    std::string path = Internal::get_test_tmp_dir() + "halide_test_extern_consumer.txt";
    Internal::assert_file_exists(path);
    FILE *f = fopen(path.c_str(), "r");
    char result[1024];
    size_t bytes_read = fread(&result[0], 1, 1023, f);
    result[bytes_read] = 0;
    fclose(f);

    if (strncmp(result, correct, 1023)) {
        printf("Incorrect output: %s\n", result);
        return false;
    }

    return true;
}

int main(int argc, char **argv) {
    if (get_jit_target_from_environment().arch == Target::WebAssembly) {
        printf("[SKIP] WebAssembly JIT does not support passing arbitrary pointers to/from HalideExtern code.\n");
        return 0;
    }

    // Define a pipeline that dumps some squares to a file using an
    // external consumer stage.
    Func source;
    Var x;
    source(x) = x * x;

    Param<int> min, extent;
    Param<const char *> filename;

    Func sink;
    std::vector<ExternFuncArgument> args;
    args.push_back(source);
    args.push_back(filename);
    args.push_back(min);
    args.push_back(extent);
    sink.define_extern("dump_to_file", args, Int(32), 0);

    // Extern stages still have an outermost var.
    source.compute_at(sink, Var::outermost());

    sink.compile_jit();

    // Dump the first 10 squares to a file
    std::string path = Internal::get_test_tmp_dir() + "halide_test_extern_consumer.txt";
    Internal::ensure_no_file_exists(path);

    filename.set(path.c_str());
    min.set(0);
    extent.set(10);
    sink.realize();

    if (!check_result())
        return 1;

    // Test ImageParam ExternFuncArgument via passed in image.
    Buffer<int32_t> buf = source.realize({10});
    ImageParam passed_in(Int(32), 1);
    passed_in.set(buf);

    Func sink2;
    std::vector<ExternFuncArgument> args2;
    args2.push_back(passed_in);
    args2.push_back(filename);
    args2.push_back(min);
    args2.push_back(extent);
    sink2.define_extern("dump_to_file", args2, Int(32), 0);

    sink2.realize();

    if (!check_result())
        return 1;

    printf("Success!\n");
    return 0;
}