1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
|
import pytest
import torch
from torch.optim.lr_scheduler import ConstantLR, LambdaLR, ReduceLROnPlateau
import torch_geometric
from torch_geometric.nn.resolver import (
activation_resolver,
aggregation_resolver,
lr_scheduler_resolver,
normalization_resolver,
optimizer_resolver,
)
def test_activation_resolver():
assert isinstance(activation_resolver(torch.nn.ELU()), torch.nn.ELU)
assert isinstance(activation_resolver(torch.nn.ReLU()), torch.nn.ReLU)
assert isinstance(activation_resolver(torch.nn.PReLU()), torch.nn.PReLU)
assert isinstance(activation_resolver('elu'), torch.nn.ELU)
assert isinstance(activation_resolver('relu'), torch.nn.ReLU)
assert isinstance(activation_resolver('prelu'), torch.nn.PReLU)
@pytest.mark.parametrize('aggr_tuple', [
(torch_geometric.nn.MeanAggregation, 'mean'),
(torch_geometric.nn.SumAggregation, 'sum'),
(torch_geometric.nn.SumAggregation, 'add'),
(torch_geometric.nn.MaxAggregation, 'max'),
(torch_geometric.nn.MinAggregation, 'min'),
(torch_geometric.nn.MulAggregation, 'mul'),
(torch_geometric.nn.VarAggregation, 'var'),
(torch_geometric.nn.StdAggregation, 'std'),
(torch_geometric.nn.SoftmaxAggregation, 'softmax'),
(torch_geometric.nn.PowerMeanAggregation, 'powermean'),
])
def test_aggregation_resolver(aggr_tuple):
aggr_module, aggr_repr = aggr_tuple
assert isinstance(aggregation_resolver(aggr_module()), aggr_module)
assert isinstance(aggregation_resolver(aggr_repr), aggr_module)
def test_multi_aggregation_resolver():
aggr = aggregation_resolver(None)
assert aggr is None
aggr = aggregation_resolver(['sum', 'mean', None])
assert isinstance(aggr, torch_geometric.nn.MultiAggregation)
assert len(aggr.aggrs) == 3
assert isinstance(aggr.aggrs[0], torch_geometric.nn.SumAggregation)
assert isinstance(aggr.aggrs[1], torch_geometric.nn.MeanAggregation)
assert aggr.aggrs[2] is None
@pytest.mark.parametrize('norm_tuple', [
(torch_geometric.nn.BatchNorm, 'batch', (16, )),
(torch_geometric.nn.BatchNorm, 'batch_norm', (16, )),
(torch_geometric.nn.InstanceNorm, 'instance_norm', (16, )),
(torch_geometric.nn.LayerNorm, 'layer_norm', (16, )),
(torch_geometric.nn.GraphNorm, 'graph_norm', (16, )),
(torch_geometric.nn.GraphSizeNorm, 'graphsize_norm', ()),
(torch_geometric.nn.PairNorm, 'pair_norm', ()),
(torch_geometric.nn.MessageNorm, 'message_norm', ()),
(torch_geometric.nn.DiffGroupNorm, 'diffgroup_norm', (16, 4)),
])
def test_normalization_resolver(norm_tuple):
norm_module, norm_repr, norm_args = norm_tuple
assert isinstance(normalization_resolver(norm_module(*norm_args)),
norm_module)
assert isinstance(normalization_resolver(norm_repr, *norm_args),
norm_module)
def test_optimizer_resolver():
params = [torch.nn.Parameter(torch.randn(1))]
assert isinstance(optimizer_resolver(torch.optim.SGD(params, lr=0.01)),
torch.optim.SGD)
assert isinstance(optimizer_resolver(torch.optim.Adam(params)),
torch.optim.Adam)
assert isinstance(optimizer_resolver(torch.optim.Rprop(params)),
torch.optim.Rprop)
assert isinstance(optimizer_resolver('sgd', params, lr=0.01),
torch.optim.SGD)
assert isinstance(optimizer_resolver('adam', params), torch.optim.Adam)
assert isinstance(optimizer_resolver('rprop', params), torch.optim.Rprop)
@pytest.mark.parametrize('scheduler_args', [
('constant_with_warmup', LambdaLR),
('linear_with_warmup', LambdaLR),
('cosine_with_warmup', LambdaLR),
('cosine_with_warmup_restarts', LambdaLR),
('polynomial_with_warmup', LambdaLR),
('constant', ConstantLR),
('ReduceLROnPlateau', ReduceLROnPlateau),
])
def test_lr_scheduler_resolver(scheduler_args):
scheduler_name, scheduler_cls = scheduler_args
model = torch.nn.Linear(10, 5)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
lr_scheduler = lr_scheduler_resolver(
scheduler_name,
optimizer,
num_training_steps=100,
)
assert isinstance(lr_scheduler, scheduler_cls)
|