1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
|
import time
import torch
import torch.nn.functional as F
def train_runtime(model, data, epochs, device):
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model = model.to(device)
data = data.to(device)
model.train()
mask = data.train_mask if 'train_mask' in data else data.train_idx
y = data.y[mask] if 'train_mask' in data else data.train_y
if torch.cuda.is_available():
torch.cuda.synchronize()
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
torch.mps.synchronize()
t_start = time.perf_counter()
for epoch in range(epochs):
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[mask], y)
loss.backward()
optimizer.step()
if torch.cuda.is_available():
torch.cuda.synchronize()
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
torch.mps.synchronize()
t_end = time.perf_counter()
return t_end - t_start
|