File: main.py

package info (click to toggle)
pytorch-ignite 0.5.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 11,712 kB
  • sloc: python: 46,874; sh: 376; makefile: 27
file content (153 lines) | stat: -rw-r--r-- 5,179 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import argparse

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from model import Net
from torch.utils.data import DataLoader
from torchvision.transforms.functional import center_crop, resize, to_tensor

from ignite.engine import Engine, Events

from ignite.handlers import BasicTimeProfiler, ProgressBar
from ignite.metrics import PSNR

# Training settings
parser = argparse.ArgumentParser(description="PyTorch Super Res Example")
parser.add_argument("--crop_size", type=int, default=256, help="cropped size of the images for training")
parser.add_argument("--upscale_factor", type=int, required=True, help="super resolution upscale factor")
parser.add_argument("--batch_size", type=int, default=64, help="training batch size")
parser.add_argument("--test_batch_size", type=int, default=10, help="testing batch size")
parser.add_argument("--n_epochs", type=int, default=2, help="number of epochs to train for")
parser.add_argument("--lr", type=float, default=0.01, help="Learning Rate. Default=0.01")
parser.add_argument("--cuda", action="store_true", help="use cuda?")
parser.add_argument("--mps", action="store_true", default=False, help="enables macOS GPU training")
parser.add_argument("--threads", type=int, default=4, help="number of threads for data loader to use")
parser.add_argument("--seed", type=int, default=123, help="random seed to use. Default=123")
parser.add_argument("--debug", action="store_true", help="use debug")

opt = parser.parse_args()

print(opt)

if opt.cuda and not torch.cuda.is_available():
    raise Exception("No GPU found, please run without --cuda")
if not opt.mps and torch.backends.mps.is_available():
    raise Exception("Found mps device, please run with --mps to enable macOS GPU")

torch.manual_seed(opt.seed)
use_mps = opt.mps and torch.backends.mps.is_available()

if opt.cuda:
    device = torch.device("cuda")
elif use_mps:
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print("===> Loading datasets")


class SRDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, scale_factor, crop_size=256):
        self.dataset = dataset
        self.scale_factor = scale_factor
        self.crop_size = crop_size

    def __getitem__(self, index):
        image, _ = self.dataset[index]
        img = image.convert("YCbCr")
        hr_image, _, _ = img.split()
        hr_image = center_crop(hr_image, self.crop_size)
        lr_image = hr_image.copy()
        if self.scale_factor != 1:
            size = self.crop_size // self.scale_factor
            lr_image = resize(lr_image, [size, size])
        hr_image = to_tensor(hr_image)
        lr_image = to_tensor(lr_image)
        return lr_image, hr_image

    def __len__(self):
        return len(self.dataset)


try:
    trainset = torchvision.datasets.Caltech101(root="./data", download=True)
    testset = torchvision.datasets.Caltech101(root="./data", download=False)
except RuntimeError:
    print("Dataset download problem, exiting without error code")
    exit(0)

trainset_sr = SRDataset(trainset, scale_factor=opt.upscale_factor, crop_size=opt.crop_size)
testset_sr = SRDataset(testset, scale_factor=opt.upscale_factor, crop_size=opt.crop_size)

training_data_loader = DataLoader(dataset=trainset_sr, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True)
testing_data_loader = DataLoader(dataset=testset_sr, num_workers=opt.threads, batch_size=opt.test_batch_size)

print("===> Building model")
model = Net(upscale_factor=opt.upscale_factor).to(device)
criterion = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=opt.lr)


def train_step(engine, batch):
    model.train()
    input, target = batch[0].to(device), batch[1].to(device)

    optimizer.zero_grad()
    loss = criterion(model(input), target)
    loss.backward()
    optimizer.step()

    return loss.item()


def validation_step(engine, batch):
    model.eval()
    with torch.no_grad():
        x, y = batch[0].to(device), batch[1].to(device)
        y_pred = model(x)

    return y_pred, y


trainer = Engine(train_step)
evaluator = Engine(validation_step)
psnr = PSNR(data_range=1)
psnr.attach(evaluator, "psnr")
validate_every = 1

if opt.debug:
    epoch_length = 10
    validate_epoch_length = 1
else:
    epoch_length = len(training_data_loader)
    validate_epoch_length = len(testing_data_loader)


@trainer.on(Events.EPOCH_COMPLETED(every=validate_every))
def log_validation():
    evaluator.run(testing_data_loader, epoch_length=validate_epoch_length)
    metrics = evaluator.state.metrics
    print(f"Epoch: {trainer.state.epoch}, Avg. PSNR: {metrics['psnr']} dB")


@trainer.on(Events.EPOCH_COMPLETED)
def checkpoint():
    model_out_path = "model_epoch_{}.pth".format(trainer.state.epoch)
    torch.save(model, model_out_path)
    print("Checkpoint saved to {}".format(model_out_path))


# Attach basic profiler
basic_profiler = BasicTimeProfiler()
basic_profiler.attach(trainer)

ProgressBar().attach(trainer, output_transform=lambda x: {"loss": x})

trainer.run(training_data_loader, opt.n_epochs, epoch_length=epoch_length)

results = basic_profiler.get_results()
basic_profiler.print_results(results)