File: image_of_lists.cpp

package info (click to toggle)
halide 21.0.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 55,752 kB
  • sloc: cpp: 289,334; ansic: 22,751; python: 7,486; makefile: 4,299; sh: 2,508; java: 1,549; javascript: 282; pascal: 207; xml: 127; asm: 9
file content (62 lines) | stat: -rw-r--r-- 1,868 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include "Halide.h"
#include <stdio.h>

#include <list>

using namespace Halide;

extern "C" HALIDE_EXPORT_SYMBOL std::list<int> *list_create(int) {
    return new std::list<int>();
}
HalideExtern_1(std::list<int> *, list_create, int);

extern "C" HALIDE_EXPORT_SYMBOL std::list<int> *list_maybe_insert(std::list<int> *list, bool insert, int value) {
    if (insert) {
        list->push_back(value);
    }
    return list;
}
HalideExtern_3(std::list<int> *, list_maybe_insert, std::list<int> *, bool, int);

int main(int argc, char **argv) {
    if (get_jit_target_from_environment().arch == Target::WebAssembly) {
        printf("[SKIP] WebAssembly JIT does not support passing arbitrary pointers to/from HalideExtern code.\n");
        return 0;
    }

    // Compute the list of factors of all numbers < 100
    Func factors;
    Var x;

    // Ideally this would only iterate up to the square root of x, but
    // we don't have dynamic reduction bounds yet.
    RDom r(1, 99);

    // Create an std::list for each result
    factors(x) = list_create(x);

    // Because Halide::select evaluates both paths, we need to move
    // the condition into the C function.
    factors(x) = list_maybe_insert(factors(x), x % r == 0, r);

    Buffer<std::list<int> *> result = factors.realize({100});

    // Inspect the results for correctness
    for (int i = 0; i < 100; i++) {
        std::list<int> *list = result(i);
        // printf("Factors of %d: ", i);
        for (std::list<int>::iterator iter = list->begin(); iter != list->end(); iter++) {
            int factor = *iter;
            if (i % factor) {
                printf("Error: %d is not a factor of %d\n", factor, i);
                return 1;
            }
            // printf("%d ", factor);
        }
        // printf("\n");
        delete list;
    }

    printf("Success!\n");
    return 0;
}