1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
|
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
"""
Example of a demo that fits a Bayesian Gaussian Mixture Model (GMM)
to a dataset.
Variational bayes and Gibbs estimation are sucessively run on the same
dataset
Author : Bertrand Thirion, 2008-2009
"""
print __doc__
import numpy as np
import numpy.random as nr
import pylab as pl
import nipy.neurospin.clustering.bgmm as bgmm
from nipy.neurospin.clustering.gmm import plot2D
dim = 2
################################################################################
# 1. generate a 3-components mixture
x1 = nr.randn(100, dim)
x2 = 3+2*nr.randn(50, dim)
x3 = np.repeat(np.array([-2, 2], ndmin=2), 30, 0) + 0.5*nr.randn(30, dim)
x = np.concatenate((x1, x2, x3))
################################################################################
#2. fit the mixture with a bunch of possible models, using Variational Bayes
krange = range(1,10)
be = -np.infty
for k in krange:
b = bgmm.VBGMM(k, dim)
b.guess_priors(x)
b.initialize(x)
b.estimate(x)
ek = float(b.evidence(x))
if ek>be:
be = ek
bestb = b
print k,'classes, free energy:',ek
################################################################################
# 3. plot the result
z = bestb.map_label(x)
plot2D(x, bestb, z, verbose=0)
pl.title('Variational Bayes')
################################################################################
# 4. the same, with the Gibbs GMM algo
niter = 1000
krange = range(2, 6)
bbf = -np.infty
for k in krange:
b = bgmm.BGMM(k, dim)
b.guess_priors(x)
b.initialize(x)
b.sample(x, 100)
w, cent, prec, pz = b.sample(x, niter=niter, mem=1)
bplugin = bgmm.BGMM(k, dim, cent, prec, w)
bplugin.guess_priors(x)
bfk = bplugin.bayes_factor(x, pz.astype(np.int), nperm=40)
print k, 'classes, evidence:', bfk
if bfk>bbf:
bestk = k
bbf = bfk
bbgmm = bplugin
z = bbgmm.map_label(x)
plot2D(x, bbgmm, z, verbose=0)
pl.title('Gibbs sampling')
pl.show()
|