File: kmean.py

package info (click to toggle)
python-scipy 0.5.2-0.1
  • links: PTS
  • area: main
  • in suites: etch, etch-m68k
  • size: 33,888 kB
  • ctags: 44,231
  • sloc: ansic: 156,256; cpp: 90,347; python: 89,604; fortran: 73,083; sh: 1,318; objc: 424; makefile: 342
file content (76 lines) | stat: -rw-r--r-- 2,163 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# /usr/bin/python
# Last Change: Thu Sep 28 01:00 PM 2006 J

#TODO:
#   - a demo for kmeans

import numpy as N

def _py_vq(data, code):
    """ Please do not use directly. Use kmean instead"""
    # No attempt to be efficient has been made...
    (n, d)  = data.shape
    (k, d)  = code.shape

    label   = N.zeros(n, int)
    for i in range(n):
        d           = N.sum((data[i, :] - code) ** 2, 1)
        label[i]    = N.argmin(d)

    return label
    
# Try to import pyrex function for vector quantization. If not available,
# falls back on pure python implementation.
#%KMEANIMPORT%
#try:
#    from scipy.cluster.vq import kmeans as kmean
#except ImportError:
#    try:
#        from c_gmm import _vq
#    except:
#        print """c_gmm._vq not found, using pure python implementation instead. 
#        Kmean will be REALLY slow"""
#        _vq = _py_vq
try:
    from scipy.cluster.vq import vq
    print "using scipy.cluster.vq"
    def _vq(*args, **kw): return vq(*args, **kw)[0]
except ImportError:
    try:
        from c_gmm import _vq
        print "using pyrex vq"
    except ImportError:
        print """c_gmm._vq not found, using pure python implementation instead. 
        Kmean will be REALLY slow"""
        _vq = _py_vq

def kmean(data, init, iter = 10):
    """Simple kmean implementation for EM. Runs iter iterations.
    
    returns a tuple (code, label), where code are the final
    centroids, and label are the class label indec for each
    frame (ie row) of data"""

    data    = N.atleast_2d(data)
    init    = N.atleast_2d(init)

    (n, d)  = data.shape
    (k, d1) = init.shape

    if not d == d1:
        msg = "data and init centers do not have same dimensions..."
        raise GmmParamError(msg)
    
    code    = N.asarray(init.copy())
    for i in range(iter):
        # Compute the nearest neighbour for each obs
        # using the current code book
        label   = _vq(data, code)
        # Update the code by computing centroids using the new code book
        for j in range(k):
            code[j,:] = N.mean(data[N.where(label==j)], axis=0) 

    return code, label

if __name__ == "__main__":
    pass