File: outliers.py

package info (click to toggle)
python-bumps 0.7.11-2
  • links: PTS, VCS
  • area: main
  • in suites: buster
  • size: 10,264 kB
  • sloc: python: 22,226; ansic: 4,973; cpp: 4,849; xml: 493; makefile: 163; perl: 108; sh: 101
file content (140 lines) | stat: -rw-r--r-- 4,993 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
Chain outlier tests.
"""

__all__ = ["identify_outliers"]

from numpy import mean, std, sqrt, where, argmin, arange, array
from numpy import sort
from scipy.stats import t as student_t
from scipy.stats import scoreatpercentile

from .mahal import mahalanobis
from .acr import ACR

tinv = student_t.ppf
# from scipy.stats import scoreatpercentile as prctile
# CRUFT: scoreatpercentile not accepting array arguments in older scipy


def prctile(v, Q):
    v = sort(v)
    return [scoreatpercentile(v, Qi) for Qi in Q]


def identify_outliers(test, chains, x):
    """
    Determine which chains have converged on a local maximum much lower than
    the maximum likelihood.

    *test* is the name of the test to use (one of IQR, Grubbs, Mahal or none).
    *chains* is a set of log likelihood values of shape (chain len, num chains)
    *x* is the current population of shape (num vars, num chains)

    Returns an integer array of outlier indices.
    """
    # Determine the mean log density of the active chains
    v = mean(chains, axis=0)

    # Check whether any of these active chains are outlier chains
    test = test.lower()
    if test == 'iqr':
        # Derive the upper and lower quartile of the chain averages
        q1, q3 = prctile(v, [25., 75.])
        # Derive the Inter Quartile Range (IQR)
        iqr = q3 - q1
        # See whether there are any outlier chains
        outliers = where(v < q1 - 2*iqr)[0]

    elif test == 'grubbs':
        # Compute zscore for chain averages
        zscore = (mean(v) - v) / std(v, ddof=1)
        # Determine t-value of one-sided interval
        n = len(v)
        t2 = tinv(1 - 0.01/n, n-2)**2  # 95% interval
        # Determine the critical value
        gcrit = ((n - 1)/sqrt(n)) * sqrt(t2/(n-2 + t2))
        # Then check against this
        outliers = where(zscore > gcrit)[0]

    elif test == 'mahal':
        # Use the Mahalanobis distance to find outliers in the population
        alpha = 0.01
        npop, nvar = x.shape
        gcrit = ACR(nvar, npop-1, alpha)
        #print "alpha", alpha, "nvar", nvar, "npop", npop, "gcrit", gcrit
        # Find which chain has minimum log_density
        minidx = argmin(v)
        # check the Mahalanobis distance of the current point to other chains
        d1 = mahalanobis(x[minidx, :], x[minidx != arange(npop), :])
        #print "d1", d1, "minidx", minidx
        # and see if it is an outlier
        outliers = array([minidx]) if d1 > gcrit else array([])

    elif test == 'none':
        outliers = array([])

    else:
        raise ValueError("Unknown outlier test "+test)

    return outliers


def test_outliers():
    from .walk import walk
    from numpy.random import multivariate_normal, seed
    from numpy import vstack, ones, eye
    seed(2)  # Remove uncertainty on tests
    # Set a number of good and bad chains
    ngood, nbad = 25, 2

    # Make chains mean-reverting chains with widely separated values for
    # bad and good; put bad chains first.
    chains = walk(1000, mu=[1]*nbad+[5]*ngood, sigma=0.45, alpha=0.1)

    # Check IQR and Grubbs
    assert (identify_outliers('IQR', chains, None) == arange(nbad)).all()
    assert (identify_outliers('Grubbs', chains, None) == arange(nbad)).all()

    # Put points for 'bad' chains at [-1,...,-1] and 'good' chains at [1,...,1]
    x = vstack((multivariate_normal(-ones(4), 0.1*eye(4), size=nbad),
                multivariate_normal(ones(4), 0.1*eye(4), size=ngood)))
    assert identify_outliers('Mahal', chains, x)[0] in range(nbad)

    # Put points for _all_ chains at [1,...,1] and check that mahal return []
    xsame = multivariate_normal(ones(4), 0.2*eye(4), size=ngood+nbad)
    assert len(identify_outliers('Mahal', chains, xsame)) == 0

    # Check again with large variance
    x = vstack((multivariate_normal(-3*ones(4), eye(4), size=nbad),
                multivariate_normal(ones(4), 10*eye(4), size=ngood)))
    assert len(identify_outliers('Mahal', chains, x)) == 0

    # =====================================================================
    # Test replacement

    # Construct a state object
    from numpy.linalg import norm
    from .state import MCMCDraw
    ngen, npop = chains.shape
    npop, nvar = x.shape
    state = MCMCDraw(Ngen=ngen, Nthin=ngen, Nupdate=0,
                     Nvar=nvar, Npop=npop, Ncr=0, thinning=0)
    # Fill it with chains
    for i in range(ngen):
        state._generation(new_draws=npop, x=x, logp=chains[i], accept=npop)

    # Make a copy of the current state so we can check it was updated
    nx, nlogp = x+0, chains[-1]+0
    # Remove outliers
    state.remove_outliers(nx, nlogp, test='IQR', portion=0.5)
    # Check that the outliers were removed
    outliers = state.outliers()
    assert outliers.shape[0] == nbad
    for i in range(nbad):
        assert nlogp[outliers[i, 1]] == chains[-1][outliers[i, 2]]
        assert norm(nx[outliers[i, 1], :] - x[outliers[i, 2], :]) == 0


if __name__ == "__main__":
    test_outliers()