1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
|
import os.path as osp
import torch
import torch.nn.functional as F
from point_transformer_classification import TransformerBlock, TransitionDown
from torchmetrics.functional import jaccard_index
import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MLP, knn_graph, knn_interpolate
from torch_geometric.typing import WITH_TORCH_CLUSTER
from torch_geometric.utils import scatter
if not WITH_TORCH_CLUSTER:
quit("This example requires 'torch-cluster'")
category = 'Airplane' # Pass in `None` to train on all categories.
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ShapeNet')
transform = T.Compose([
T.RandomJitter(0.01),
T.RandomRotate(15, axis=0),
T.RandomRotate(15, axis=1),
T.RandomRotate(15, axis=2),
])
pre_transform = T.NormalizeScale()
train_dataset = ShapeNet(path, category, split='trainval', transform=transform,
pre_transform=pre_transform)
test_dataset = ShapeNet(path, category, split='test',
pre_transform=pre_transform)
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False)
class TransitionUp(torch.nn.Module):
"""Reduce features dimensionality and interpolate back to higher
resolution and cardinality.
"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.mlp_sub = MLP([in_channels, out_channels], plain_last=False)
self.mlp = MLP([out_channels, out_channels], plain_last=False)
def forward(self, x, x_sub, pos, pos_sub, batch=None, batch_sub=None):
# transform low-res features and reduce the number of features
x_sub = self.mlp_sub(x_sub)
# interpolate low-res feats to high-res points
x_interpolated = knn_interpolate(x_sub, pos_sub, pos, k=3,
batch_x=batch_sub, batch_y=batch)
x = self.mlp(x) + x_interpolated
return x
class Net(torch.nn.Module):
def __init__(self, in_channels, out_channels, dim_model, k=16):
super().__init__()
self.k = k
# dummy feature is created if there is none given
in_channels = max(in_channels, 1)
# first block
self.mlp_input = MLP([in_channels, dim_model[0]], plain_last=False)
self.transformer_input = TransformerBlock(
in_channels=dim_model[0],
out_channels=dim_model[0],
)
# backbone layers
self.transformers_up = torch.nn.ModuleList()
self.transformers_down = torch.nn.ModuleList()
self.transition_up = torch.nn.ModuleList()
self.transition_down = torch.nn.ModuleList()
for i in range(0, len(dim_model) - 1):
# Add Transition Down block followed by a Point Transformer block
self.transition_down.append(
TransitionDown(in_channels=dim_model[i],
out_channels=dim_model[i + 1], k=self.k))
self.transformers_down.append(
TransformerBlock(in_channels=dim_model[i + 1],
out_channels=dim_model[i + 1]))
# Add Transition Up block followed by Point Transformer block
self.transition_up.append(
TransitionUp(in_channels=dim_model[i + 1],
out_channels=dim_model[i]))
self.transformers_up.append(
TransformerBlock(in_channels=dim_model[i],
out_channels=dim_model[i]))
# summit layers
self.mlp_summit = MLP([dim_model[-1], dim_model[-1]], norm=None,
plain_last=False)
self.transformer_summit = TransformerBlock(
in_channels=dim_model[-1],
out_channels=dim_model[-1],
)
# class score computation
self.mlp_output = MLP([dim_model[0], 64, out_channels], norm=None)
def forward(self, x, pos, batch=None):
# add dummy features in case there is none
if x is None:
x = torch.ones((pos.shape[0], 1)).to(pos.get_device())
out_x = []
out_pos = []
out_batch = []
# first block
x = self.mlp_input(x)
edge_index = knn_graph(pos, k=self.k, batch=batch)
x = self.transformer_input(x, pos, edge_index)
# save outputs for skipping connections
out_x.append(x)
out_pos.append(pos)
out_batch.append(batch)
# backbone down : #reduce cardinality and augment dimensionnality
for i in range(len(self.transformers_down)):
x, pos, batch = self.transition_down[i](x, pos, batch=batch)
edge_index = knn_graph(pos, k=self.k, batch=batch)
x = self.transformers_down[i](x, pos, edge_index)
out_x.append(x)
out_pos.append(pos)
out_batch.append(batch)
# summit
x = self.mlp_summit(x)
edge_index = knn_graph(pos, k=self.k, batch=batch)
x = self.transformer_summit(x, pos, edge_index)
# backbone up : augment cardinality and reduce dimensionnality
n = len(self.transformers_down)
for i in range(n):
x = self.transition_up[-i - 1](x=out_x[-i - 2], x_sub=x,
pos=out_pos[-i - 2],
pos_sub=out_pos[-i - 1],
batch_sub=out_batch[-i - 1],
batch=out_batch[-i - 2])
edge_index = knn_graph(out_pos[-i - 2], k=self.k,
batch=out_batch[-i - 2])
x = self.transformers_up[-i - 1](x, out_pos[-i - 2], edge_index)
# Class score
out = self.mlp_output(x)
return F.log_softmax(out, dim=-1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(3, train_dataset.num_classes, dim_model=[32, 64, 128, 256, 512],
k=16).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
def train():
model.train()
total_loss = correct_nodes = total_nodes = 0
for i, data in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
out = model(data.x, data.pos, data.batch)
loss = F.nll_loss(out, data.y)
loss.backward()
optimizer.step()
total_loss += loss.item()
correct_nodes += out.argmax(dim=1).eq(data.y).sum().item()
total_nodes += data.num_nodes
if (i + 1) % 10 == 0:
print(f'[{i+1}/{len(train_loader)}] Loss: {total_loss / 10:.4f} '
f'Train Acc: {correct_nodes / total_nodes:.4f}')
total_loss = correct_nodes = total_nodes = 0
def test(loader):
model.eval()
ious, categories = [], []
y_map = torch.empty(loader.dataset.num_classes, device=device).long()
for data in loader:
data = data.to(device)
outs = model(data.x, data.pos, data.batch)
sizes = (data.ptr[1:] - data.ptr[:-1]).tolist()
for out, y, category in zip(outs.split(sizes), data.y.split(sizes),
data.category.tolist()):
category = list(ShapeNet.seg_classes.keys())[category]
part = ShapeNet.seg_classes[category]
part = torch.tensor(part, device=device)
y_map[part] = torch.arange(part.size(0), device=device)
iou = jaccard_index(out[:, part].argmax(dim=-1), y_map[y],
num_classes=part.size(0), absent_score=1.0)
ious.append(iou)
categories.append(data.category)
iou = torch.tensor(ious, device=device)
category = torch.cat(categories, dim=0)
mean_iou = scatter(iou, category, reduce='mean') # Per-category IoU.
return float(mean_iou.mean()) # Global IoU.
for epoch in range(1, 100):
train()
iou = test(test_loader)
print(f'Epoch: {epoch:03d}, Test IoU: {iou:.4f}')
scheduler.step()
|