1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
|
# /usr/bin/python
# Last Change: Thu Sep 28 01:00 PM 2006 J
#TODO:
# - a demo for kmeans
import numpy as N
def _py_vq(data, code):
""" Please do not use directly. Use kmean instead"""
# No attempt to be efficient has been made...
(n, d) = data.shape
(k, d) = code.shape
label = N.zeros(n, int)
for i in range(n):
d = N.sum((data[i, :] - code) ** 2, 1)
label[i] = N.argmin(d)
return label
# Try to import pyrex function for vector quantization. If not available,
# falls back on pure python implementation.
#%KMEANIMPORT%
#try:
# from scipy.cluster.vq import kmeans as kmean
#except ImportError:
# try:
# from c_gmm import _vq
# except:
# print """c_gmm._vq not found, using pure python implementation instead.
# Kmean will be REALLY slow"""
# _vq = _py_vq
try:
from scipy.cluster.vq import vq
print "using scipy.cluster.vq"
def _vq(*args, **kw): return vq(*args, **kw)[0]
except ImportError:
try:
from c_gmm import _vq
print "using pyrex vq"
except ImportError:
print """c_gmm._vq not found, using pure python implementation instead.
Kmean will be REALLY slow"""
_vq = _py_vq
def kmean(data, init, iter = 10):
"""Simple kmean implementation for EM. Runs iter iterations.
returns a tuple (code, label), where code are the final
centroids, and label are the class label indec for each
frame (ie row) of data"""
data = N.atleast_2d(data)
init = N.atleast_2d(init)
(n, d) = data.shape
(k, d1) = init.shape
if not d == d1:
msg = "data and init centers do not have same dimensions..."
raise GmmParamError(msg)
code = N.asarray(init.copy())
for i in range(iter):
# Compute the nearest neighbour for each obs
# using the current code book
label = _vq(data, code)
# Update the code by computing centroids using the new code book
for j in range(k):
code[j,:] = N.mean(data[N.where(label==j)], axis=0)
return code, label
if __name__ == "__main__":
pass
|