1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
|
import halide as hl
def test_tuple_select():
x = hl.Var("x")
y = hl.Var("y")
# ternary select with Expr condition
f = hl.Func("f")
f[x, y] = hl.select(x + y < 30, (x, y), (x - 1, y - 2))
a, b = f.realize([200, 200])
for xx in range(a.height()):
for yy in range(a.width()):
correct_a = xx if xx + yy < 30 else xx - 1
correct_b = yy if xx + yy < 30 else yy - 2
assert a[xx, yy] == correct_a
assert b[xx, yy] == correct_b
# ternary select with Tuple condition
f = hl.Func("f")
f[x, y] = hl.select((x < 30, y < 30), (x, y), (x - 1, y - 2))
a, b = f.realize([200, 200])
for xx in range(a.height()):
for yy in range(a.width()):
correct_a = xx if xx < 30 else xx - 1
correct_b = yy if yy < 30 else yy - 2
assert a[xx, yy] == correct_a
assert b[xx, yy] == correct_b
# multiway select with Expr condition
f = hl.Func("f")
# fmt: off
f[x, y] = hl.select(x + y < 30, (x, y),
x + y < 100, (x-1, y-2),
(x-100, y-200))
# fmt: on
a, b = f.realize([200, 200])
for xx in range(a.height()):
for yy in range(a.width()):
correct_a = xx if xx + yy < 30 else xx - 1 if xx + yy < 100 else xx - 100
correct_b = yy if xx + yy < 30 else yy - 2 if xx + yy < 100 else yy - 200
assert a[xx, yy] == correct_a
assert b[xx, yy] == correct_b
# multiway select with Tuple condition
f = hl.Func("f")
# fmt: off
f[x, y] = hl.select((x < 30, y < 30), (x, y),
(x < 100, y < 100), (x-1, y-2),
(x-100, y-200))
# fmt: on
a, b = f.realize([200, 200])
for xx in range(a.height()):
for yy in range(a.width()):
correct_a = xx if xx < 30 else xx - 1 if xx < 100 else xx - 100
correct_b = yy if yy < 30 else yy - 2 if yy < 100 else yy - 200
assert a[xx, yy] == correct_a
assert b[xx, yy] == correct_b
# Failure case: mixing Expr and Tuple in multiway
try:
f = hl.Func("f")
# fmt: off
f[x, y] = hl.select((x < 30, y < 30), (x, y),
x + y < 100, (x-1, y-2),
(x-100, y-200))
# fmt: on
except hl.HalideError as e:
assert (
"select() on Tuples may not mix Expr and Tuple for the condition elements."
in str(e)
)
else:
assert False, "Did not see expected exception!"
# Failure case: Tuples of mixed sizes
try:
f = hl.Func("f")
f[x, y] = hl.select((x < 30, y < 30), (x, y, 0), (1, 2, 3, 4))
except hl.HalideError as e:
assert "select() on Tuples requires all Tuples to have identical sizes." in str(
e
)
else:
assert False, "Did not see expected exception!"
if __name__ == "__main__":
test_tuple_select()
|