File: ward_clustering.py

package info (click to toggle)
nipy 0.6.1-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 7,392 kB
  • sloc: python: 39,094; ansic: 30,931; makefile: 228; sh: 93
file content (79 lines) | stat: -rwxr-xr-x 1,993 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#!/usr/bin/env python3
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
__doc__ = """
Demo ward clustering on a graph: various ways of forming clusters and dendrogram

Requires matplotlib
"""
print(__doc__)

import numpy as np
from numpy.random import rand, randn

try:
    import matplotlib.pyplot as plt
except ImportError:
    raise RuntimeError("This script needs the matplotlib library")

from nipy.algorithms.clustering.hierarchical_clustering import ward
from nipy.algorithms.graph import knn

# n = number of points, k = number of nearest neighbours
n = 100
k = 5

# Set verbose to True to see more printed output
verbose = False

X = randn(n, 2)
X[:int(np.ceil(n / 3))] += 3
G = knn(X, 5)
tree = ward(G, X, verbose)

threshold = .5 * n
u = tree.partition(threshold)

plt.figure(figsize=(12, 6))
plt.subplot(1, 3, 1)
for i in range(u.max()+1):
    plt.plot(X[u == i, 0], X[u == i, 1], 'o', color=(rand(), rand(), rand()))

plt.axis('tight')
plt.axis('off')
plt.title(f'clustering into clusters \n of inertia < {threshold:g}')

u = tree.split(k)
plt.subplot(1, 3, 2)
for e in range(G.E):
    plt.plot([X[G.edges[e, 0], 0], X[G.edges[e, 1], 0]],
            [X[G.edges[e, 0], 1], X[G.edges[e, 1], 1]], 'k')
for i in range(u.max() + 1):
    plt.plot(X[u == i, 0], X[u == i, 1], 'o', color=(rand(), rand(), rand()))
plt.axis('tight')
plt.axis('off')
plt.title('clustering into 5 clusters')

nl = np.sum(tree.isleaf())
validleaves = np.zeros(n)
validleaves[:int(np.ceil(n / 4))] = 1
valid = np.zeros(tree.V, 'bool')
valid[tree.isleaf()] = validleaves.astype('bool')
nv = np.sum(validleaves)
nv0 = 0
while nv > nv0:
    nv0 = nv
    for v in range(tree.V):
        if valid[v]:
            valid[tree.parents[v]]=1
        nv = np.sum(valid)

ax = plt.subplot(1, 3, 3)
ax = tree.plot(ax)
ax.set_title('Dendrogram')
ax.set_visible(True)
plt.show()

if verbose:
    print('List of sub trees')
    print(tree.list_of_subtrees())