1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476
|
import math
from contextlib import contextmanager
import pytest
import torch
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.data import Data, HeteroData
from torch_geometric.data.lightning import (
LightningDataset,
LightningLinkData,
LightningNodeData,
)
from torch_geometric.nn import global_mean_pool
from torch_geometric.sampler import BaseSampler, NeighborSampler
from torch_geometric.testing import (
MyFeatureStore,
MyGraphStore,
get_random_edge_index,
onlyCUDA,
onlyFullTest,
onlyNeighborSampler,
onlyOnline,
withPackage,
)
try:
from pytorch_lightning import LightningModule
except ImportError:
LightningModule = torch.nn.Module
class LinearGraphModule(LightningModule):
def __init__(self, in_channels: int, hidden_channels: int,
out_channels: int):
super().__init__()
from torchmetrics import Accuracy
self.lin1 = torch.nn.Linear(in_channels, hidden_channels)
self.lin2 = torch.nn.Linear(hidden_channels, out_channels)
self.train_acc = Accuracy(task='multiclass', num_classes=out_channels)
self.val_acc = Accuracy(task='multiclass', num_classes=out_channels)
self.test_acc = Accuracy(task='multiclass', num_classes=out_channels)
def forward(self, x: Tensor, batch: Data) -> Tensor:
# Basic test to ensure that the dataset is not replicated:
self.trainer.datamodule.train_dataset._data.x.add_(1)
x = self.lin1(x).relu()
x = global_mean_pool(x, batch)
x = self.lin2(x)
return x
def training_step(self, data: Data, batch_idx: int):
y_hat = self(data.x, data.batch)
loss = F.cross_entropy(y_hat, data.y)
self.train_acc(y_hat.softmax(dim=-1), data.y)
self.log('loss', loss, batch_size=data.num_graphs)
self.log('train_acc', self.train_acc, batch_size=data.num_graphs)
return loss
def validation_step(self, data: Data, batch_idx: int):
y_hat = self(data.x, data.batch)
self.val_acc(y_hat.softmax(dim=-1), data.y)
self.log('val_acc', self.val_acc, batch_size=data.num_graphs)
def test_step(self, data: Data, batch_idx: int):
y_hat = self(data.x, data.batch)
self.test_acc(y_hat.softmax(dim=-1), data.y)
self.log('test_acc', self.test_acc, batch_size=data.num_graphs)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.01)
@onlyCUDA
@onlyOnline
@onlyFullTest
@withPackage('pytorch_lightning>=2.0.0', 'torchmetrics>=0.11.0')
@pytest.mark.parametrize('strategy_type', [None, 'ddp'])
def test_lightning_dataset(get_dataset, strategy_type):
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
@contextmanager
def expect_rank_zero_user_warning(match: str):
if rank_zero_only.rank == 0:
with pytest.warns(UserWarning, match=match):
yield
else:
yield
dataset = get_dataset(name='MUTAG').shuffle()
train_dataset = dataset[:50]
val_dataset = dataset[50:80]
test_dataset = dataset[80:90]
pred_dataset = dataset[90:]
devices = 1 if strategy_type is None else torch.cuda.device_count()
if strategy_type == 'ddp':
strategy = pl.strategies.DDPStrategy(accelerator='gpu')
else:
strategy = pl.strategies.SingleDeviceStrategy(device='cuda:0')
model = LinearGraphModule(dataset.num_features, 64, dataset.num_classes)
trainer = pl.Trainer(strategy=strategy, devices=devices, max_epochs=1,
log_every_n_steps=1)
with pytest.warns(UserWarning, match="'shuffle=True' option is ignored"):
datamodule = LightningDataset(train_dataset, val_dataset, test_dataset,
pred_dataset, batch_size=5,
num_workers=3, shuffle=True)
assert 'shuffle' not in datamodule.kwargs
old_x = train_dataset._data.x.clone()
assert str(datamodule) == ('LightningDataset(train_dataset=MUTAG(50), '
'val_dataset=MUTAG(30), '
'test_dataset=MUTAG(10), '
'pred_dataset=MUTAG(98), batch_size=5, '
'num_workers=3, pin_memory=True, '
'persistent_workers=True)')
trainer.fit(model, datamodule)
trainer.test(model, datamodule)
new_x = train_dataset._data.x
assert torch.all(new_x > old_x) # Ensure shared data.
assert trainer.validate_loop._data_source.is_defined()
assert trainer.test_loop._data_source.is_defined()
# Test with `val_dataset=None` and `test_dataset=None`:
if strategy_type is None:
trainer = pl.Trainer(strategy=strategy, devices=devices, max_epochs=1,
log_every_n_steps=1)
datamodule = LightningDataset(train_dataset, batch_size=5)
assert str(datamodule) == ('LightningDataset(train_dataset=MUTAG(50), '
'batch_size=5, num_workers=0, '
'pin_memory=True, '
'persistent_workers=False)')
with expect_rank_zero_user_warning("defined a `validation_step`"):
trainer.fit(model, datamodule)
assert not trainer.validate_loop._data_source.is_defined()
assert not trainer.test_loop._data_source.is_defined()
class LinearNodeModule(LightningModule):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
from torchmetrics import Accuracy
self.lin = torch.nn.Linear(in_channels, out_channels)
self.train_acc = Accuracy(task='multiclass', num_classes=out_channels)
self.val_acc = Accuracy(task='multiclass', num_classes=out_channels)
self.test_acc = Accuracy(task='multiclass', num_classes=out_channels)
def forward(self, x: Tensor) -> Tensor:
# Basic test to ensure that the dataset is not replicated:
self.trainer.datamodule.data.x.add_(1)
return self.lin(x)
def training_step(self, data: Data, batch_idx: int):
y_hat = self(data.x)[data.train_mask]
y = data.y[data.train_mask]
loss = F.cross_entropy(y_hat, y)
self.train_acc(y_hat.softmax(dim=-1), y)
self.log('loss', loss, batch_size=y.size(0))
self.log('train_acc', self.train_acc, batch_size=y.size(0))
return loss
def validation_step(self, data: Data, batch_idx: int):
y_hat = self(data.x)[data.val_mask]
y = data.y[data.val_mask]
self.val_acc(y_hat.softmax(dim=-1), y)
self.log('val_acc', self.val_acc, batch_size=y.size(0))
def test_step(self, data: Data, batch_idx: int):
y_hat = self(data.x)[data.test_mask]
y = data.y[data.test_mask]
self.test_acc(y_hat.softmax(dim=-1), y)
self.log('test_acc', self.test_acc, batch_size=y.size(0))
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.01)
@onlyCUDA
@onlyOnline
@onlyFullTest
@onlyNeighborSampler
@withPackage('pytorch_lightning>=2.0.0', 'torchmetrics>=0.11.0', 'scipy')
@pytest.mark.parametrize('loader', ['full', 'neighbor'])
@pytest.mark.parametrize('strategy_type', [None, 'ddp'])
def test_lightning_node_data(get_dataset, strategy_type, loader):
import pytorch_lightning as pl
dataset = get_dataset(name='Cora')
data = dataset[0]
data_repr = ('Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], '
'train_mask=[2708], val_mask=[2708], test_mask=[2708])')
model = LinearNodeModule(dataset.num_features, dataset.num_classes)
if strategy_type is None or loader == 'full':
devices = 1
else:
devices = torch.cuda.device_count()
if strategy_type == 'ddp':
strategy = pl.strategies.DDPStrategy(accelerator='gpu')
else:
strategy = pl.strategies.SingleDeviceStrategy(device='cuda:0')
if loader == 'full': # Set reasonable defaults for full-batch training:
batch_size = 1
num_workers = 0
else:
batch_size = 32
num_workers = 3
kwargs, kwargs_repr = {}, ''
if loader == 'neighbor':
kwargs['num_neighbors'] = [5]
kwargs_repr += 'num_neighbors=[5], '
trainer = pl.Trainer(strategy=strategy, devices=devices, max_epochs=5,
log_every_n_steps=1)
datamodule = LightningNodeData(data, loader=loader, batch_size=batch_size,
num_workers=num_workers, **kwargs)
old_x = data.x.clone().cpu()
assert str(datamodule) == (f'LightningNodeData(data={data_repr}, '
f'loader={loader}, batch_size={batch_size}, '
f'num_workers={num_workers}, {kwargs_repr}'
f'pin_memory={loader != "full"}, '
f'persistent_workers={loader != "full"})')
trainer.fit(model, datamodule)
trainer.test(model, datamodule)
new_x = data.x.cpu()
assert torch.all(new_x > old_x) # Ensure shared data.
assert trainer.validate_loop._data_source.is_defined()
assert trainer.test_loop._data_source.is_defined()
class LinearHeteroNodeModule(LightningModule):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
from torchmetrics import Accuracy
self.lin = torch.nn.Linear(in_channels, out_channels)
self.train_acc = Accuracy(task='multiclass', num_classes=out_channels)
self.val_acc = Accuracy(task='multiclass', num_classes=out_channels)
self.test_acc = Accuracy(task='multiclass', num_classes=out_channels)
def forward(self, x: Tensor) -> Tensor:
# Basic test to ensure that the dataset is not replicated:
self.trainer.datamodule.data['paper'].x.add_(1)
return self.lin(x)
def training_step(self, data: HeteroData, batch_idx: int):
y_hat = self(data['paper'].x)[data['paper'].train_mask]
y = data['paper'].y[data['paper'].train_mask]
loss = F.cross_entropy(y_hat, y)
self.train_acc(y_hat.softmax(dim=-1), y)
self.log('loss', loss, batch_size=y.size(0))
self.log('train_acc', self.train_acc, batch_size=y.size(0))
return loss
def validation_step(self, data: HeteroData, batch_idx: int):
y_hat = self(data['paper'].x)[data['paper'].val_mask]
y = data['paper'].y[data['paper'].val_mask]
self.val_acc(y_hat.softmax(dim=-1), y)
self.log('val_acc', self.val_acc, batch_size=y.size(0))
def test_step(self, data: HeteroData, batch_idx: int):
y_hat = self(data['paper'].x)[data['paper'].test_mask]
y = data['paper'].y[data['paper'].test_mask]
self.test_acc(y_hat.softmax(dim=-1), y)
self.log('test_acc', self.test_acc, batch_size=y.size(0))
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.01)
@pytest.fixture
def preserve_context():
num_threads = torch.get_num_threads()
yield
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
torch.set_num_threads(num_threads)
@onlyCUDA
@onlyFullTest
@onlyNeighborSampler
@withPackage('pytorch_lightning>=2.0.0', 'torchmetrics>=0.11.0')
def test_lightning_hetero_node_data(preserve_context, get_dataset):
import pytorch_lightning as pl
data = get_dataset(name='hetero')[0]
model = LinearHeteroNodeModule(data['paper'].num_features,
int(data['paper'].y.max()) + 1)
devices = torch.cuda.device_count()
strategy = pl.strategies.DDPStrategy(accelerator='gpu')
trainer = pl.Trainer(strategy=strategy, devices=devices, max_epochs=5,
log_every_n_steps=1)
datamodule = LightningNodeData(data, loader='neighbor', num_neighbors=[5],
batch_size=32, num_workers=3)
assert isinstance(datamodule.graph_sampler, NeighborSampler)
original_x = data['paper'].x.clone()
trainer.fit(model, datamodule)
trainer.test(model, datamodule)
assert torch.all(data['paper'].x > original_x) # Ensure shared data.
assert trainer.validate_loop._data_source.is_defined()
assert trainer.test_loop._data_source.is_defined()
@withPackage('pytorch_lightning')
def test_lightning_data_custom_sampler():
class DummySampler(BaseSampler):
def sample_from_edges(self, *args, **kwargs):
pass
def sample_from_nodes(self, *args, **kwargs):
pass
data = Data(num_nodes=2, edge_index=torch.tensor([[0, 1], [1, 0]]))
datamodule = LightningNodeData(data, node_sampler=DummySampler(),
input_train_nodes=torch.arange(2))
assert isinstance(datamodule.graph_sampler, DummySampler)
datamodule = LightningLinkData(
data, link_sampler=DummySampler(),
input_train_edges=torch.tensor([[0, 1], [0, 1]]))
assert isinstance(datamodule.graph_sampler, DummySampler)
@onlyCUDA
@onlyFullTest
@onlyNeighborSampler
@withPackage('pytorch_lightning')
def test_lightning_hetero_link_data():
torch.manual_seed(12345)
data = HeteroData()
data['paper'].x = torch.arange(10)
data['author'].x = torch.arange(10)
data['term'].x = torch.arange(10)
data['paper', 'author'].edge_index = get_random_edge_index(10, 10, 10)
data['author', 'paper'].edge_index = get_random_edge_index(10, 10, 10)
data['paper', 'term'].edge_index = get_random_edge_index(10, 10, 10)
data['author', 'term'].edge_index = get_random_edge_index(10, 10, 10)
datamodule = LightningLinkData(
data,
input_train_edges=('author', 'paper'),
input_val_edges=('paper', 'author'),
input_test_edges=('paper', 'term'),
input_pred_edges=('author', 'term'),
loader='neighbor',
num_neighbors=[5],
batch_size=32,
num_workers=0,
)
assert isinstance(datamodule.graph_sampler, NeighborSampler)
assert isinstance(datamodule.eval_graph_sampler, NeighborSampler)
for batch in datamodule.train_dataloader():
assert 'edge_label_index' in batch['author', 'paper']
for batch in datamodule.val_dataloader():
assert 'edge_label_index' in batch['paper', 'author']
for batch in datamodule.test_dataloader():
assert 'edge_label_index' in batch['paper', 'term']
for batch in datamodule.predict_dataloader():
assert 'edge_label_index' in batch['author', 'term']
data['author'].time = torch.arange(data['author'].num_nodes)
data['paper'].time = torch.arange(data['paper'].num_nodes)
data['term'].time = torch.arange(data['term'].num_nodes)
datamodule = LightningLinkData(
data,
input_train_edges=('author', 'paper'),
input_train_time=torch.arange(data['author', 'paper'].num_edges),
loader='neighbor',
num_neighbors=[5],
batch_size=32,
num_workers=0,
time_attr='time',
)
for batch in datamodule.train_dataloader():
assert 'edge_label_index' in batch['author', 'paper']
assert 'edge_label_time' in batch['author', 'paper']
@onlyNeighborSampler
@withPackage('pytorch_lightning')
def test_lightning_hetero_link_data_custom_store():
torch.manual_seed(12345)
feature_store = MyFeatureStore()
graph_store = MyGraphStore()
x = torch.arange(10)
feature_store.put_tensor(x, group_name='paper', attr_name='x', index=None)
feature_store.put_tensor(x, group_name='author', attr_name='x', index=None)
feature_store.put_tensor(x, group_name='term', attr_name='x', index=None)
edge_index = get_random_edge_index(10, 10, 10)
graph_store.put_edge_index(edge_index=(edge_index[0], edge_index[1]),
edge_type=('paper', 'to', 'author'),
layout='coo', size=(10, 10))
graph_store.put_edge_index(edge_index=(edge_index[0], edge_index[1]),
edge_type=('author', 'to', 'paper'),
layout='coo', size=(10, 10))
graph_store.put_edge_index(edge_index=(edge_index[0], edge_index[1]),
edge_type=('paper', 'to', 'term'), layout='coo',
size=(10, 10))
datamodule = LightningLinkData(
(feature_store, graph_store),
input_train_edges=('author', 'to', 'paper'),
loader='neighbor',
num_neighbors=[5],
batch_size=32,
num_workers=0,
)
batch = next(iter(datamodule.train_dataloader()))
assert 'edge_label_index' in batch['author', 'paper']
@onlyOnline
@onlyNeighborSampler
@withPackage('pytorch_lightning', 'scipy')
def test_eval_loader_kwargs(get_dataset):
data = get_dataset(name='Cora')[0]
node_sampler = NeighborSampler(data, num_neighbors=[5])
datamodule = LightningNodeData(
data,
node_sampler=node_sampler,
batch_size=32,
eval_loader_kwargs=dict(num_neighbors=[-1], batch_size=64),
)
assert datamodule.loader_kwargs['batch_size'] == 32
assert datamodule.graph_sampler.num_neighbors.values == [5]
assert datamodule.eval_loader_kwargs['batch_size'] == 64
assert datamodule.eval_graph_sampler.num_neighbors.values == [-1]
train_loader = datamodule.train_dataloader()
assert math.ceil(int(data.train_mask.sum()) / 32) == len(train_loader)
val_loader = datamodule.val_dataloader()
assert math.ceil(int(data.val_mask.sum()) / 64) == len(val_loader)
test_loader = datamodule.test_dataloader()
assert math.ceil(int(data.test_mask.sum()) / 64) == len(test_loader)
pred_loader = datamodule.predict_dataloader()
assert math.ceil(data.num_nodes / 64) == len(pred_loader)
|