File: _composable_state.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (44 lines) | stat: -rw-r--r-- 1,394 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import weakref
from typing import cast, Optional

import torch.nn as nn


class _State:
    pass


_module_state_mapping: weakref.WeakKeyDictionary[
    nn.Module, weakref.ReferenceType[_State]
] = weakref.WeakKeyDictionary()


def _insert_module_state(module: nn.Module, state: _State) -> None:
    global _module_state_mapping
    assert module not in _module_state_mapping, f"Inserting {module} more than once."
    _module_state_mapping[module] = weakref.ref(state)


def _get_module_state(module: nn.Module) -> Optional[_State]:
    """
    Return the ``_State`` in ``model``.

    Given a ``module``, this API finds out if the module is also a ``_State``
    instance or if the module is managed by a composable API. If the module
    is also a ``_State``, ``module`` will be casted to ``_State` and returned.
    If it is managed by a composable API, the corresponding ``_State`` will
    be returned.
    """
    global _module_state_mapping
    if isinstance(module, _State):
        return cast(_State, module)
    else:
        # https://github.com/pytorch/pytorch/issues/107054
        if module in _module_state_mapping:
            state_ref = _module_state_mapping[module]
            state = state_ref()
            if state is None:
                raise AssertionError("State has already been garbage collected")
            return state
        else:
            return None