File: replay_record.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (113 lines) | stat: -rw-r--r-- 3,646 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import dataclasses
from dataclasses import field
from types import CellType, CodeType, ModuleType
from typing import Any, BinaryIO, Dict, IO, Tuple
from typing_extensions import Self

from torch.utils._import_utils import import_dill


dill = import_dill()


@dataclasses.dataclass
class ModuleRecord:
    module: ModuleType
    accessed_attrs: Dict[str, Any] = field(default_factory=dict)


@dataclasses.dataclass
class DummyModule:
    name: str
    is_torch: bool = False

    @property
    def __name__(self) -> str:
        return self.name


@dataclasses.dataclass
class ExecutionRecord:
    code: CodeType
    closure: Tuple[CellType]
    globals: Dict[str, Any] = field(default_factory=dict)
    locals: Dict[str, Any] = field(default_factory=dict)
    builtins: Dict[str, Any] = field(default_factory=dict)
    code_options: Dict[str, Any] = field(default_factory=dict)

    def dump(self, f: IO[str]) -> None:
        assert dill is not None, "replay_record requires `pip install dill`"
        dill.dump(self, f)

    @classmethod
    def load(cls, f: BinaryIO) -> Self:
        assert dill is not None, "replay_record requires `pip install dill`"
        return dill.load(f)


@dataclasses.dataclass
class ExecutionRecorder:
    LOCAL_MOD_PREFIX = "___local_mod_"

    code: CodeType
    closure: Tuple[CellType]
    globals: Dict[str, Any] = field(default_factory=dict)
    locals: Dict[str, Any] = field(default_factory=dict)
    builtins: Dict[str, Any] = field(default_factory=dict)
    code_options: Dict[str, Any] = field(default_factory=dict)
    name_to_modrec: Dict[str, ModuleRecord] = field(default_factory=dict)

    def add_local_var(self, name: str, var: Any) -> None:
        if isinstance(var, ModuleType):
            self.locals[name] = self._add_mod(var)
        else:
            self.locals[name] = var

    def add_global_var(self, name: str, var: Any) -> None:
        if isinstance(var, ModuleType):
            self.globals[name] = self._add_mod(var)
        else:
            self.globals[name] = var

    def add_local_mod(self, name: str, mod: ModuleType) -> None:
        assert isinstance(mod, ModuleType)
        self.add_global_var(name, mod)

    def record_module_access(self, mod: ModuleType, name: str, val: Any) -> None:
        if isinstance(val, ModuleType):
            self.name_to_modrec[mod.__name__].accessed_attrs[name] = self._add_mod(val)
            return

        if mod.__name__ in self.name_to_modrec:
            self.name_to_modrec[mod.__name__].accessed_attrs[name] = val

    def get_record(self) -> ExecutionRecord:
        return ExecutionRecord(
            self.code,
            self.closure,
            ExecutionRecorder._resolve_modules(self.globals),
            ExecutionRecorder._resolve_modules(self.locals),
            self.builtins.copy(),
            self.code_options.copy(),
        )

    def _add_mod(self, mod: ModuleType) -> ModuleRecord:
        if mod.__name__ not in self.name_to_modrec:
            self.name_to_modrec[mod.__name__] = ModuleRecord(mod)

        return self.name_to_modrec[mod.__name__]

    @classmethod
    def _resolve_modules(cls, vars: Dict[str, Any]) -> Dict[str, Any]:
        def resolve_module(var: Any) -> Any:
            if not isinstance(var, ModuleRecord):
                return var

            dummy_mod = DummyModule(var.module.__name__)
            for attr_name, attr_value in var.accessed_attrs.items():
                attr_value = resolve_module(attr_value)
                dummy_mod.__setattr__(attr_name, attr_value)

            return dummy_mod

        return {k: resolve_module(v) for k, v in vars.items()}