File: dataloader_benchmark.py

package info (click to toggle)
pytorch 2.9.1%2Bdfsg-1~exp2
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 180,096 kB
  • sloc: python: 1,473,255; cpp: 942,030; ansic: 79,796; asm: 7,754; javascript: 2,502; java: 1,962; sh: 1,809; makefile: 628; xml: 8
file content (316 lines) | stat: -rw-r--r-- 10,033 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
#!/usr/bin/env python3
"""
Benchmark script for PyTorch DataLoader with different worker methods.

This script measures:
1. Dataloader initialization time
2. Dataloading speed (time per batch)
3. CPU memory utilization

Usage:
    python dataloader_benchmark.py --data_path /path/to/dataset --batch_size 32 --num_workers 4
"""

import argparse
import copy
import gc
import time

import psutil
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import ConcatDataset


def get_memory_usage():
    """
    Get current memory usage in MB. This includes all child processes.

    Returns:
        Total memory usage in MB
    """
    process = psutil.Process()

    main_memory = process.memory_full_info().pss

    # Add memory usage of all child processes
    for child in process.children(recursive=True):
        try:
            child_mem = child.memory_full_info().pss
            main_memory += child_mem
        except (psutil.NoSuchProcess, psutil.AccessDenied, AttributeError):
            # Process might have terminated or doesn't support PSS, fall back to USS
            print(f"Failed to get PSS for {child}, falling back to USS")
            child_mem = child.memory_info().uss
            main_memory += child_mem

    return main_memory / (1024 * 1024)


def print_detailed_memory():
    """Print detailed memory information."""
    process = psutil.Process()
    print("\nDetailed memory information:")
    try:
        print(
            f"  USS (Unique Set Size): {process.memory_full_info().uss / (1024 * 1024):.2f} MB"
        )
        print(
            f"  PSS (Proportional Set Size): {process.memory_full_info().pss / (1024 * 1024):.2f} MB"
        )
        print(
            f"  RSS (Resident Set Size): {process.memory_info().rss / (1024 * 1024):.2f} MB"
        )
    except Exception:
        print("  Detailed memory info not available")


def create_model():
    """Create a simple model for benchmarking."""
    model = resnet18()
    return model


def benchmark_dataloader(
    dataset,
    batch_size,
    num_workers,
    num_epochs=1,
    max_batches=10,
    multiprocessing_context=None,
    logging_freq=10,
):
    """Benchmark a dataloader with specific configuration."""
    print("\n--- Benchmarking DataLoader ---")

    # Clear memory before starting
    gc.collect()
    torch.cuda.empty_cache()

    # Create model
    model = create_model()

    # Measure memory before dataloader creation
    memory_before = get_memory_usage()
    print(f"Memory before DataLoader creation: {memory_before:.2f} MB")
    print_detailed_memory()

    # Measure dataloader initialization time
    start = time.perf_counter()
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available(),
        prefetch_factor=2 if num_workers > 0 else None,
        multiprocessing_context=multiprocessing_context,
    )
    it = iter(dataloader)
    dataloader_init_time = time.perf_counter() - start

    # Measure memory after dataloader creation
    memory_after = get_memory_usage()
    print(f"Memory after DataLoader creation: {memory_after:.2f} MB")
    print(f"Memory increase: {memory_after - memory_before:.2f} MB")

    # Create model and optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

    # Benchmark dataloading speed
    model.train()
    total_batches = 0
    total_samples = 0
    total_time = 0
    total_data_load_time = 0

    # Measure peak memory during training
    peak_memory = memory_after

    print(
        f"\nStarting training loop with {num_epochs} epochs (max {max_batches} batches per epoch)"
    )

    for epoch in range(num_epochs):
        while total_batches < max_batches:
            batch_start = time.perf_counter()

            try:
                inputs, labels = next(it)
            except StopIteration:
                break

            # Move data to device
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Capture data fetch time (including sending to device)
            data_load_time = time.perf_counter() - batch_start

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Capture batch time
            batch_time = time.perf_counter() - batch_start

            total_batches += 1
            total_samples += inputs.size(0)
            total_data_load_time += data_load_time
            total_time += batch_time

            # Update peak memory and log memory usage periodically
            if total_batches % 5 == 0:
                # Force garbage collection before measuring memory
                gc.collect()
                current_memory = get_memory_usage()

                if current_memory > peak_memory:
                    peak_memory = current_memory

            if total_batches % logging_freq == 0:
                print(
                    f"Epoch {epoch + 1}, Batch {total_batches}, "
                    f"Time: {batch_time:.4f}s, "
                    f"Memory: {current_memory:.2f} MB"
                )

    # Calculate statistics
    avg_data_load_time = (
        total_data_load_time / total_batches if total_batches > 0 else 0
    )
    avg_batch_time = total_time / total_batches if total_batches > 0 else 0
    samples_per_second = total_samples / total_time if total_time > 0 else 0

    results = {
        "dataloader_init_time": dataloader_init_time,
        "num_workers": num_workers,
        "batch_size": batch_size,
        "total_batches": total_batches,
        "avg_batch_time": avg_batch_time,
        "avg_data_load_time": avg_data_load_time,
        "samples_per_second": samples_per_second,
        "peak_memory_mb": peak_memory,
        "memory_increase_mb": peak_memory - memory_before,
    }

    print("\nResults:")
    print(f"  DataLoader init time: {dataloader_init_time:.4f} seconds")
    print(f"  Average data loading time: {avg_data_load_time:.4f} seconds")
    print(f"  Average batch time: {avg_batch_time:.4f} seconds")
    print(f"  Samples per second: {samples_per_second:.2f}")
    print(f"  Peak memory usage: {peak_memory:.2f} MB")
    print(f"  Memory increase: {peak_memory - memory_before:.2f} MB")

    # Clean up
    del model, optimizer
    del dataloader

    # Force garbage collection
    gc.collect()
    torch.cuda.empty_cache()

    return results


