File: test_online_em.py

package info (click to toggle)
python-scipy 0.5.2-0.1
  • links: PTS
  • area: main
  • in suites: etch, etch-m68k
  • size: 33,888 kB
  • ctags: 44,231
  • sloc: ansic: 156,256; cpp: 90,347; python: 89,604; fortran: 73,083; sh: 1,318; objc: 424; makefile: 342
file content (234 lines) | stat: -rw-r--r-- 7,840 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
#! /usr/bin/env python
# Last Change: Wed Dec 06 09:00 PM 2006 J

import copy

import sys
from numpy.testing import *

import numpy as N
from numpy.random import seed

set_package_path()
from pyem import GM, GMM
from pyem.online_em import OnGMM, OnGMM1d
restore_path()

# #Optional:
# set_local_path()
# # import modules that are located in the same directory as this file.
# restore_path()

# Error precision allowed (nb of decimals)
AR_AS_PREC  = 12
KM_ITER     = 5

class OnlineEmTest(NumpyTestCase):
    def _create_model(self, d, k, mode, nframes, emiter):
        #+++++++++++++++++++++++++++++++++++++++++++++++++
        # Generate a model with k components, d dimensions
        #+++++++++++++++++++++++++++++++++++++++++++++++++
        w, mu, va   = GM.gen_param(d, k, mode, spread = 1.5)
        gm          = GM.fromvalues(w, mu, va)
        # Sample nframes frames  from the model
        data        = gm.sample(nframes)

        #++++++++++++++++++++++++++++++++++++++++++
        # Approximate the models with classical EM
        #++++++++++++++++++++++++++++++++++++++++++
        # Init the model
        lgm = GM(d, k, mode)
        gmm = GMM(lgm, 'kmean')
        gmm.init(data, niter = KM_ITER)

        self.gm0    = copy.copy(gmm.gm)
        # The actual EM, with likelihood computation
        for i in range(emiter):
            g, tgd  = gmm.sufficient_statistics(data)
            gmm.update_em(data, g)

        self.data   = data
        self.gm     = lgm
    
class test_on_off_eq(OnlineEmTest):
    def check_1d(self, level = 1):
        d       = 1
        k       = 2
        mode    = 'diag'
        nframes = int(1e2)
        emiter  = 3

        seed(1)
        self._create_model(d, k, mode, nframes, emiter)
        self._check(d, k, mode, nframes, emiter)

    def check_2d(self, level = 1):
        d       = 2
        k       = 2
        mode    = 'diag'
        nframes = int(1e2)
        emiter  = 3

        seed(1)
        self._create_model(d, k, mode, nframes, emiter)
        self._check(d, k, mode, nframes, emiter)

    def check_5d(self, level = 5):
        d       = 5
        k       = 2
        mode    = 'diag'
        nframes = int(1e2)
        emiter  = 3

        seed(1)
        self._create_model(d, k, mode, nframes, emiter)
        self._check(d, k, mode, nframes, emiter)

    def _check(self, d, k, mode, nframes, emiter):
        #++++++++++++++++++++++++++++++++++++++++
        # Approximate the models with online EM
        #++++++++++++++++++++++++++++++++++++++++
        # Learn the model with Online EM
        ogm         = GM(d, k, mode)
        ogmm        = OnGMM(ogm, 'kmean')
        init_data   = self.data
        ogmm.init(init_data, niter = KM_ITER)

        # Check that online kmean init is the same than kmean offline init
        ogm0    = copy.copy(ogm)
        assert_array_equal(ogm0.w, self.gm0.w)
        assert_array_equal(ogm0.mu, self.gm0.mu)
        assert_array_equal(ogm0.va, self.gm0.va)

        # Forgetting param
        lamb	= N.ones((nframes, 1))
        lamb[0] = 0
        nu0		= 1.0
        nu		= N.zeros((len(lamb), 1))
        nu[0]	= nu0
        for i in range(1, len(lamb)):
            nu[i]	= 1./(1 + lamb[i] / nu[i-1])

        # object version of online EM: the p* arguments are updated only at each 
        # epoch, which is equivalent to on full EM iteration on the 
        # classic EM algorithm
        ogmm.pw    = ogmm.cw.copy()
        ogmm.pmu   = ogmm.cmu.copy()
        ogmm.pva   = ogmm.cva.copy()
        for e in range(emiter):
            for t in range(nframes):
                ogmm.compute_sufficient_statistics_frame(self.data[t], nu[t])
                ogmm.update_em_frame()

            # Change pw args only a each epoch 
            ogmm.pw  = ogmm.cw.copy()
            ogmm.pmu = ogmm.cmu.copy()
            ogmm.pva = ogmm.cva.copy()

        # For equivalence between off and on, we allow a margin of error,
        # because of round-off errors.
        print " Checking precision of equivalence with offline EM trainer "
        maxtestprec = 18
        try :
            for i in range(maxtestprec):
                    assert_array_almost_equal(self.gm.w, ogmm.pw, decimal = i)
                    assert_array_almost_equal(self.gm.mu, ogmm.pmu, decimal = i)
                    assert_array_almost_equal(self.gm.va, ogmm.pva, decimal = i)
            print "\t !! Precision up to %d decimals !! " % i
        except AssertionError:
            if i < AR_AS_PREC:
                print """\t !!NOT OK: Precision up to %d decimals only, 
                    outside the allowed range (%d) !! """ % (i, AR_AS_PREC)
                raise AssertionError
            else:
                print "\t !!OK: Precision up to %d decimals !! " % i

