File: float_precision_test.py

package info (click to toggle)
halide 21.0.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 55,752 kB
  • sloc: cpp: 289,334; ansic: 22,751; python: 7,486; makefile: 4,299; sh: 2,508; java: 1,549; javascript: 282; pascal: 207; xml: 127; asm: 9
file content (60 lines) | stat: -rw-r--r-- 1,697 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import math
import warnings

import halide as hl
import numpy


class AssertWarnsContext:
    def __init__(self, warn):
        self.expected = warn
        self.occurred = False

    def __bool__(self):
        return self.occurred

    def __enter__(self):
        self.warnings_manager = warnings.catch_warnings(record=True)
        self.warnings = self.warnings_manager.__enter__()
        return self

    def __exit__(self, exc_type, exc_value, tb):
        self.warnings_manager.__exit__(exc_type, exc_value, tb)
        if exc_type is not None:
            return
        self.occurred = False
        for m in self.warnings:
            if not isinstance(m.message, self.expected):
                continue
            self.occurred = True


def test():
    def test_pattern(c):
        x = hl.Var("x")
        f = hl.Func("f")
        f[x] = x * hl.f64(c) * (hl.f64(0.1) + hl.f64(0.2))
        for i, hl_value in enumerate(numpy.asarray(f.realize([10]))):
            py_value = i * c * (0.1 + 0.2)
            check = math.isclose(hl_value, py_value)
            assert check, f"{i}[{c}]: {hl_value} != {py_value}"

    test_pattern(0.123456789012345678)
    test_pattern(0.987654321098765432)

    x = hl.Var("x")
    with AssertWarnsContext(RuntimeWarning) as ctx:
        x + 0.123456789012345678
    assert ctx.occurred, "RuntimeWarning didn't occur."

    with AssertWarnsContext(RuntimeWarning) as ctx:
        x + hl.f64(0.123456789012345678)
    assert not ctx.occurred, "RuntimeWarning occurred."

    with AssertWarnsContext(RuntimeWarning) as ctx:
        x + 0.75  # 0.5 + 0.25
    assert not ctx.occurred, "RuntimeWarning occurred."


if __name__ == "__main__":
    test()