File: geweke.py

package info (click to toggle)
python-bumps 1.0.0b2-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 6,144 kB
  • sloc: python: 23,941; xml: 493; ansic: 373; makefile: 209; sh: 91; javascript: 90
file content (62 lines) | stat: -rw-r--r-- 2,118 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""
Convergence test statistic from Gelman and Rubin, 1992.
"""

__all__ = ["geweke"]

from numpy import var, mean, ones, sqrt, reshape, log10, abs


def geweke(sequences, portion=0.25):
    """
    Calculates the Geweke convergence diagnostic

    Refer to:

        pymc-devs.github.com/pymc/modelchecking.html#informal-methods
        support.sas.com/documentation/cdl/en/statug/63033/HTML/default/viewer.htm#statug_introbayes_sect008.html

    """

    # Find the size of the sample
    chain_len, nchains, nvar = sequences.shape
    z_stat = -2 * ones(nvar)
    if chain_len >= 2:
        # Only use the last portion of the sample
        try:
            front_portion, back_portion = portion
        except TypeError:
            front_portion = back_portion = portion
        front_len = int(chain_len * front_portion)
        back_len = int(chain_len * back_portion)
        # print "STARTING SHAPE", sequences.shape
        seq1 = reshape(sequences[:front_len, :, :], (front_len * nchains, nvar))
        seq2 = reshape(sequences[-back_len:, :, :], (back_len * nchains, nvar))
        # print "SEQ1", seq1.shape, 'SEQ2', seq2.shape
        # Step 1: Determine the sequence means
        meanseq1 = mean(seq1, axis=0)
        meanseq2 = mean(seq2, axis=0)
        # print "SHAPEs", meanseq1.shape, meanseq2.shape
        var1 = var(seq1, axis=0)
        var2 = var(seq2, axis=0)
        denom = sqrt(var1 + var2)
        index = denom > 0
        z_stat[index] = (meanseq1 - meanseq2)[index] / denom[index]

        # z_stat is now the Z score for every chain and parameter
        # in that with shape (chains, vars)

        # To make it easier to look at, return the average for the vars.
        if 0:
            avg_z = mean(z_stat, axis=0)
            lavg_z = log10(abs(avg_z))
            return lavg_z.tolist()
        if 0:
            avg_z = z_stat
            lavg_z = log10(abs(avg_z))
            return lavg_z.flatten().tolist()
        else:
            return z_stat.flatten().tolist()

    # TODO: code is wrong if chain length is 1, since lavg_z is not defined
    return lavg_z.tolist()