File: multi_method_module_test.py

package info (click to toggle)
halide 21.0.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 55,752 kB
  • sloc: cpp: 289,334; ansic: 22,751; python: 7,486; makefile: 4,299; sh: 2,508; java: 1,549; javascript: 282; pascal: 207; xml: 127; asm: 9
file content (49 lines) | stat: -rw-r--r-- 1,390 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import numpy as np

from multi_method_module import simplecpp, user_context


def test_simplecpp():
    buffer_input = np.ndarray([2, 2], dtype=np.uint8)
    buffer_input[0, 0] = 123
    buffer_input[0, 1] = 123
    buffer_input[1, 0] = 123
    buffer_input[1, 1] = 123

    float_arg = 3.5

    simple_output = np.ndarray([2, 2], dtype=np.float32)

    simplecpp(buffer_input, float_arg, simple_output)

    assert simple_output[0, 0] == 3.5 + 123
    assert simple_output[0, 1] == 3.5 + 123
    assert simple_output[1, 0] == 3.5 + 123
    assert simple_output[1, 1] == 3.5 + 123


def test_user_context():
    output = bytearray("\0\0\0\0", "ascii")
    user_context(None, ord("q"), output)
    assert output == bytearray("qqqq", "ascii")


def test_aot_call_failure_throws_exception():
    buffer_input = np.zeros([2, 2], dtype=np.float32)  # wrong type
    float_arg = 3.5
    simple_output = np.zeros([2, 2], dtype=np.float32)

    try:
        simplecpp(buffer_input, float_arg, simple_output)
    except RuntimeError as e:
        assert "Halide Runtime Error: -3" in str(e), str(e)
    except Exception as e:
        assert False, "Did not see expected exception, saw: " + str(e)
    else:
        assert False, "Did not see ANY exception, but one was expected!"


if __name__ == "__main__":
    test_simplecpp()
    test_user_context()
    test_aot_call_failure_throws_exception()