1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
|
from typing import Optional
import torch
def graclus_cluster(
row: torch.Tensor,
col: torch.Tensor,
weight: Optional[torch.Tensor] = None,
num_nodes: Optional[int] = None,
) -> torch.Tensor:
"""A greedy clustering algorithm of picking an unmarked vertex and matching
it with one its unmarked neighbors (that maximizes its edge weight).
Args:
row (LongTensor): Source nodes.
col (LongTensor): Target nodes.
weight (Tensor, optional): Edge weights. (default: :obj:`None`)
num_nodes (int, optional): The number of nodes. (default: :obj:`None`)
:rtype: :class:`LongTensor`
.. code-block:: python
import torch
from torch_cluster import graclus_cluster
row = torch.tensor([0, 1, 1, 2])
col = torch.tensor([1, 0, 2, 1])
weight = torch.Tensor([1, 1, 1, 1])
cluster = graclus_cluster(row, col, weight)
"""
if num_nodes is None:
num_nodes = max(int(row.max()), int(col.max())) + 1
# Remove self-loops.
mask = row != col
row, col = row[mask], col[mask]
if weight is not None:
weight = weight[mask]
# Randomly shuffle nodes.
if weight is None:
perm = torch.randperm(row.size(0), dtype=torch.long, device=row.device)
row, col = row[perm], col[perm]
# To CSR.
perm = torch.argsort(row)
row, col = row[perm], col[perm]
if weight is not None:
weight = weight[perm]
deg = row.new_zeros(num_nodes)
deg.scatter_add_(0, row, torch.ones_like(row))
rowptr = row.new_zeros(num_nodes + 1)
torch.cumsum(deg, 0, out=rowptr[1:])
return torch.ops.torch_cluster.graclus(rowptr, col, weight)
|