def main():
    parser = argparse.ArgumentParser(
        description="Benchmark PyTorch DataLoader with different worker methods"
    )
    parser.add_argument("--data_path", required=True, help="Path to dataset")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
    parser.add_argument("--num_workers", type=int, default=4, help="Number of workers")
    parser.add_argument(
        "--max_batches",
        type=int,
        default=100,
        help="Maximum number of batches per epoch",
    )
    parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs")
    parser.add_argument(
        "--multiprocessing_context",
        choices=["fork", "spawn", "forkserver"],
        default="forkserver",
        help="Multiprocessing context to use (fork, spawn, forkserver)",
    )
    parser.add_argument(
        "--dataset_copies",
        type=int,
        default=1,
        help="Number of copies of the dataset to concatenate (for testing memory usage)",
    )
    parser.add_argument(
        "--logging_freq",
        type=int,
        default=10,
        help="Frequency of logging memory usage during training",
    )
    args = parser.parse_args()

    # Print system info
    print("System Information:")
    # The following are handy for debugging if building from source worked correctly
    print(f"  PyTorch version: {torch.__version__}")
    print(f"  PyTorch location: {torch.__file__}")
    print(f"  Torchvision version: {torchvision.__version__}")
    print(f"  Torchvision location: {torchvision.__file__}")
    print(f"  CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"  CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"  CPU count: {psutil.cpu_count(logical=True)}")
    print(f"  Physical CPU cores: {psutil.cpu_count(logical=False)}")
    print(f"  Total system memory: {psutil.virtual_memory().total / (1024**3):.2f} GB")

    # Define transforms
    transform = transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

    # Load dataset
    print(f"\nLoading dataset from {args.data_path} ({args.dataset_copies} copies)")

    # Try to load as ImageFolder
    datasets = []
    for _ in range(args.dataset_copies):
        base_dataset = torchvision.datasets.ImageFolder(
            args.data_path, transform=transform
        )
        datasets.append(copy.deepcopy(base_dataset))
        del base_dataset
    dataset = ConcatDataset(datasets)

    print(f"Dataset size: {len(dataset)}")

    # Run benchmark with specified worker method
    benchmark_dataloader(
        dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        multiprocessing_context=args.multiprocessing_context,
        num_epochs=args.num_epochs,
        max_batches=args.max_batches,
        logging_freq=args.logging_freq,
    )


if __name__ == "__main__":
    main()