File: utils.py

package info (click to toggle)
pytorch-ignite 0.5.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 11,712 kB
  • sloc: python: 46,874; sh: 376; makefile: 27
file content (50 lines) | stat: -rw-r--r-- 1,889 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import random

from torch.utils.data import DataLoader, Subset
from torchvision.datasets.cifar import CIFAR100
from torchvision.transforms import Compose, Normalize, Pad, RandomCrop, RandomErasing, RandomHorizontalFlip, ToTensor


def get_train_eval_loaders(path, batch_size=256):
    """Setup the dataflow:
        - load CIFAR100 train and test datasets
        - setup train/test image transforms
            - horizontally flipped randomly and augmented using cutout.
            - each mini-batch contained 256 examples
        - setup train/test data loaders

    Returns:
        train_loader, test_loader, eval_train_loader
    """
    train_transform = Compose(
        [
            Pad(4),
            RandomCrop(32),
            RandomHorizontalFlip(),
            ToTensor(),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            RandomErasing(),
        ]
    )

    test_transform = Compose([ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

    train_dataset = CIFAR100(root=path, train=True, transform=train_transform, download=True)
    test_dataset = CIFAR100(root=path, train=False, transform=test_transform, download=False)

    train_eval_indices = [random.randint(0, len(train_dataset) - 1) for i in range(len(test_dataset))]
    train_eval_dataset = Subset(train_dataset, train_eval_indices)

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, num_workers=12, shuffle=True, drop_last=True, pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, num_workers=12, shuffle=False, drop_last=False, pin_memory=True
    )

    eval_train_loader = DataLoader(
        train_eval_dataset, batch_size=batch_size, num_workers=12, shuffle=False, drop_last=False, pin_memory=True
    )

    return train_loader, test_loader, eval_train_loader