1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
|
# fmt: off
import numpy as np
class DisjointSet:
def __init__(self, n):
self.sizes = np.ones(n, dtype=int)
self.parents = np.arange(n)
self.nc = n
def _compress(self):
a = self.parents
b = a[a]
while (a != b).any():
a = b
b = a[a]
self.parents = a
def union(self, a, b):
a = self.find(a)
b = self.find(b)
if a == b:
return False
sizes = self.sizes
parents = self.parents
if sizes[a] < sizes[b]:
parents[a] = b
sizes[b] += sizes[a]
else:
parents[b] = a
sizes[a] += sizes[b]
self.nc -= 1
return True
def find(self, index):
parents = self.parents
parent = parents[index]
while parent != parents[parent]:
parent = parents[parent]
parents[index] = parent
return parent
def find_all(self, relabel=False):
self._compress()
if not relabel:
return self.parents
# order elements by frequency
_unique, inverse, counts = np.unique(self.parents,
return_inverse=True,
return_counts=True)
indices = np.argsort(counts, kind='merge')[::-1]
return np.argsort(indices)[inverse]
|