File: main.py

package info (click to toggle)
accelerate 1.12.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 4,900 kB
  • sloc: python: 40,061; sh: 90; makefile: 79
file content (122 lines) | stat: -rw-r--r-- 3,721 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from typing import Callable

import torch

from accelerate import Accelerator
from utils import parse_args, prepare_accelerate, prepare_torch


MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
LEARNING_RATE = 3e-5

CONFIG = {
    "model_name": MODEL_NAME,
    "learning_rate": LEARNING_RATE,
}


def train(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    train_dataloader: torch.utils.data.DataLoader,
    accelerator: Accelerator,
) -> torch.Tensor:
    losses = []
    for batch in train_dataloader:
        optimizer.zero_grad()
        outputs = model(**batch, use_cache=False)

        loss = outputs.loss
        losses.append(loss.item())
        accelerator.backward(loss)
        optimizer.step()

    return torch.tensor(losses)


def evaluate(args, config: dict, init_fn: Callable, run_name: str) -> torch.Tensor:
    model, optimizer, dataloader, accelerator, memory_tracker = init_fn(args, config)

    loss = train(model, optimizer, dataloader, accelerator)

    memory_tracker.stop()
    msg = f"""Results for {run_name} (rank 0):
Loss: {loss[-1].item()}
Peak Allocated Memory: {float(memory_tracker.peak_allocated_memory):.2f} MB
Peak Reserved Memory: {float(memory_tracker.peak_reserved_memory):.2f} MB
{"-" * 34}"""
    accelerator.print(msg)
    return loss


def main():
    args = parse_args()
    evaluations = [
        functools.partial(
            evaluate,
            init_fn=functools.partial(prepare_torch, post_shard_optimizer=False, apply_optimizer_fix=True),
            run_name="Optimizer Before FSDP (w/ fix)",
        ),
        functools.partial(
            evaluate,
            init_fn=functools.partial(prepare_torch, post_shard_optimizer=False, apply_optimizer_fix=False),
            run_name="Optimizer Before FSDP (w/o fix)",
        ),
        functools.partial(
            evaluate,
            init_fn=functools.partial(prepare_torch, post_shard_optimizer=True),
            run_name="Optimizer After FSDP",
        ),
        functools.partial(evaluate, init_fn=prepare_accelerate, run_name="Accelerate"),
    ]
    labels = [
        "Optimizer Before FSDP (w/ fix)",
        "Optimizer Before FSDP (w/o fix)",
        "Optimizer After FSDP",
        "Accelerate",
    ]

    results = {}
    torch.use_deterministic_algorithms(True)

    for evaluation, label in zip(evaluations, labels):
        results[label] = evaluation(args, CONFIG)

    torch.testing.assert_close(
        results["Optimizer After FSDP"],
        results["Optimizer Before FSDP (w/ fix)"],
        msg="Optimizer After FSDP and Optimizer Before FSDP (w/ fix) should be the same",
    )

    torch.testing.assert_close(
        results["Optimizer After FSDP"],
        results["Accelerate"],
        msg="Optimizer After FSDP and Accelerate should be the same",
    )

    torch.testing.assert_close(
        results["Accelerate"],
        results["Optimizer Before FSDP (w/ fix)"],
        msg="Accelerate and Optimizer Before FSDP (w/ fix) should be the same",
    )

    torch.distributed.destroy_process_group()


if __name__ == "__main__":
    main()