File: train_eval.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (130 lines) | stat: -rw-r--r-- 4,068 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import time

import torch
import torch.nn.functional as F
from torch.optim import Adam

from torch_geometric.loader import DataLoader
from torch_geometric.profile import timeit, torch_profile

if torch.cuda.is_available():
    device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')


def run_train(train_dataset, test_dataset, model, epochs, batch_size,
              use_compile, lr, lr_decay_factor, lr_decay_step_size,
              weight_decay):
    model = model.to(device)
    if use_compile:
        model = torch.compile(model)
    optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size, shuffle=False)

    for epoch in range(1, epochs + 1):
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        elif (hasattr(torch.backends, 'mps')
              and torch.backends.mps.is_available()):
            torch.mps.synchronize()

        t_start = time.perf_counter()

        train(model, optimizer, train_loader, device)
        test_acc = test(model, test_loader, device)

        if torch.cuda.is_available():
            torch.cuda.synchronize()
        elif (hasattr(torch.backends, 'mps')
              and torch.backends.mps.is_available()):
            torch.mps.synchronize()

        t_end = time.perf_counter()

        print(f'Epoch: {epoch:03d}, Test: {test_acc:.4f}, '
              f'Duration: {t_end - t_start:.2f}')

        if epoch % lr_decay_step_size == 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_decay_factor * param_group['lr']


@torch.no_grad()
def run_inference(test_dataset, model, epochs, batch_size, profiling, bf16,
                  use_compile):
    model = model.to(device)
    if use_compile:
        model = torch.compile(model)
    test_loader = DataLoader(test_dataset, batch_size, shuffle=False)

    if torch.cuda.is_available():
        amp = torch.amp.autocast('cuda', enabled=False)
    else:
        amp = torch.cpu.amp.autocast(enabled=bf16)

    with amp:
        for epoch in range(1, epochs + 1):
            print("Epoch: ", epoch)
            if epoch == epochs:
                with timeit():
                    inference(model, test_loader, device, bf16)
            else:
                inference(model, test_loader, device, bf16)

        if profiling:
            with torch_profile():
                inference(model, test_loader, device, bf16)


def run(train_dataset, test_dataset, model, epochs, batch_size, lr,
        lr_decay_factor, lr_decay_step_size, weight_decay, inference,
        profiling, bf16, use_compile):
    if not inference:
        run_train(train_dataset, test_dataset, model, epochs, batch_size,
                  use_compile, lr, lr_decay_factor, lr_decay_step_size,
                  weight_decay)
    else:
        run_inference(test_dataset, model, epochs, batch_size, profiling, bf16,
                      use_compile)


def train(model, optimizer, train_loader, device):
    model.train()

    for data in train_loader:
        optimizer.zero_grad()
        data = data.to(device)
        out = model(data.pos, data.batch)
        loss = F.nll_loss(out, data.y)
        loss.backward()
        optimizer.step()


@torch.no_grad()
def test(model, test_loader, device):
    model.eval()

    correct = 0
    for data in test_loader:
        data = data.to(device)
        pred = model(data.pos, data.batch).max(1)[1]
        correct += pred.eq(data.y).sum().item()
    test_acc = correct / len(test_loader.dataset)

    return test_acc


@torch.no_grad()
def inference(model, test_loader, device, bf16):
    model.eval()
    for data in test_loader:
        data = data.to(device)
        if bf16:
            data.pos = data.pos.to(torch.bfloat16)
            model = model.to(torch.bfloat16)
        model(data.pos, data.batch)