class test_on(OnlineEmTest):
    def check_consistency(self):
        d       = 1
        k       = 2
        mode    = 'diag'
        nframes = int(5e2)
        emiter  = 4

        self._create_model(d, k, mode, nframes, emiter)
        self._run_pure_online(d, k, mode, nframes)
    
    def check_1d_imp(self):
        d       = 1
        k       = 2
        mode    = 'diag'
        nframes = int(5e2)
        emiter  = 4

        self._create_model(d, k, mode, nframes, emiter)
        gmref   = self._run_pure_online(d, k, mode, nframes)
        gmtest  = self._run_pure_online_1d(d, k, mode, nframes)
    
        assert_array_almost_equal(gmref.w, gmtest.w, AR_AS_PREC)
        assert_array_almost_equal(gmref.mu, gmtest.mu, AR_AS_PREC)
        assert_array_almost_equal(gmref.va, gmtest.va, AR_AS_PREC)

    def _run_pure_online_1d(self, d, k, mode, nframes):
        #++++++++++++++++++++++++++++++++++++++++
        # Approximate the models with online EM
        #++++++++++++++++++++++++++++++++++++++++
        ogm     = GM(d, k, mode)
        ogmm    = OnGMM1d(ogm, 'kmean')
        init_data   = self.data[0:nframes / 20, :]
        ogmm.init(init_data[:, 0])

        # Forgetting param
        ku		= 0.005
        t0		= 200
        lamb	= 1 - 1/(N.arange(-1, nframes-1) * ku + t0)
        nu0		= 0.2
        nu		= N.zeros((len(lamb), 1))
        nu[0]	= nu0
        for i in range(1, len(lamb)):
            nu[i]	= 1./(1 + lamb[i] / nu[i-1])

        # object version of online EM
        for t in range(nframes):
            # the assert are here to check we do not create copies
            # unvoluntary for parameters
            a, b, c = ogmm.compute_sufficient_statistics_frame(self.data[t, 0], nu[t])
            ogmm.update_em_frame(a, b, c)

        ogmm.gm.set_param(ogmm.cw, ogmm.cmu[:, N.newaxis], ogmm.cva[:, N.newaxis])

        return ogmm.gm
    def _run_pure_online(self, d, k, mode, nframes):
        #++++++++++++++++++++++++++++++++++++++++
        # Approximate the models with online EM
        #++++++++++++++++++++++++++++++++++++++++
        ogm     = GM(d, k, mode)
        ogmm    = OnGMM(ogm, 'kmean')
        init_data   = self.data[0:nframes / 20, :]
        ogmm.init(init_data)

        # Forgetting param
        ku		= 0.005
        t0		= 200
        lamb	= 1 - 1/(N.arange(-1, nframes-1) * ku + t0)
        nu0		= 0.2
        nu		= N.zeros((len(lamb), 1))
        nu[0]	= nu0
        for i in range(1, len(lamb)):
            nu[i]	= 1./(1 + lamb[i] / nu[i-1])

        # object version of online EM
        for t in range(nframes):
            # the assert are here to check we do not create copies
            # unvoluntary for parameters
            assert ogmm.pw is ogmm.cw
            assert ogmm.pmu is ogmm.cmu
            assert ogmm.pva is ogmm.cva
            ogmm.compute_sufficient_statistics_frame(self.data[t], nu[t])
            ogmm.update_em_frame()

        ogmm.gm.set_param(ogmm.cw, ogmm.cmu, ogmm.cva)

        return ogmm.gm
if __name__ == "__main__":
    NumpyTest().run()