File: utils.py

package info (click to toggle)
accelerate 1.12.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 4,900 kB
  • sloc: python: 40,061; sh: 90; makefile: 79
file content (219 lines) | stat: -rw-r--r-- 7,822 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Common utilities for torch-native-parallelism examples.
"""

import time
from contextlib import nullcontext

import torch
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from accelerate import Accelerator


def get_dataset(tokenizer: AutoTokenizer, seq_len: int, accelerator: Accelerator | None = None) -> Dataset:
    """
    Load and prepare TinyStories dataset.

    Args:
        accelerator (Accelerator): Accelerate accelerator instance
        tokenizer (AutoTokenizer): Hugging Face tokenizer
        seq_len (int): Sequence length for the dataset

    Returns:
        Dataset: Packed dataset
    """
    processing_ctx = accelerator.main_process_first if accelerator else nullcontext
    raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:50%]")

    def tokenize_function(examples):
        tokenized_batch = tokenizer(
            examples["text"],
            padding=False,
            truncation=True,
            max_length=seq_len,
            return_tensors=None,
        )
        tokenized_batch["labels"] = tokenized_batch["input_ids"].copy()
        return tokenized_batch

    with processing_ctx():
        tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"])

    def create_packed_sequences(examples):
        all_tokens = []
        for input_ids in examples["input_ids"]:
            all_tokens.extend(input_ids)

        num_sequences = len(all_tokens) // (seq_len + 1)
        packed_input_ids = []
        packed_labels = []
        packed_position_ids = []

        for i in range(num_sequences):
            start_idx = i * (seq_len + 1)
            end_idx = start_idx + (seq_len + 1)
            full_sequence = all_tokens[start_idx:end_idx]
            packed_input_ids.append(full_sequence[:-1])
            packed_labels.append(full_sequence[1:])
            packed_position_ids.append(torch.arange(0, seq_len))

        return {
            "input_ids": packed_input_ids,
            "shift_labels": packed_labels,
            "position_ids": packed_position_ids,
            "labels": packed_labels,
        }

    with processing_ctx():
        packed_dataset = tokenized_dataset.map(
            create_packed_sequences,
            batched=True,
            remove_columns=tokenized_dataset.column_names,
            batch_size=1000,
        )

    return packed_dataset.shuffle(seed=42)


def get_model_flops_per_token(model: AutoModelForCausalLM, seq_len: int) -> float:
    """
    Get the number of flops per token for the model.

    Args:
        model (AutoModelForCausalLM): Model to get the flops for
        seq_len (int): Sequence length
    """
    cfg = model.config
    head_dim = cfg.hidden_size // cfg.num_attention_heads

    # MLP: 3 matmuls
    mlp_flops = 18 * cfg.hidden_size * cfg.intermediate_size

    # Attn (w/o dotproduct)
    attn_flops = 12 * head_dim * (cfg.num_attention_heads + cfg.num_key_value_heads)

    # attn (dotproduct) - this scales quadratically with sequence length
    attn_dotproduct_flops = 12 * cfg.num_attention_heads * head_dim * seq_len

    # we also ignore embeddings and layernorms, etc
    return (mlp_flops + attn_flops + attn_dotproduct_flops) * cfg.num_hidden_layers


def create_collate_fn():
    """Create a collate function for batching."""

    def collate_fn(batch):
        input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
        shift_labels = torch.tensor([item["shift_labels"] for item in batch], dtype=torch.long)
        return {"input_ids": input_ids, "shift_labels": shift_labels, "labels": shift_labels}

    return collate_fn


class PerformanceTracker:
    """Track training performance metrics."""

    def __init__(self, warmup_steps: int = 10):
        self.warmup_steps = warmup_steps
        self.reset()

    def reset(self):
        """Reset all tracking variables."""
        self.start_time = None
        self.num_tokens = 0
        self.is_in_warmup = True
        self.step_count = 0

    def step(self, batch_tokens: int, model_flops_per_token: float | None = None) -> dict:
        """
        Update performance tracking with a new step.

        Args:
            batch_tokens (int): Number of tokens in current batch

        Returns:
            dict: Performance metrics if past warmup, empty dict otherwise
        """
        self.step_count += 1

        if self.step_count == self.warmup_steps:
            self.start_time = time.perf_counter()
            self.num_tokens = 0
            self.is_in_warmup = False
            return {"warmup_completed": True}

        if not self.is_in_warmup and self.start_time is not None:
            dct = {}
            self.num_tokens += batch_tokens
            total_time = time.perf_counter() - self.start_time
            steps_from_warmup = self.step_count - self.warmup_steps

            if total_time > 0 and steps_from_warmup > 0:
                memory_stats = gpu_memory_usage_all()
                dct = {
                    "tokens_per_second": self.num_tokens / total_time,
                    "steps_per_second": steps_from_warmup / total_time,
                    "total_tokens": self.num_tokens,
                    "total_time": total_time,
                    **memory_stats,
                }

            if model_flops_per_token is not None:
                flops = model_flops_per_token * self.num_tokens
                dct["tflops_per_device"] = flops / (total_time * 1e12)

            return dct

        return {}

    def get_print_message(self, metrics: dict, with_memory: bool = False) -> str:
        print_msg = f" | Average steps/s: {metrics['steps_per_second']:.2f} | Average tokens/s: {metrics['tokens_per_second']:.2f} | Average TFLOPS: {metrics['tflops_per_device']:.2f}\n"
        if with_memory:
            print_msg += (
                f"\tMemory (GB): active={metrics['peak_memory_active']:.1f}, "
                f"alloc={metrics['peak_memory_alloc']:.1f}, "
                f"reserved={metrics['peak_memory_reserved']:.1f}"
            )
        return print_msg


def setup_tokenizer(model_id: str) -> AutoTokenizer:
    """Setup tokenizer with proper padding token."""
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    return tokenizer


def gpu_memory_usage_all(device=0):
    device_type = torch.accelerator.current_accelerator().type
    device = torch.device(f"{device_type}:{device}")
    torch_device_module = getattr(torch, device_type, torch.cuda)
    _BYTES_IN_GIB = 1024**3
    peak_memory_active = torch_device_module.memory_stats().get("active_bytes.all.peak", 0) / _BYTES_IN_GIB
    peak_memory_alloc = torch_device_module.max_memory_allocated(device) / _BYTES_IN_GIB
    peak_memory_reserved = torch_device_module.max_memory_reserved(device) / _BYTES_IN_GIB
    memory_stats = {
        "peak_memory_active": peak_memory_active,
        "peak_memory_alloc": peak_memory_alloc,
        "peak_memory_reserved": peak_memory_reserved,
    }
    torch_device_module.reset_peak_memory_stats(device)

    return memory_stats