1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
|
import time
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch_geometric.loader import DataLoader
from torch_geometric.profile import timeit, torch_profile
if torch.cuda.is_available():
device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')
def run_train(train_dataset, test_dataset, model, epochs, batch_size,
use_compile, lr, lr_decay_factor, lr_decay_step_size,
weight_decay):
model = model.to(device)
if use_compile:
model = torch.compile(model)
optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size, shuffle=False)
for epoch in range(1, epochs + 1):
if torch.cuda.is_available():
torch.cuda.synchronize()
elif (hasattr(torch.backends, 'mps')
and torch.backends.mps.is_available()):
torch.mps.synchronize()
t_start = time.perf_counter()
train(model, optimizer, train_loader, device)
test_acc = test(model, test_loader, device)
if torch.cuda.is_available():
torch.cuda.synchronize()
elif (hasattr(torch.backends, 'mps')
and torch.backends.mps.is_available()):
torch.mps.synchronize()
t_end = time.perf_counter()
print(f'Epoch: {epoch:03d}, Test: {test_acc:.4f}, '
f'Duration: {t_end - t_start:.2f}')
if epoch % lr_decay_step_size == 0:
for param_group in optimizer.param_groups:
param_group['lr'] = lr_decay_factor * param_group['lr']
@torch.no_grad()
def run_inference(test_dataset, model, epochs, batch_size, profiling, bf16,
use_compile):
model = model.to(device)
if use_compile:
model = torch.compile(model)
test_loader = DataLoader(test_dataset, batch_size, shuffle=False)
if torch.cuda.is_available():
amp = torch.amp.autocast('cuda', enabled=False)
else:
amp = torch.cpu.amp.autocast(enabled=bf16)
with amp:
for epoch in range(1, epochs + 1):
print("Epoch: ", epoch)
if epoch == epochs:
with timeit():
inference(model, test_loader, device, bf16)
else:
inference(model, test_loader, device, bf16)
if profiling:
with torch_profile():
inference(model, test_loader, device, bf16)
def run(train_dataset, test_dataset, model, epochs, batch_size, lr,
lr_decay_factor, lr_decay_step_size, weight_decay, inference,
profiling, bf16, use_compile):
if not inference:
run_train(train_dataset, test_dataset, model, epochs, batch_size,
use_compile, lr, lr_decay_factor, lr_decay_step_size,
weight_decay)
else:
run_inference(test_dataset, model, epochs, batch_size, profiling, bf16,
use_compile)
def train(model, optimizer, train_loader, device):
model.train()
for data in train_loader:
optimizer.zero_grad()
data = data.to(device)
out = model(data.pos, data.batch)
loss = F.nll_loss(out, data.y)
loss.backward()
optimizer.step()
@torch.no_grad()
def test(model, test_loader, device):
model.eval()
correct = 0
for data in test_loader:
data = data.to(device)
pred = model(data.pos, data.batch).max(1)[1]
correct += pred.eq(data.y).sum().item()
test_acc = correct / len(test_loader.dataset)
return test_acc
@torch.no_grad()
def inference(model, test_loader, device, bf16):
model.eval()
for data in test_loader:
data = data.to(device)
if bf16:
data.pos = data.pos.to(torch.bfloat16)
model = model.to(torch.bfloat16)
model(data.pos, data.batch)
|