File: test_apply_optimizer_in_backward.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (162 lines) | stat: -rw-r--r-- 5,674 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Owner(s): ["oncall: distributed"]

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import unittest
from copy import deepcopy

import torch
import torch.nn as nn
from torch.distributed.optim import (
    _apply_optimizer_in_backward,
    _get_in_backward_optimizers,
)


# TODO (rohan-varma): Add FSDP & DDP tests once supported


def _validate_params(params_list, fn):
    ref_params = params_list[0]
    for param_list in params_list[1:]:
        for p1, p2 in zip(ref_params, param_list):
            fn(p1, p2)


class ApplyOverlappedOptimizerTest(unittest.TestCase):
    def _run_training_loop_and_validate(self, inp, models, optimizers):
        for i in range(6):
            for model in models:
                model(inp).sum().backward()
            for opt in optimizers:
                opt.step()

            with self.subTest(i):
                _validate_params(
                    [model.parameters() for model in models],
                    torch.testing.assert_allclose,
                )

            for opt in optimizers:
                opt.zero_grad(set_to_none=True)

    def _test_apply_optimizer_in_backward(self, share_params) -> None:
        weight_optimizer_kwargs = {"lr": 1.0}
        bias_optimizer_kwargs = {"lr": 0.5}
        model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10))
        if share_params:
            model[0].weight = model[1].weight

        # Use different optimizers for weights & biases.
        weights = [m.weight for m in model]
        biases = [m.bias for m in model]
        optim_weight = torch.optim.SGD(weights, **weight_optimizer_kwargs)
        optim_bias = torch.optim.SGD(biases, **bias_optimizer_kwargs)
        model_with_opt_in_bwd = deepcopy(model)

        # Apply different optimizer in backwards for weights and biases.
        _apply_optimizer_in_backward(
            torch.optim.SGD,
            [m.weight for m in model_with_opt_in_bwd],
            optimizer_kwargs=weight_optimizer_kwargs,
        )

        _apply_optimizer_in_backward(
            torch.optim.SGD,
            [m.bias for m in model_with_opt_in_bwd],
            optimizer_kwargs=bias_optimizer_kwargs,
        )

        _validate_params(
            [
                model.parameters(),
                model_with_opt_in_bwd.parameters(),
            ],
            torch.testing.assert_allclose,
        )

        self._run_training_loop_and_validate(
            torch.randn(4, 10),
            [model, model_with_opt_in_bwd],
            [optim_weight, optim_bias],
        )

    def test_apply_optimizer_in_backward(self) -> None:
        self._test_apply_optimizer_in_backward(share_params=False)

    def test_apply_optimizer_in_backward_shared_params(self) -> None:
        self._test_apply_optimizer_in_backward(share_params=True)

    def test_no_register_hook(self):
        model_with_hook = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10))
        initial_model = deepcopy(model_with_hook)
        model_no_hook = deepcopy(model_with_hook)
        _apply_optimizer_in_backward(
            torch.optim.SGD,
            model_with_hook.parameters(),
            optimizer_kwargs={"lr": 0.03},
        )
        _apply_optimizer_in_backward(
            torch.optim.SGD,
            model_no_hook.parameters(),
            optimizer_kwargs={"lr": 0.03},
            register_hook=False,
        )
        inp = torch.randn(4, 10)
        model_with_hook(inp).sum().backward()
        model_no_hook(inp).sum().backward()

        for p1, p2 in zip(model_with_hook.parameters(), initial_model.parameters()):
            with self.assertRaises(AssertionError):
                torch.testing.assert_allclose(p1, p2)

        for p1, p2 in zip(model_no_hook.parameters(), initial_model.parameters()):
            torch.testing.assert_allclose(p1, p2)

    def test_multiple_optim_for_params(self) -> None:
        model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10))
        opt_0_kwargs = {"lr": 0.03}
        opt_1_kwargs = {"lr": 0.01}
        opt_0 = torch.optim.SGD(model.parameters(), **opt_0_kwargs)
        opt_1 = torch.optim.SGD(model.parameters(), **opt_1_kwargs)
        model_with_opt_in_bwd = deepcopy(model)
        _apply_optimizer_in_backward(
            torch.optim.SGD,
            model_with_opt_in_bwd.parameters(),
            optimizer_kwargs=opt_0_kwargs,
        )
        _apply_optimizer_in_backward(
            torch.optim.SGD,
            model_with_opt_in_bwd.parameters(),
            optimizer_kwargs=opt_1_kwargs,
        )
        self._run_training_loop_and_validate(
            torch.randn(4, 10),
            [model, model_with_opt_in_bwd],
            [opt_0, opt_1],
        )

    def test_get_optimizers_in_backward(self):
        # Create a simple test model
        class TestModel(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear1 = torch.nn.Linear(10, 5)
                self.linear2 = torch.nn.Linear(5, 2)

        model = TestModel()

        # Apply optimizers in backward
        _apply_optimizer_in_backward(torch.optim.SGD, model.parameters(), {"lr": 0.01})
        in_backward_optims = _get_in_backward_optimizers(model)
        self.assertEqual(len(list(model.parameters())), len(in_backward_optims))
        result = set(in_backward_optims)
        expected = {
            optim for p in model.parameters() for optim in p._in_backward_optimizers
        }
        self.assertEqual(result, expected)