File: rdom.py

package info (click to toggle)
halide 21.0.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 55,752 kB
  • sloc: cpp: 289,334; ansic: 22,751; python: 7,486; makefile: 4,299; sh: 2,508; java: 1,549; javascript: 282; pascal: 207; xml: 127; asm: 9
file content (69 lines) | stat: -rw-r--r-- 1,557 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import halide as hl
import numpy as np


def test_rdom():
    x = hl.Var("x")
    y = hl.Var("y")

    diagonal = hl.Func("diagonal")
    diagonal[x, y] = 1

    domain_width = 10
    domain_height = 10

    r = hl.RDom([(0, domain_width), (0, domain_height)])
    r.where(r.x <= r.y)

    diagonal[r.x, r.y] += 2
    output = diagonal.realize([domain_width, domain_height])

    for iy in range(domain_height):
        for ix in range(domain_width):
            if ix <= iy:
                assert output[ix, iy] == 3
            else:
                assert output[ix, iy] == 1

    assert r.x.name() == r[0].name()
    assert r.y.name() == r[1].name()
    try:
        r[-1].name()
        raise Exception("underflowing index should raise KeyError")
    except KeyError:
        pass
    try:
        r[2].name()
        raise Exception("overflowing index should raise KeyError")
    except KeyError:
        pass
    try:
        r["foo"].name()
        raise Exception("bad index type should raise TypeError")
    except TypeError:
        pass

    return 0


def test_implicit_pure_definition():
    a = np.random.ranf((2, 3)).astype(np.float32)
    expected = a.sum(axis=1)

    ha = hl.Buffer(a, name="ha")
    da_cols = ha.dim(0).extent()

    x = hl.Var("x")
    k = hl.RDom([(0, da_cols)], name="k")

    hc = hl.Func("hc")
    # hc[x] = 0.0 # this is implicit
    hc[x] += ha[k, x]

    result = np.array(hc.realize([2]))
    assert np.allclose(result, expected)


if __name__ == "__main__":
    test_rdom()
    test_implicit_pure_definition()