File: rbcd_attack_poisoning.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (144 lines) | stat: -rw-r--r-- 4,796 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import copy
import os.path as osp
import sys
from typing import Optional, Tuple

import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from rbcd_attack import GCN, metric, test, train
from torch import Tensor
from torch.optim import Adam

import torch_geometric.transforms as T
from torch_geometric.contrib.nn import PRBCDAttack
from torch_geometric.datasets import Planetoid

try:
    import higher
except ImportError:
    sys.exit('Install `higher` via `pip install higher` for poisoning example')

path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', 'Planetoid')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# IMPORTANT: Edge weights are being ignored later and most adjacency matrix
# preprocessing should be part of the model (part of backpropagation):
dataset = Planetoid(path, name='Cora', transform=T.NormalizeFeatures())
data = dataset[0].to(device)

gcn = GCN(dataset.num_features, 16, dataset.num_classes).to(device)
train(gcn, data)

print('------------- GCN: Global Poisoning -------------')

clean_acc = test(gcn, data)
print(f'Clean accuracy: {clean_acc:.3f}')

n_epochs = 50
lr = 0.04
weight_decay = 5e-4


class PoisoningPRBCDAttack(PRBCDAttack):
    def _forward(self, x: Tensor, edge_index: Tensor, edge_weight: Tensor,
                 **kwargs) -> Tensor:
        """Forward model."""
        self.model.reset_parameters()

        with torch.enable_grad():
            ped = copy.copy(data)
            ped.x, ped.edge_index, ped.edge_weight = x, edge_index, edge_weight
            train(self.model, ped, n_epochs, lr, weight_decay)

        self.model.eval()
        return self.model(x, edge_index, edge_weight)

    def _forward_and_gradient(self, x: Tensor, labels: Tensor,
                              idx_attack: Optional[Tensor] = None,
                              **kwargs) -> Tuple[Tensor, Tensor]:
        """Forward and update edge weights."""
        self.block_edge_weight.requires_grad = True

        self.model.reset_parameters()

        self.model.train()
        opt = Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay)

        with higher.innerloop_ctx(self.model, opt) as (fmodel, diffopt):
            edge_index, edge_weight = self._get_modified_adj(
                self.edge_index, self.edge_weight, self.block_edge_index,
                self.block_edge_weight)

            # Normalize only once (only relevant if model normalizes adj)
            if hasattr(fmodel, 'norm'):
                edge_index, edge_weight = fmodel.norm(
                    edge_index,
                    edge_weight,
                    num_nodes=x.size(0),
                    add_self_loops=True,
                )

            for _ in range(n_epochs):
                pred = fmodel.forward(x, edge_index, edge_weight,
                                      skip_norm=True)
                loss = F.cross_entropy(pred[data.train_mask],
                                       data.y[data.train_mask])
                diffopt.step(loss)

            pred = fmodel(x, edge_index, edge_weight)
            loss = self.loss(pred, labels, idx_attack)

            gradient = torch.autograd.grad(loss, self.block_edge_weight)[0]

        # Clip gradient for stability:
        clip_norm = 0.5
        grad_len_sq = gradient.square().sum()
        if grad_len_sq > clip_norm:
            gradient *= clip_norm / grad_len_sq.sqrt()

        self.model.eval()

        return loss, gradient


prbcd = PoisoningPRBCDAttack(gcn, block_size=250_000, metric=metric, lr=100)

# PRBCD: Attack test set:
global_budget = int(0.05 * data.edge_index.size(1) / 2)  # Perturb 5% of edges

pert_edge_index, perts = prbcd.attack(
    data.x,
    data.edge_index,
    data.y,
    budget=global_budget,
    idx_attack=data.test_mask,
)

gcn.reset_parameters()
pert_data = copy.copy(data)
pert_data.edge_index = pert_edge_index
train(gcn, pert_data)
pert_acc = test(gcn, pert_data)
# Note that the values here a bit more noisy than in the evasion case:
print(f'PRBCD: Accuracy dropped from {clean_acc:.3f} to {pert_acc:.3f}')

fig, ax1 = plt.subplots()
plt.title('Global Poisoning GCN')
color = 'tab:red'
ax1.plot(prbcd.attack_statistics['loss'], color=color, label='Loss')
ax1.tick_params(axis='y', labelcolor=color)
ax1.set_ylabel('Loss')
ax1.set_xlabel('Steps')

# It is best practice choosing the learning rate s.t. the budget is exhausted:
ax2 = ax1.twinx()
color = 'tab:blue'
ax2.plot(prbcd.attack_statistics['prob_mass_after_update'], color=color,
         linestyle='--', label='Before projection')
ax2.plot(prbcd.attack_statistics['prob_mass_after_projection'], color=color,
         label='After projection')
ax2.tick_params(axis='y', labelcolor=color)
ax2.set_ylabel('Used budget')
plt.legend()
fig.show()