File: flatten_params_wrapper.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (165 lines) | stat: -rw-r--r-- 6,486 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# Copyright (c) Tongzhou Wang
# Licensed under the MIT License.

import contextlib
from typing import Any, Dict, Generator, List

import torch
import torch.nn as nn
from torch.distributed.utils import _replace_by_prefix

from .flat_param import FlatParamHandle, HandleConfig

FLAT_PARAM = "flat_param"
FPW_MODULE = "_fpw_module"

__all__ = ["FlattenParamsWrapper"]


def _post_state_dict_hook(
    module: nn.Module, state_dict: Dict[str, Any], prefix: str, *args: Any
) -> Dict[str, Any]:
    """
    _post_state_dict_hook() is called after the state_dict() is executed
    and before returning the state_dict to the users.
    This API post-processes the keys of the state_dict to remove the
    FlattenParamsWrapper internal prefix.
    """
    # Move everything from FPW_MODULE up one level.
    _replace_by_prefix(state_dict, prefix + f"{FPW_MODULE}.", prefix)
    return state_dict


def _pre_load_state_dict_hook(
    state_dict: Dict[str, Any],
    prefix: str,
    *args: Any,
) -> None:
    """
    _pre_load_state_dict_hook() is called before the _load_from_state_dict() is
    executed. This API pre-processes the keys of the state_dict to add the
    FlattenParamsWrapper internal prefix.
    """
    # Push everything down to FPW_MODULE level.
    _replace_by_prefix(state_dict, prefix, prefix + f"{FPW_MODULE}.")
    # The flat_param_* keys actually needs to move one level up.
    flat_param_key = prefix + f"{FPW_MODULE}.{FLAT_PARAM}"
    for k in list(state_dict.keys()):
        if k.startswith(flat_param_key):
            last_part = k.split(".")[-1]
            assert last_part.startswith(
                FLAT_PARAM
            ), f"Expected key to contain flat_param, but key name is {k}"
            _replace_by_prefix(state_dict, k, prefix + last_part)


class FlattenParamsWrapper(nn.Module):
    """
    This is a wrapper for flattening parameters in a ``nn.Module`` 's subtree
    into a single flattened parameter and is based on [1]. This is used for
    :class:`FullyShardedDataParallel` 's recursive wrapping.
    [1] https://github.com/SsnL/PyTorch-Reparam-Module

    Args:
        module (nn.Module): Module to wrap.
        params (List[nn.Parameter]): Parameters in ``module`` 's subtree to
            flatten into a single flattened parameter.
        device (torch.device): The compute and communication device for this
            wrapper's handle.
        config (HandleConfig): A config customizing this wrapper's handle based
            on FSDP's available features.

    Attributes:
        flat_param (Optional[FlatParameter]): The flattened parameter.
            ``flat_param`` is ``None`` either when (1) this wrapper manages no
            parameters or (2) the wrapped module's parameters are unflattened.
        _fpw_module (nn.Module): The wrapped module.
        _flat_param_handle (FlatParamHandle): A handle for the flattened
            parameter; only present if this wrapper manages parameters.
    """

    def __init__(
        self,
        module: nn.Module,
        params: List[nn.Parameter],
        device: torch.device,
        config: HandleConfig,
    ) -> None:
        super().__init__()
        self._fpw_module = module
        self.flat_param = None
        # Register hooks to clean parameter names for state dict (even if this
        # wrapper itself manages no parameters since it must clean names from
        # submodules)
        self._register_state_dict_hook(_post_state_dict_hook)
        self._register_load_state_dict_pre_hook(_pre_load_state_dict_hook)
        if len(params) == 0:
            return
        self._flat_param_handle = FlatParamHandle(params, module, device, config)
        # Defining `self.flat_param` registers the `FlatParameter` and makes it
        # visible to `named_parameters()`
        self.flat_param = self._flat_param_handle.flat_param
        assert getattr(self, FPW_MODULE) is self._fpw_module
        assert getattr(self, FLAT_PARAM) is self.flat_param

    @property
    def has_params(self) -> bool:
        """Returns whether this wrapper manages any parameters."""
        return hasattr(self, "_flat_param_handle")

    @property
    def handle(self) -> FlatParamHandle:
        assert hasattr(self, "_flat_param_handle"), (
            "Accessing the handle of a `FlattenParamsWrapper` that does not "
            "manage any parameters"
        )
        return self._flat_param_handle

    @property
    def module(self) -> Any:
        """Returns the wrapped module (like DDP)."""
        return self._fpw_module

    @contextlib.contextmanager
    def unflatten_as_params(self) -> Generator:
        """
        Assumes that the flattened parameter is unsharded. When in the context,
        unflattens the original parameters as ``nn.Parameter`` views into the
        flattened parameter and de-registers the flattened parameter. After the
        context, restores the original parameters as ``Tensor`` views into the
        flattened parameter and re-registers the flattened parameter.
        """
        if getattr(self, "flat_param", None) is None:
            yield
        else:
            # De-register the `FlatParameter` from this wrapper to hide it from
            # `named_parameters()` (though it still exists in memory)
            del self.flat_param
            try:
                with self._flat_param_handle.unflatten_as_params():
                    yield
            finally:
                # Re-register the `FlatParameter`
                self.flat_param = self._flat_param_handle.flat_param

    def __getattr__(self, name: str) -> Any:
        """Forward missing attributes of this wrapper to the wrapped module."""
        try:
            return super().__getattr__(name)  # defer to `nn.Module`'s logic
        except AttributeError:
            return getattr(self.module, name)  # fall back to the wrapped module

    def __getitem__(self, key: int) -> Any:
        """Forward indexing calls to the wrapped module in case the wrapped
        module is an ``nn.Sequential``."""
        return self.module.__getitem__(key)

    def forward(self, *inputs: Any, **kwinputs: Any) -> Any:
        if self.flat_param is not None:
            self._flat_param_handle._unflatten(as_params=False)
        return self.module(*inputs, **kwinputs)