1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
|
import fire
import torch
from apex import amp
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torchvision.models import wide_resnet50_2
from utils import get_train_eval_loaders
from ignite.engine import convert_tensor, create_supervised_evaluator, Engine, Events
from ignite.handlers import ProgressBar, Timer
from ignite.metrics import Accuracy, Loss
def main(dataset_path, batch_size=256, max_epochs=10, opt="O1"):
assert torch.cuda.is_available()
assert torch.backends.cudnn.enabled, "NVIDIA/Apex:Amp requires cudnn backend to be enabled."
torch.backends.cudnn.benchmark = True
device = "cuda"
train_loader, test_loader, eval_train_loader = get_train_eval_loaders(dataset_path, batch_size=batch_size)
model = wide_resnet50_2(num_classes=100).to(device)
optimizer = SGD(model.parameters(), lr=0.01)
criterion = CrossEntropyLoss().to(device)
model, optimizer = amp.initialize(model, optimizer, opt_level=opt)
def train_step(engine, batch):
x = convert_tensor(batch[0], device, non_blocking=True)
y = convert_tensor(batch[1], device, non_blocking=True)
optimizer.zero_grad()
y_pred = model(x)
loss = criterion(y_pred, y)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
return loss.item()
trainer = Engine(train_step)
timer = Timer(average=True)
timer.attach(trainer, step=Events.EPOCH_COMPLETED)
ProgressBar(persist=True).attach(trainer, output_transform=lambda out: {"batch loss": out})
metrics = {"Accuracy": Accuracy(), "Loss": Loss(criterion)}
evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True)
def log_metrics(engine, title):
for name in metrics:
print(f"\t{title} {name}: {engine.state.metrics[name]:.2f}")
@trainer.on(Events.COMPLETED)
def run_validation(_):
print(f"- Mean elapsed time for 1 epoch: {timer.value()}")
print("- Metrics:")
with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "Train"):
evaluator.run(eval_train_loader)
with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "Test"):
evaluator.run(test_loader)
trainer.run(train_loader, max_epochs=max_epochs)
if __name__ == "__main__":
fire.Fire(main)
|