File: test_jit_disabled.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (91 lines) | stat: -rw-r--r-- 2,357 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Owner(s): ["oncall: jit"]

import sys
import os
import contextlib
import subprocess
from torch.testing._internal.common_utils import TestCase, run_tests, TemporaryFileName


@contextlib.contextmanager
def _jit_disabled():
    cur_env = os.environ.get("PYTORCH_JIT", "1")
    os.environ["PYTORCH_JIT"] = "0"
    try:
        yield
    finally:
        os.environ["PYTORCH_JIT"] = cur_env


class TestJitDisabled(TestCase):
    """
    These tests are separate from the rest of the JIT tests because we need
    run a new subprocess and `import torch` with the correct environment
    variables set.
    """

    def compare_enabled_disabled(self, src):
        """
        Runs the script in `src` with PYTORCH_JIT enabled and disabled and
        compares their stdout for equality.
        """
        # Write `src` out to a temporary so our source inspection logic works
        # correctly.
        with TemporaryFileName() as fname:
            with open(fname, 'w') as f:
                f.write(src)
                with _jit_disabled():
                    out_disabled = subprocess.check_output([
                        sys.executable,
                        fname])
                out_enabled = subprocess.check_output([
                    sys.executable,
                    fname])
                self.assertEqual(out_disabled, out_enabled)

    def test_attribute(self):
        _program_string = """
import torch

class Foo(torch.jit.ScriptModule):
    def __init__(self, x):
        super().__init__()
        self.x = torch.jit.Attribute(x, torch.Tensor)

    def forward(self, input):
        return input

s = Foo(torch.ones(2, 3))
print(s.x)
"""
        self.compare_enabled_disabled(_program_string)

    def test_script_module_construction(self):
        _program_string = """
import torch

class AModule(torch.jit.ScriptModule):
    @torch.jit.script_method
    def forward(self, input):
        pass

AModule()
print("Didn't throw exception")
"""
        self.compare_enabled_disabled(_program_string)

    def test_recursive_script(self):
        _program_string = """
import torch

class AModule(torch.nn.Module):
    def forward(self, input):
        pass

sm = torch.jit.script(AModule())
print("Didn't throw exception")
"""
        self.compare_enabled_disabled(_program_string)

if __name__ == '__main__':
    run_tests()