1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
|
import argparse
import os.path as osp
import time
import torch
from sklearn.metrics import f1_score
from torch_geometric.datasets import PPI
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATConv
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='GeniePathLazy')
args = parser.parse_args()
assert args.model in ['GeniePath', 'GeniePathLazy']
path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', 'PPI')
train_dataset = PPI(path, split='train')
val_dataset = PPI(path, split='val')
test_dataset = PPI(path, split='test')
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)
dim = 256
lstm_hidden = 256
layer_num = 4
class Breadth(torch.nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.gatconv = GATConv(in_dim, out_dim, heads=1)
def forward(self, x, edge_index):
x = torch.tanh(self.gatconv(x, edge_index))
return x
class Depth(torch.nn.Module):
def __init__(self, in_dim, hidden):
super().__init__()
self.lstm = torch.nn.LSTM(in_dim, hidden, 1, bias=False)
def forward(self, x, h, c):
x, (h, c) = self.lstm(x, (h, c))
return x, (h, c)
class GeniePathLayer(torch.nn.Module):
def __init__(self, in_dim):
super().__init__()
self.breadth_func = Breadth(in_dim, dim)
self.depth_func = Depth(dim, lstm_hidden)
def forward(self, x, edge_index, h, c):
x = self.breadth_func(x, edge_index)
x = x[None, :]
x, (h, c) = self.depth_func(x, h, c)
x = x[0]
return x, (h, c)
class GeniePath(torch.nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.lin1 = torch.nn.Linear(in_dim, dim)
self.gplayers = torch.nn.ModuleList(
[GeniePathLayer(dim) for i in range(layer_num)])
self.lin2 = torch.nn.Linear(dim, out_dim)
def forward(self, x, edge_index):
x = self.lin1(x)
h = torch.zeros(1, x.shape[0], lstm_hidden, device=x.device)
c = torch.zeros(1, x.shape[0], lstm_hidden, device=x.device)
for i, l in enumerate(self.gplayers):
x, (h, c) = self.gplayers[i](x, edge_index, h, c)
x = self.lin2(x)
return x
class GeniePathLazy(torch.nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.lin1 = torch.nn.Linear(in_dim, dim)
self.breadths = torch.nn.ModuleList(
[Breadth(dim, dim) for i in range(layer_num)])
self.depths = torch.nn.ModuleList(
[Depth(dim * 2, lstm_hidden) for i in range(layer_num)])
self.lin2 = torch.nn.Linear(dim, out_dim)
def forward(self, x, edge_index):
x = self.lin1(x)
h = torch.zeros(1, x.shape[0], lstm_hidden, device=x.device)
c = torch.zeros(1, x.shape[0], lstm_hidden, device=x.device)
h_tmps = []
for i, l in enumerate(self.breadths):
h_tmps.append(self.breadths[i](x, edge_index))
x = x[None, :]
for i, l in enumerate(self.depths):
in_cat = torch.cat((h_tmps[i][None, :], x), -1)
x, (h, c) = self.depths[i](in_cat, h, c)
x = self.lin2(x[0])
return x
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
kwargs = {'GeniePath': GeniePath, 'GeniePathLazy': GeniePathLazy}
model = kwargs[args.model](train_dataset.num_features,
train_dataset.num_classes).to(device)
loss_op = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
def train():
model.train()
total_loss = 0
for data in train_loader:
num_graphs = data.num_graphs
data.batch = None
data = data.to(device)
optimizer.zero_grad()
loss = loss_op(model(data.x, data.edge_index), data.y)
total_loss += loss.item() * num_graphs
loss.backward()
optimizer.step()
return total_loss / len(train_loader.dataset)
def test(loader):
model.eval()
ys, preds = [], []
for data in loader:
ys.append(data.y)
with torch.no_grad():
out = model(data.x.to(device), data.edge_index.to(device))
preds.append((out > 0).float().cpu())
y, pred = torch.cat(ys, dim=0).numpy(), torch.cat(preds, dim=0).numpy()
return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0
times = []
for epoch in range(1, 101):
start = time.time()
loss = train()
val_f1 = test(val_loader)
test_f1 = test(test_loader)
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_f1:.4f}, '
f'Test: {test_f1:.4f}')
times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")
|