File: tuple_select.py

package info (click to toggle)
halide 21.0.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 55,752 kB
  • sloc: cpp: 289,334; ansic: 22,751; python: 7,486; makefile: 4,299; sh: 2,508; java: 1,549; javascript: 282; pascal: 207; xml: 127; asm: 9
file content (93 lines) | stat: -rw-r--r-- 3,071 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import halide as hl


def test_tuple_select():
    x = hl.Var("x")
    y = hl.Var("y")

    # ternary select with Expr condition
    f = hl.Func("f")
    f[x, y] = hl.select(x + y < 30, (x, y), (x - 1, y - 2))

    a, b = f.realize([200, 200])
    for xx in range(a.height()):
        for yy in range(a.width()):
            correct_a = xx if xx + yy < 30 else xx - 1
            correct_b = yy if xx + yy < 30 else yy - 2
            assert a[xx, yy] == correct_a
            assert b[xx, yy] == correct_b

    # ternary select with Tuple condition
    f = hl.Func("f")
    f[x, y] = hl.select((x < 30, y < 30), (x, y), (x - 1, y - 2))

    a, b = f.realize([200, 200])
    for xx in range(a.height()):
        for yy in range(a.width()):
            correct_a = xx if xx < 30 else xx - 1
            correct_b = yy if yy < 30 else yy - 2
            assert a[xx, yy] == correct_a
            assert b[xx, yy] == correct_b

    # multiway select with Expr condition
    f = hl.Func("f")
    # fmt: off
    f[x, y] = hl.select(x + y < 30,  (x, y),
                        x + y < 100, (x-1, y-2),
                                     (x-100, y-200))
    # fmt: on

    a, b = f.realize([200, 200])
    for xx in range(a.height()):
        for yy in range(a.width()):
            correct_a = xx if xx + yy < 30 else xx - 1 if xx + yy < 100 else xx - 100
            correct_b = yy if xx + yy < 30 else yy - 2 if xx + yy < 100 else yy - 200
            assert a[xx, yy] == correct_a
            assert b[xx, yy] == correct_b

    # multiway select with Tuple condition
    f = hl.Func("f")
    # fmt: off
    f[x, y] = hl.select((x < 30, y < 30),   (x, y),
                        (x < 100, y < 100), (x-1, y-2),
                                            (x-100, y-200))
    # fmt: on

    a, b = f.realize([200, 200])
    for xx in range(a.height()):
        for yy in range(a.width()):
            correct_a = xx if xx < 30 else xx - 1 if xx < 100 else xx - 100
            correct_b = yy if yy < 30 else yy - 2 if yy < 100 else yy - 200
            assert a[xx, yy] == correct_a
            assert b[xx, yy] == correct_b

    # Failure case: mixing Expr and Tuple in multiway
    try:
        f = hl.Func("f")
        # fmt: off
        f[x, y] = hl.select((x < 30, y < 30), (x, y),
                             x + y < 100,     (x-1, y-2),
                                              (x-100, y-200))
        # fmt: on
    except hl.HalideError as e:
        assert (
            "select() on Tuples may not mix Expr and Tuple for the condition elements."
            in str(e)
        )
    else:
        assert False, "Did not see expected exception!"

    # Failure case: Tuples of mixed sizes
    try:
        f = hl.Func("f")
        f[x, y] = hl.select((x < 30, y < 30), (x, y, 0), (1, 2, 3, 4))
    except hl.HalideError as e:
        assert "select() on Tuples requires all Tuples to have identical sizes." in str(
            e
        )
    else:
        assert False, "Did not see expected exception!"


if __name__ == "__main__":
    test_tuple_select()