File: git_mol.py

package info (click to toggle)
pytorch-geometric 2.7.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 14,172 kB
  • sloc: python: 144,911; sh: 247; cpp: 27; makefile: 18; javascript: 16
file content (135 lines) | stat: -rw-r--r-- 4,726 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""This example implements the GIT-Mol model
(https://arxiv.org/abs/2308.06911) using PyG.
"""
import argparse
import os.path as osp

import torch
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm

from torch_geometric import seed_everything
from torch_geometric.datasets import GitMolDataset
from torch_geometric.llm.models import GITMol
from torch_geometric.loader import DataLoader


@torch.no_grad()
def eval(model, data_loader):
    model.eval()
    loss = 0

    for batch in data_loader:
        batch_loss = model(batch.x, batch.edge_index, batch.batch,
                           batch.edge_attr, batch.smiles, batch.image,
                           batch.caption)
        loss += batch_loss.item() / len(data_loader)
    return loss


def train(
    num_epochs: int,
    lr: float,
    weight_decay: float,
    batch_size: int,
    checkpointing: bool,
):
    # Load dataset ================================================
    path = osp.dirname(osp.realpath(__file__))
    path = osp.join(path, '..', '..', 'data', 'GITMol')
    train_dataset = GitMolDataset(path, split=0)
    val_dataset = GitMolDataset(path, split=1)
    test_dataset = GitMolDataset(path, split=2)

    seed_everything(42)

    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                              drop_last=True, pin_memory=True, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size,
                            drop_last=False, pin_memory=True, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size,
                             drop_last=False, pin_memory=True, shuffle=False)

    # Create model ===============================================
    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
    device = accelerator.device
    model = GITMol().to(device)
    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad], lr=lr,
        weight_decay=weight_decay)
    scheduler = StepLR(optimizer, step_size=1, gamma=0.1)
    model, optimizer, train_loader, scheduler = accelerator.prepare(
        model, optimizer, train_loader, scheduler)
    val_loader = accelerator.prepare_data_loader(val_loader,
                                                 device_placement=True)
    test_loader = accelerator.prepare_data_loader(test_loader,
                                                  device_placement=True)

    # Train and eval ============================================
    best_epoch = 0
    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        # Train
        model.train()
        epoch_loss = 0
        if epoch == 0:
            print("Training beginning...")
        epoch_str = f'Epoch: {epoch + 1}|{num_epochs}'

        for batch in tqdm(train_loader, desc=epoch_str):
            optimizer.zero_grad()
            loss = model(batch.x, batch.edge_index, batch.batch,
                         batch.edge_attr, batch.smiles, batch.image,
                         batch.caption)
            accelerator.backward(loss)

            optimizer.step()
            epoch_loss += loss.item()

        train_loss = epoch_loss / len(train_loader)

        # Eval
        val_loss = eval(model, val_loader)
        print(
            f'{epoch_str}, Train loss: {train_loss:4f}, Val loss: {val_loss:4f}'  # noqa: E501
        )

        if checkpointing and val_loss < best_val_loss:
            best_val_loss = val_loss
            best_epoch = epoch
            torch.save(
                {
                    'model_state_dict':
                    accelerator.unwrap_model(model).state_dict(),
                    'best_loss':
                    best_val_loss
                },
                f'gitmol_pretrain_epoch{best_epoch}_val_loss{best_val_loss:4f}_ckpt.pt'  # noqa: E501
            )
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    # Test
    test_loss = eval(model, test_loader)
    print(f'Test loss: {test_loss:4f}')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs', type=int, default=3)
    parser.add_argument('--lr', type=float, default=1e-5)
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument("--weight_decay", type=float, default=0.01)
    parser.add_argument('--checkpointing', type=bool, default=True)
    args = parser.parse_args()

    train(
        args.epochs,
        args.lr,
        args.weight_decay,
        args.batch_size,
        args.checkpointing,
    )