File: test_step_closures.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (91 lines) | stat: -rw-r--r-- 2,318 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Owner(s): ["oncall: jit"]

from threading import Event
from time import sleep

import torch._lazy
import torch._lazy.ts_backend
from torch.testing._internal.common_utils import run_tests, TestCase

torch._lazy.ts_backend.init()


class ClosuresTest(TestCase):
    def test_synchronous(self):
        flag = Event()
        assert not flag.is_set()

        def closure():
            sleep(1)
            assert not flag.is_set()
            flag.set()

        torch._lazy.add_step_closure(closure)
        torch._lazy.mark_step()

        # should not get to this part before closure is finished running
        assert flag.is_set()

    def test_asynchronous(self):
        flag = Event()
        assert not flag.is_set()

        def closure():
            sleep(1)
            assert flag.is_set()

        torch._lazy.add_step_closure(closure, run_async=True)
        torch._lazy.mark_step()

        # should get to this part and complete before closure is finished running
        assert not flag.is_set()
        flag.set()

    def test_synchronous_exception(self):
        flag = Event()
        assert not flag.is_set()

        try:

            def closure():
                flag.set()
                raise RuntimeError("Simulating exception in closure")

            torch._lazy.add_step_closure(closure)
            torch._lazy.mark_step()

            raise AssertionError()  # Should not reach here
        except RuntimeError as e:
            assert flag.is_set(), "Should have caught exception from closure"

    def test_asynchronous_exception(self):
        flag = Event()
        assert not flag.is_set()

        def closure1():
            flag.set()
            raise RuntimeError("Simulating exception in closure1")

        torch._lazy.add_step_closure(closure1, run_async=True)
        torch._lazy.mark_step()

        flag.wait(timeout=5)

        try:

            def closure2():  # Should never execute
                flag.clear()

            torch._lazy.add_step_closure(closure2, run_async=True)
            torch._lazy.mark_step()

            raise AssertionError()  # Should not reach here
        except RuntimeError as e:
            # Should have caught exception from closure1
            pass

        assert flag.is_set()


if __name__ == "__main__":
    run_tests()