File: ward_clustering.py

package info (click to toggle)
nipy 0.1.2%2B20100526-2
  • links: PTS, VCS
  • area: main
  • in suites: squeeze
  • size: 11,992 kB
  • ctags: 13,434
  • sloc: python: 47,720; ansic: 41,334; makefile: 197
file content (76 lines) | stat: -rw-r--r-- 1,725 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
"""
Demo ward clustering on a graph:
various ways of forming clusters and dendrogram
"""
print __doc__

import numpy as np
from numpy.random import randn, rand
import matplotlib.pylab as mp

from nipy.neurospin import graph
from nipy.neurospin.clustering.hierarchical_clustering import ward

# n = number of points, k = number of nearest neighbours
n = 100
k = 5
verbose = 0

X = randn(n,2)
X[:np.ceil(n/3)] += 3		
G = graph.WeightedGraph(n)
#G.mst(X)
G.knn(X, 5)
tree = ward(G, X, verbose)

threshold = .5*n
u = tree.partition(threshold)

mp.figure()
mp.subplot(1,2,1)
for i in range(u.max()+1):
    mp.plot(X[u==i,0], X[u==i,1],'o', color=(rand(), rand(), rand()))

mp.axis('tight')
mp.axis('off')
mp.title('clustering into clusters of inertia<%f'%threshold)

u = tree.split(k)
mp.subplot(1,2,2)
for e in range(G.E):
    mp.plot([X[G.edges[e,0],0], X[G.edges[e,1],0]],
            [X[G.edges[e,0],1], X[G.edges[e,1],1]], 'k')
for i in range(u.max()+1):
    mp.plot(X[u==i,0], X[u==i,1], 'o', color=(rand(), rand(), rand()))
mp.axis('tight')
mp.axis('off')
mp.title('clustering into 5 clusters')



nl = np.sum(tree.isleaf())
validleaves = np.zeros(n)
validleaves[:np.ceil(n/4)]=1
valid = np.zeros(tree.V, 'bool')
valid[tree.isleaf()] = validleaves.astype('bool')
nv =  np.sum(validleaves)
nv0 = 0
while nv>nv0:
    nv0= nv
    for v in range(tree.V):
        if valid[v]:
            valid[tree.parents[v]]=1
    nv = np.sum(valid)
    
#ax = tree.fancy_plot_(valid)
#ax.axis('off')

ax = tree.plot()
ax.set_visible(True)
mp.show()

if verbose:
    print 'List of sub trees'
    print tree.list_of_subtrees()