1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645
|
# Code adapted from OGB.
# https://github.com/snap-stanford/ogb/tree/master/examples/lsc/pcqm4m-v2
import argparse
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
import torch.optim as optim
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
from torch_geometric.data import Data
from torch_geometric.datasets import PCQM4Mv2
from torch_geometric.io import fs
from torch_geometric.loader import DataLoader
from torch_geometric.nn import (
GlobalAttention,
MessagePassing,
Set2Set,
global_add_pool,
global_max_pool,
global_mean_pool,
)
from torch_geometric.utils import degree
try:
from ogb.lsc import PCQM4Mv2Evaluator, PygPCQM4Mv2Dataset
except ImportError as e:
raise ImportError(
"`PygPCQM4Mv2Dataset` requires rdkit (`pip install rdkit`)") from e
from ogb.utils import smiles2graph
def ogb_from_smiles_wrapper(smiles, *args, **kwargs):
"""Returns `torch_geometric.data.Data` object from smiles while
`ogb.utils.smiles2graph` returns a dict of np arrays.
"""
data_dict = smiles2graph(smiles, *args, **kwargs)
return Data(
x=torch.from_numpy(data_dict['node_feat']),
edge_index=torch.from_numpy(data_dict['edge_index']),
edge_attr=torch.from_numpy(data_dict['edge_feat']),
smiles=smiles,
)
class GINConv(MessagePassing):
def __init__(self, emb_dim):
r"""GINConv.
Args:
emb_dim (int): node embedding dimensionality
"""
super().__init__(aggr="add")
self.mlp = torch.nn.Sequential(
torch.nn.Linear(emb_dim, emb_dim),
torch.nn.BatchNorm1d(emb_dim),
torch.nn.ReLU(),
torch.nn.Linear(emb_dim, emb_dim),
)
self.eps = torch.nn.Parameter(torch.Tensor([0]))
self.bond_encoder = BondEncoder(emb_dim=emb_dim)
def forward(self, x, edge_index, edge_attr):
edge_embedding = self.bond_encoder(edge_attr)
return self.mlp(
(1 + self.eps) * x +
self.propagate(edge_index, x=x, edge_attr=edge_embedding))
def message(self, x_j, edge_attr):
return F.relu(x_j + edge_attr)
def update(self, aggr_out):
return aggr_out
class GCNConv(MessagePassing):
def __init__(self, emb_dim):
super().__init__(aggr='add')
self.linear = torch.nn.Linear(emb_dim, emb_dim)
self.root_emb = torch.nn.Embedding(1, emb_dim)
self.bond_encoder = BondEncoder(emb_dim=emb_dim)
def forward(self, x, edge_index, edge_attr):
x = self.linear(x)
edge_embedding = self.bond_encoder(edge_attr)
row, col = edge_index
deg = degree(row, x.size(0), dtype=x.dtype) + 1
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
return self.propagate(
edge_index, x=x, edge_attr=edge_embedding, norm=norm
) + F.relu(x + self.root_emb.weight) * 1. / deg.view(-1, 1)
def message(self, x_j, edge_attr, norm):
return norm.view(-1, 1) * F.relu(x_j + edge_attr)
def update(self, aggr_out):
return aggr_out
class GNNNode(torch.nn.Module):
def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last",
residual=False, gnn_type='gin'):
r"""GNN Node.
Args:
emb_dim (int): node embedding dimensionality.
num_layers (int): number of GNN message passing layers.
residual (bool): whether to add residual connection.
drop_ratio (float): dropout ratio.
JK (str): "last" or "sum" to choose JK concat strat.
residual (bool): Wether or not to add the residual
gnn_type (str): Type of GNN to use.
"""
super().__init__()
if num_layers < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
self.num_layers = num_layers
self.drop_ratio = drop_ratio
self.JK = JK
self.residual = residual
self.atom_encoder = AtomEncoder(emb_dim)
self.convs = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
for _ in range(num_layers):
if gnn_type == 'gin':
self.convs.append(GINConv(emb_dim))
elif gnn_type == 'gcn':
self.convs.append(GCNConv(emb_dim))
else:
raise ValueError(f'Undefined GNN type called {gnn_type}')
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
def forward(self, batched_data):
x = batched_data.x
edge_index = batched_data.edge_index
edge_attr = batched_data.edge_attr
# compute input node embedding
h_list = [self.atom_encoder(x)]
for layer in range(self.num_layers):
h = self.convs[layer](h_list[layer], edge_index, edge_attr)
h = self.batch_norms[layer](h)
if layer == self.num_layers - 1:
# remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training=self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio,
training=self.training)
if self.residual:
h += h_list[layer]
h_list.append(h)
# Different implementations of Jk-concat
if self.JK == "last":
node_representation = h_list[-1]
elif self.JK == "sum":
node_representation = 0
for layer in range(self.num_layers + 1):
node_representation += h_list[layer]
return node_representation
class GNNNodeVirtualNode(torch.nn.Module):
"""Outputs node representations."""
def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last",
residual=False, gnn_type='gin'):
super().__init__()
if num_layers < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
self.num_layers = num_layers
self.drop_ratio = drop_ratio
self.JK = JK
self.residual = residual
self.atom_encoder = AtomEncoder(emb_dim)
# set the initial virtual node embedding to 0.
self.virtualnode_embedding = torch.nn.Embedding(1, emb_dim)
torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0)
self.convs = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
self.mlp_virtualnode_list = torch.nn.ModuleList()
for _ in range(num_layers):
if gnn_type == 'gin':
self.convs.append(GINConv(emb_dim))
elif gnn_type == 'gcn':
self.convs.append(GCNConv(emb_dim))
else:
raise ValueError('Undefined GNN type called {gnn_type}')
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
for _ in range(num_layers - 1):
self.mlp_virtualnode_list.append(
torch.nn.Sequential(
torch.nn.Linear(emb_dim, emb_dim),
torch.nn.BatchNorm1d(emb_dim),
torch.nn.ReLU(),
torch.nn.Linear(emb_dim, emb_dim),
torch.nn.BatchNorm1d(emb_dim),
torch.nn.ReLU(),
))
def forward(self, batched_data):
x = batched_data.x
edge_index = batched_data.edge_index
edge_attr = batched_data.edge_attr
batch = batched_data.batch
# virtual node embeddings for graphs
virtualnode_embedding = self.virtualnode_embedding(
torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(
edge_index.device))
h_list = [self.atom_encoder(x)]
for layer in range(self.num_layers):
# add message from virtual nodes to graph nodes
h_list[layer] = h_list[layer] + virtualnode_embedding[batch]
# Message passing among graph nodes
h = self.convs[layer](h_list[layer], edge_index, edge_attr)
h = self.batch_norms[layer](h)
if layer == self.num_layers - 1:
# remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training=self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio,
training=self.training)
if self.residual:
h = h + h_list[layer]
h_list.append(h)
# update the virtual nodes
if layer < self.num_layers - 1:
# add message from graph nodes to virtual nodes
virtualnode_embedding_temp = global_add_pool(
h_list[layer], batch) + virtualnode_embedding
# transform virtual nodes using MLP
if self.residual:
virtualnode_embedding = virtualnode_embedding + F.dropout(
self.mlp_virtualnode_list[layer]
(virtualnode_embedding_temp), self.drop_ratio,
training=self.training)
else:
virtualnode_embedding = F.dropout(
self.mlp_virtualnode_list[layer](
virtualnode_embedding_temp), self.drop_ratio,
training=self.training)
# Different implementations of Jk-concat
if self.JK == "last":
node_representation = h_list[-1]
elif self.JK == "sum":
node_representation = 0
for layer in range(self.num_layers + 1):
node_representation += h_list[layer]
return node_representation
class GNN(torch.nn.Module):
def __init__(
self,
num_tasks=1,
num_layers=5,
emb_dim=300,
gnn_type='gin',
virtual_node=True,
residual=False,
drop_ratio=0,
JK="last",
graph_pooling="sum",
):
r"""GNN.
Args:
num_tasks (int): number of labels to be predicted
num_layers (int): number of gnn layers.
emb_dim (int): embedding dim to use.
gnn_type (str): Type of GNN to use.
virtual_node (bool): whether to add virtual node or not.
residual (bool): Wether or not to add the residual
drop_ratio (float): dropout ratio.
JK (str): "last" or "sum" to choose JK concat strat.
graph_pooling (str): Graph pooling strat to use.
"""
super().__init__()
if num_layers < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
self.num_layers = num_layers
self.drop_ratio = drop_ratio
self.JK = JK
self.emb_dim = emb_dim
self.num_tasks = num_tasks
self.graph_pooling = graph_pooling
if virtual_node:
self.gnn_node = GNNNodeVirtualNode(
num_layers,
emb_dim,
JK=JK,
drop_ratio=drop_ratio,
residual=residual,
gnn_type=gnn_type,
)
else:
self.gnn_node = GNNNode(
num_layers,
emb_dim,
JK=JK,
drop_ratio=drop_ratio,
residual=residual,
gnn_type=gnn_type,
)
# Pooling function to generate whole-graph embeddings
if self.graph_pooling == "sum":
self.pool = global_add_pool
elif self.graph_pooling == "mean":
self.pool = global_mean_pool
elif self.graph_pooling == "max":
self.pool = global_max_pool
elif self.graph_pooling == "attention":
self.pool = GlobalAttention(gate_nn=torch.nn.Sequential(
torch.nn.Linear(emb_dim, emb_dim),
torch.nn.BatchNorm1d(emb_dim),
torch.nn.ReLU(),
torch.nn.Linear(emb_dim, 1),
))
elif self.graph_pooling == "set2set":
self.pool = Set2Set(emb_dim, processing_steps=2)
else:
raise ValueError("Invalid graph pooling type.")
if graph_pooling == "set2set":
self.graph_pred_linear = torch.nn.Linear(2 * emb_dim, num_tasks)
else:
self.graph_pred_linear = torch.nn.Linear(emb_dim, num_tasks)
def forward(self, batched_data):
h_node = self.gnn_node(batched_data)
h_graph = self.pool(h_node, batched_data.batch)
output = self.graph_pred_linear(h_graph)
if self.training:
return output
else:
# At inference time, we clamp the value between 0 and 20
return torch.clamp(output, min=0, max=20)
def train(model, rank, device, loader, optimizer):
model.train()
reg_criterion = torch.nn.L1Loss()
loss_accum = 0.0
for step, batch in enumerate(
tqdm(loader, desc="Training", disable=(rank > 0))):
batch = batch.to(device)
pred = model(batch).view(-1, )
optimizer.zero_grad()
loss = reg_criterion(pred, batch.y)
loss.backward()
optimizer.step()
loss_accum += loss.detach().cpu().item()
return loss_accum / (step + 1)
def eval(model, device, loader, evaluator):
model.eval()
y_true = []
y_pred = []
for batch in tqdm(loader, desc="Evaluating"):
batch = batch.to(device)
with torch.no_grad():
pred = model(batch).view(-1, )
y_true.append(batch.y.view(pred.shape).detach().cpu())
y_pred.append(pred.detach().cpu())
y_true = torch.cat(y_true, dim=0)
y_pred = torch.cat(y_pred, dim=0)
input_dict = {"y_true": y_true, "y_pred": y_pred}
return evaluator.eval(input_dict)["mae"]
def test(model, device, loader):
model.eval()
y_pred = []
for batch in tqdm(loader, desc="Testing"):
batch = batch.to(device)
with torch.no_grad():
pred = model(batch).view(-1, )
y_pred.append(pred.detach().cpu())
y_pred = torch.cat(y_pred, dim=0)
return y_pred
def run(rank, dataset, args):
num_devices = args.num_devices
device = torch.device(
"cuda:" + str(rank)) if num_devices > 0 else torch.device("cpu")
if num_devices > 1:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group("nccl", rank=rank, world_size=num_devices)
if args.on_disk_dataset:
train_idx = torch.arange(len(dataset.indices()))
else:
split_idx = dataset.get_idx_split()
train_idx = split_idx["train"]
if num_devices > 1:
train_idx = train_idx.split(train_idx.size(0) // num_devices)[rank]
if args.train_subset:
subset_ratio = 0.1
n = len(train_idx)
subset_idx = torch.randperm(n)[:int(subset_ratio * n)]
train_dataset = dataset[train_idx[subset_idx]]
else:
train_dataset = dataset[train_idx]
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
)
if rank == 0:
if args.on_disk_dataset:
valid_dataset = PCQM4Mv2(root='on_disk_dataset/', split="val",
from_smiles_func=ogb_from_smiles_wrapper)
test_dev_dataset = PCQM4Mv2(
root='on_disk_dataset/', split="test",
from_smiles_func=ogb_from_smiles_wrapper)
test_challenge_dataset = PCQM4Mv2(
root='on_disk_dataset/', split="holdout",
from_smiles_func=ogb_from_smiles_wrapper)
else:
valid_dataset = dataset[split_idx["valid"]]
test_dev_dataset = dataset[split_idx["test-dev"]]
test_challenge_dataset = dataset[split_idx["test-challenge"]]
valid_loader = DataLoader(
valid_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
)
if args.save_test_dir != '':
testdev_loader = DataLoader(
test_dev_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
)
testchallenge_loader = DataLoader(
test_challenge_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
)
if args.checkpoint_dir != '':
os.makedirs(args.checkpoint_dir, exist_ok=True)
evaluator = PCQM4Mv2Evaluator()
gnn_type, virtual_node = args.gnn.split('-')
model = GNN(
gnn_type=gnn_type,
virtual_node=virtual_node,
num_layers=args.num_layers,
emb_dim=args.emb_dim,
drop_ratio=args.drop_ratio,
graph_pooling=args.graph_pooling,
)
if num_devices > 0:
model = model.to(rank)
if num_devices > 1:
model = DistributedDataParallel(model, device_ids=[rank])
optimizer = optim.Adam(model.parameters(), lr=0.001)
if args.log_dir != '':
writer = SummaryWriter(log_dir=args.log_dir)
best_valid_mae = 1000
if args.train_subset:
scheduler = StepLR(optimizer, step_size=300, gamma=0.25)
args.epochs = 1000
else:
scheduler = StepLR(optimizer, step_size=30, gamma=0.25)
current_epoch = 1
checkpoint_path = os.path.join(args.checkpoint_dir, 'checkpoint.pt')
if os.path.isfile(checkpoint_path):
checkpoint = fs.torch_load(checkpoint_path)
current_epoch = checkpoint['epoch'] + 1
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
best_valid_mae = checkpoint['best_val_mae']
print(f"Found checkpoint, resume training at epoch {current_epoch}")
for epoch in range(current_epoch, args.epochs + 1):
train_mae = train(model, rank, device, train_loader, optimizer)
if num_devices > 1:
dist.barrier()
if rank == 0:
valid_mae = eval(
model.module if isinstance(model, DistributedDataParallel) else
model, device, valid_loader, evaluator)
print(f"Epoch {epoch:02d}, "
f"Train MAE: {train_mae:.4f}, "
f"Val MAE: {valid_mae:.4f}")
if args.log_dir != '':
writer.add_scalar('valid/mae', valid_mae, epoch)
writer.add_scalar('train/mae', train_mae, epoch)
if valid_mae < best_valid_mae:
best_valid_mae = valid_mae
if args.checkpoint_dir != '':
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'best_val_mae': best_valid_mae,
}
torch.save(checkpoint, checkpoint_path)
if args.save_test_dir != '':
test_model = model.module if isinstance(
model, DistributedDataParallel) else model
testdev_pred = test(test_model, device, testdev_loader)
evaluator.save_test_submission(
{'y_pred': testdev_pred.cpu().detach().numpy()},
args.save_test_dir,
mode='test-dev',
)
testchallenge_pred = test(test_model, device,
testchallenge_loader)
evaluator.save_test_submission(
{'y_pred': testchallenge_pred.cpu().detach().numpy()},
args.save_test_dir,
mode='test-challenge',
)
print(f'Best validation MAE so far: {best_valid_mae}')
if num_devices > 1:
dist.barrier()
scheduler.step()
if rank == 0 and args.log_dir != '':
writer.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='GNN baselines on pcqm4m with Pytorch Geometrics',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--gnn', type=str, default='gin-virtual',
choices=['gin', 'gin-virtual', 'gcn',
'gcn-virtual'], help='GNN architecture')
parser.add_argument('--graph_pooling', type=str, default='sum',
help='graph pooling strategy mean or sum')
parser.add_argument('--drop_ratio', type=float, default=0,
help='dropout ratio')
parser.add_argument('--num_layers', type=int, default=5,
help='number of GNN message passing layers')
parser.add_argument('--emb_dim', type=int, default=600,
help='dimensionality of hidden units in GNNs')
parser.add_argument('--train_subset', action='store_true')
parser.add_argument('--batch_size', type=int, default=256,
help='input batch size for training')
parser.add_argument('--epochs', type=int, default=100,
help='number of epochs to train')
parser.add_argument('--num_workers', type=int, default=0,
help='number of workers')
parser.add_argument('--log_dir', type=str, default="",
help='tensorboard log directory')
parser.add_argument('--checkpoint_dir', type=str, default='',
help='directory to save checkpoint')
parser.add_argument('--save_test_dir', type=str, default='',
help='directory to save test submission file')
parser.add_argument('--num_devices', type=int, default='1',
help="Number of GPUs, if 0 runs on the CPU")
parser.add_argument('--on_disk_dataset', action='store_true')
args = parser.parse_args()
available_gpus = torch.cuda.device_count() if torch.cuda.is_available(
) else 0
if args.num_devices > available_gpus:
if available_gpus == 0:
print("No GPUs available, running w/ CPU...")
else:
raise ValueError(f"Cannot train with {args.num_devices} GPUs: "
f"available GPUs count {available_gpus}")
# automatic dataloading and splitting
if args.on_disk_dataset:
dataset = PCQM4Mv2(root='on_disk_dataset/', split='train',
from_smiles_func=ogb_from_smiles_wrapper)
else:
dataset = PygPCQM4Mv2Dataset(root='dataset/')
if args.num_devices > 1:
mp.spawn(run, args=(dataset, args), nprocs=args.num_devices, join=True)
else:
run(0, dataset, args)
|