1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
|
import argparse
from itertools import product
from asap import ASAP
from datasets import get_dataset
from diff_pool import DiffPool
from edge_pool import EdgePool
from gcn import GCN, GCNWithJK
from gin import GIN, GIN0, GIN0WithJK, GINWithJK
from global_attention import GlobalAttentionNet
from graclus import Graclus
from graph_sage import GraphSAGE, GraphSAGEWithJK
from sag_pool import SAGPool
from set2set import Set2SetNet
from sort_pool import SortPool
from top_k import TopK
from train_eval import cross_validation_with_val_set
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--lr_decay_factor', type=float, default=0.5)
parser.add_argument('--lr_decay_step_size', type=int, default=50)
args = parser.parse_args()
layers = [1, 2, 3, 4, 5]
hiddens = [16, 32, 64, 128]
datasets = ['MUTAG', 'PROTEINS', 'IMDB-BINARY', 'REDDIT-BINARY'] # , 'COLLAB']
nets = [
GCNWithJK,
GraphSAGEWithJK,
GIN0WithJK,
GINWithJK,
Graclus,
TopK,
SAGPool,
DiffPool,
EdgePool,
GCN,
GraphSAGE,
GIN0,
GIN,
GlobalAttentionNet,
Set2SetNet,
SortPool,
ASAP,
]
def logger(info):
fold, epoch = info['fold'] + 1, info['epoch']
val_loss, test_acc = info['val_loss'], info['test_acc']
print(f'{fold:02d}/{epoch:03d}: Val Loss: {val_loss:.4f}, '
f'Test Accuracy: {test_acc:.3f}')
results = []
for dataset_name, Net in product(datasets, nets):
best_result = (float('inf'), 0, 0) # (loss, acc, std)
print(f'--\n{dataset_name} - {Net.__name__}')
for num_layers, hidden in product(layers, hiddens):
dataset = get_dataset(dataset_name, sparse=Net != DiffPool)
model = Net(dataset, num_layers, hidden)
loss, acc, std = cross_validation_with_val_set(
dataset,
model,
folds=10,
epochs=args.epochs,
batch_size=args.batch_size,
lr=args.lr,
lr_decay_factor=args.lr_decay_factor,
lr_decay_step_size=args.lr_decay_step_size,
weight_decay=0,
logger=None,
)
if loss < best_result[0]:
best_result = (loss, acc, std)
desc = f'{best_result[1]:.3f} ± {best_result[2]:.3f}'
print(f'Best result - {desc}')
results += [f'{dataset_name} - {model}: {desc}']
results = '\n'.join(results)
print(f'--\n{results}')
|