File: test_generator.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (104 lines) | stat: -rw-r--r-- 3,145 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# Owner(s): ["oncall: jit"]

import torch
import torch._lazy.metrics as metrics
import torch._lazy.ts_backend
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase


torch._lazy.ts_backend.init()


class LazyGeneratorTest(TestCase):
    def test_generator(self):
        """
        Test that generators are being inserted into the TorchScript
        graph by setting different seeds before each call to
        generate_tensor but the resulting tensor is the same
        """

        def generate_tensor():
            g1 = torch.Generator()
            g1.manual_seed(2023)
            t1 = torch.tensor(1.0)
            t1.uniform_(generator=g1)

            g2 = torch.Generator()
            g2.manual_seed(2024)
            t2 = torch.tensor(1.0)
            t2.normal_(generator=g2)

            return t1, t2

        torch.manual_seed(1)

        with torch.device("cpu"):
            cpu_t1, cpu_t2 = generate_tensor()

        torch.manual_seed(2)

        with torch.device("lazy"):
            lazy_t1, lazy_t2 = generate_tensor()

        torch._lazy.mark_step()

        assert torch.allclose(
            cpu_t1, lazy_t1.to("cpu")
        ), f"Expected {cpu_t1}, got {lazy_t1.to('cpu')}"
        assert torch.allclose(
            cpu_t2, lazy_t2.to("cpu")
        ), f"Expected {cpu_t2}, got {lazy_t2.to('cpu')}"

    @skipIfTorchDynamo("Torch Dynamo does not support torch.Generator type")
    def test_generator_causes_multiple_compiles(self):
        """
        Test that inserting generators with different seed caused recompile
        """

        def generate_tensor(seed):
            t = torch.tensor(1.0)
            g = torch.Generator()
            g.manual_seed(seed)
            t.uniform_(-1, 1, generator=g)
            return t

        metrics.reset()

        with torch.device("lazy"):
            t = generate_tensor(1)
            torch._lazy.mark_step()

            uncached_compile = metrics.counter_value("UncachedCompile")
            assert (
                uncached_compile == 1
            ), f"Expected 1 uncached compiles, got {uncached_compile}"

            t = generate_tensor(2)
            torch._lazy.mark_step()

            uncached_compile = metrics.counter_value("UncachedCompile")
            assert (
                uncached_compile == 2
            ), f"Expected 2 uncached compiles, got {uncached_compile}"

            t = generate_tensor(1)
            torch._lazy.mark_step()

            uncached_compile = metrics.counter_value("UncachedCompile")
            assert (
                uncached_compile == 2
            ), f"Expected 2 uncached compiles, got {uncached_compile}"
            cached_compile = metrics.counter_value("CachedCompile")
            assert (
                cached_compile == 1
            ), f"Expected 1 cached compile, got {cached_compile}"

        metrics.reset()

        latest_graph = torch._C._lazy_ts_backend._get_latest_computation_graph()
        assert 'torch.Generator(device="cpu", seed=1)' in latest_graph
        assert "aten::uniform" in latest_graph


if __name__ == "__main__":
    run_tests()