1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
|
import os.path as osp
import torch
import torch.nn.functional as F
from torch.nn import Linear as Lin
import torch_geometric.transforms as T
from torch_geometric.datasets import ModelNet
from torch_geometric.loader import DataLoader
from torch_geometric.nn import (
MLP,
PointTransformerConv,
fps,
global_mean_pool,
knn,
knn_graph,
)
from torch_geometric.typing import WITH_TORCH_CLUSTER
from torch_geometric.utils import scatter
if not WITH_TORCH_CLUSTER:
quit("This example requires 'torch-cluster'")
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data/ModelNet10')
pre_transform, transform = T.NormalizeScale(), T.SamplePoints(1024)
train_dataset = ModelNet(path, '10', True, transform, pre_transform)
test_dataset = ModelNet(path, '10', False, transform, pre_transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
class TransformerBlock(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.lin_in = Lin(in_channels, in_channels)
self.lin_out = Lin(out_channels, out_channels)
self.pos_nn = MLP([3, 64, out_channels], norm=None, plain_last=False)
self.attn_nn = MLP([out_channels, 64, out_channels], norm=None,
plain_last=False)
self.transformer = PointTransformerConv(in_channels, out_channels,
pos_nn=self.pos_nn,
attn_nn=self.attn_nn)
def forward(self, x, pos, edge_index):
x = self.lin_in(x).relu()
x = self.transformer(x, pos, edge_index)
x = self.lin_out(x).relu()
return x
class TransitionDown(torch.nn.Module):
"""Samples the input point cloud by a ratio percentage to reduce
cardinality and uses an mlp to augment features dimensionnality.
"""
def __init__(self, in_channels, out_channels, ratio=0.25, k=16):
super().__init__()
self.k = k
self.ratio = ratio
self.mlp = MLP([in_channels, out_channels], plain_last=False)
def forward(self, x, pos, batch):
# FPS sampling
id_clusters = fps(pos, ratio=self.ratio, batch=batch)
# compute for each cluster the k nearest points
sub_batch = batch[id_clusters] if batch is not None else None
# beware of self loop
id_k_neighbor = knn(pos, pos[id_clusters], k=self.k, batch_x=batch,
batch_y=sub_batch)
# transformation of features through a simple MLP
x = self.mlp(x)
# Max pool onto each cluster the features from knn in points
x_out = scatter(x[id_k_neighbor[1]], id_k_neighbor[0], dim=0,
dim_size=id_clusters.size(0), reduce='max')
# keep only the clusters and their max-pooled features
sub_pos, out = pos[id_clusters], x_out
return out, sub_pos, sub_batch
class Net(torch.nn.Module):
def __init__(self, in_channels, out_channels, dim_model, k=16):
super().__init__()
self.k = k
# dummy feature is created if there is none given
in_channels = max(in_channels, 1)
# first block
self.mlp_input = MLP([in_channels, dim_model[0]], plain_last=False)
self.transformer_input = TransformerBlock(in_channels=dim_model[0],
out_channels=dim_model[0])
# backbone layers
self.transformers_down = torch.nn.ModuleList()
self.transition_down = torch.nn.ModuleList()
for i in range(len(dim_model) - 1):
# Add Transition Down block followed by a Transformer block
self.transition_down.append(
TransitionDown(in_channels=dim_model[i],
out_channels=dim_model[i + 1], k=self.k))
self.transformers_down.append(
TransformerBlock(in_channels=dim_model[i + 1],
out_channels=dim_model[i + 1]))
# class score computation
self.mlp_output = MLP([dim_model[-1], 64, out_channels], norm=None)
def forward(self, x, pos, batch=None):
# add dummy features in case there is none
if x is None:
x = torch.ones((pos.shape[0], 1), device=pos.get_device())
# first block
x = self.mlp_input(x)
edge_index = knn_graph(pos, k=self.k, batch=batch)
x = self.transformer_input(x, pos, edge_index)
# backbone
for i in range(len(self.transformers_down)):
x, pos, batch = self.transition_down[i](x, pos, batch=batch)
edge_index = knn_graph(pos, k=self.k, batch=batch)
x = self.transformers_down[i](x, pos, edge_index)
# GlobalAveragePooling
x = global_mean_pool(x, batch)
# Class score
out = self.mlp_output(x)
return F.log_softmax(out, dim=-1)
def train():
model.train()
total_loss = 0
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
out = model(data.x, data.pos, data.batch)
loss = F.nll_loss(out, data.y)
loss.backward()
total_loss += loss.item() * data.num_graphs
optimizer.step()
return total_loss / len(train_dataset)
@torch.no_grad()
def test(loader):
model.eval()
correct = 0
for data in loader:
data = data.to(device)
pred = model(data.x, data.pos, data.batch).max(dim=1)[1]
correct += pred.eq(data.y).sum().item()
return correct / len(loader.dataset)
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(0, train_dataset.num_classes,
dim_model=[32, 64, 128, 256, 512], k=16).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20,
gamma=0.5)
for epoch in range(1, 201):
loss = train()
test_acc = test(test_loader)
print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, Test: {test_acc:.4f}')
scheduler.step()
|