1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
|
import dataclasses
from dataclasses import field
from types import CellType, CodeType, ModuleType
from typing import Any, BinaryIO, Dict, IO, Tuple
from typing_extensions import Self
from torch.utils._import_utils import import_dill
dill = import_dill()
@dataclasses.dataclass
class ModuleRecord:
module: ModuleType
accessed_attrs: Dict[str, Any] = field(default_factory=dict)
@dataclasses.dataclass
class DummyModule:
name: str
is_torch: bool = False
@property
def __name__(self) -> str:
return self.name
@dataclasses.dataclass
class ExecutionRecord:
code: CodeType
closure: Tuple[CellType]
globals: Dict[str, Any] = field(default_factory=dict)
locals: Dict[str, Any] = field(default_factory=dict)
builtins: Dict[str, Any] = field(default_factory=dict)
code_options: Dict[str, Any] = field(default_factory=dict)
def dump(self, f: IO[str]) -> None:
assert dill is not None, "replay_record requires `pip install dill`"
dill.dump(self, f)
@classmethod
def load(cls, f: BinaryIO) -> Self:
assert dill is not None, "replay_record requires `pip install dill`"
return dill.load(f)
@dataclasses.dataclass
class ExecutionRecorder:
LOCAL_MOD_PREFIX = "___local_mod_"
code: CodeType
closure: Tuple[CellType]
globals: Dict[str, Any] = field(default_factory=dict)
locals: Dict[str, Any] = field(default_factory=dict)
builtins: Dict[str, Any] = field(default_factory=dict)
code_options: Dict[str, Any] = field(default_factory=dict)
name_to_modrec: Dict[str, ModuleRecord] = field(default_factory=dict)
def add_local_var(self, name: str, var: Any) -> None:
if isinstance(var, ModuleType):
self.locals[name] = self._add_mod(var)
else:
self.locals[name] = var
def add_global_var(self, name: str, var: Any) -> None:
if isinstance(var, ModuleType):
self.globals[name] = self._add_mod(var)
else:
self.globals[name] = var
def add_local_mod(self, name: str, mod: ModuleType) -> None:
assert isinstance(mod, ModuleType)
self.add_global_var(name, mod)
def record_module_access(self, mod: ModuleType, name: str, val: Any) -> None:
if isinstance(val, ModuleType):
self.name_to_modrec[mod.__name__].accessed_attrs[name] = self._add_mod(val)
return
if mod.__name__ in self.name_to_modrec:
self.name_to_modrec[mod.__name__].accessed_attrs[name] = val
def get_record(self) -> ExecutionRecord:
return ExecutionRecord(
self.code,
self.closure,
ExecutionRecorder._resolve_modules(self.globals),
ExecutionRecorder._resolve_modules(self.locals),
self.builtins.copy(),
self.code_options.copy(),
)
def _add_mod(self, mod: ModuleType) -> ModuleRecord:
if mod.__name__ not in self.name_to_modrec:
self.name_to_modrec[mod.__name__] = ModuleRecord(mod)
return self.name_to_modrec[mod.__name__]
@classmethod
def _resolve_modules(cls, vars: Dict[str, Any]) -> Dict[str, Any]:
def resolve_module(var: Any) -> Any:
if not isinstance(var, ModuleRecord):
return var
dummy_mod = DummyModule(var.module.__name__)
for attr_name, attr_value in var.accessed_attrs.items():
attr_value = resolve_module(attr_value)
dummy_mod.__setattr__(attr_name, attr_value)
return dummy_mod
return {k: resolve_module(v) for k, v in vars.items()}
|