1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
|
import os.path as osp
import torch
import torch.nn.functional as F
from torchmetrics.functional import jaccard_index
import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MLP, DynamicEdgeConv
from torch_geometric.utils import scatter
category = 'Airplane' # Pass in `None` to train on all categories.
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ShapeNet')
transform = T.Compose([
T.RandomJitter(0.01),
T.RandomRotate(15, axis=0),
T.RandomRotate(15, axis=1),
T.RandomRotate(15, axis=2)
])
pre_transform = T.NormalizeScale()
train_dataset = ShapeNet(path, category, split='trainval', transform=transform,
pre_transform=pre_transform)
test_dataset = ShapeNet(path, category, split='test',
pre_transform=pre_transform)
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True,
num_workers=6)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False,
num_workers=6)
class Net(torch.nn.Module):
def __init__(self, out_channels, k=30, aggr='max'):
super().__init__()
self.conv1 = DynamicEdgeConv(MLP([2 * 6, 64, 64]), k, aggr)
self.conv2 = DynamicEdgeConv(MLP([2 * 64, 64, 64]), k, aggr)
self.conv3 = DynamicEdgeConv(MLP([2 * 64, 64, 64]), k, aggr)
self.mlp = MLP([3 * 64, 1024, 256, 128, out_channels], dropout=0.5,
norm=None)
def forward(self, data):
x, pos, batch = data.x, data.pos, data.batch
x0 = torch.cat([x, pos], dim=-1)
x1 = self.conv1(x0, batch)
x2 = self.conv2(x1, batch)
x3 = self.conv3(x2, batch)
out = self.mlp(torch.cat([x1, x2, x3], dim=1))
return F.log_softmax(out, dim=1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(train_dataset.num_classes, k=30).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.8)
def train():
model.train()
total_loss = correct_nodes = total_nodes = 0
for i, data in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out, data.y)
loss.backward()
optimizer.step()
total_loss += loss.item()
correct_nodes += out.argmax(dim=1).eq(data.y).sum().item()
total_nodes += data.num_nodes
if (i + 1) % 10 == 0:
print(f'[{i+1}/{len(train_loader)}] Loss: {total_loss / 10:.4f} '
f'Train Acc: {correct_nodes / total_nodes:.4f}')
total_loss = correct_nodes = total_nodes = 0
@torch.no_grad()
def test(loader):
model.eval()
ious, categories = [], []
y_map = torch.empty(loader.dataset.num_classes, device=device).long()
for data in loader:
data = data.to(device)
outs = model(data)
sizes = (data.ptr[1:] - data.ptr[:-1]).tolist()
for out, y, category in zip(outs.split(sizes), data.y.split(sizes),
data.category.tolist()):
category = list(ShapeNet.seg_classes.keys())[category]
part = ShapeNet.seg_classes[category]
part = torch.tensor(part, device=device)
y_map[part] = torch.arange(part.size(0), device=device)
iou = jaccard_index(out[:, part].argmax(dim=-1), y_map[y],
num_classes=part.size(0), absent_score=1.0)
ious.append(iou)
categories.append(data.category)
iou = torch.tensor(ious, device=device)
category = torch.cat(categories, dim=0)
mean_iou = scatter(iou, category, reduce='mean') # Per-category IoU.
return float(mean_iou.mean()) # Global IoU.
for epoch in range(1, 31):
train()
iou = test(test_loader)
print(f'Epoch: {epoch:02d}, Test IoU: {iou:.4f}')
|