File: test_discrete_basic.py

package info (click to toggle)
python-scipy 1.1.0-7
  • links: PTS, VCS
  • area: main
  • in suites: buster
  • size: 93,828 kB
  • sloc: python: 156,854; ansic: 82,925; fortran: 80,777; cpp: 7,505; makefile: 427; sh: 294
file content (234 lines) | stat: -rw-r--r-- 8,295 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
from __future__ import division, print_function, absolute_import

import numpy.testing as npt
import numpy as np
from scipy._lib.six import xrange
import pytest

from scipy import stats
from .common_tests import (check_normalization, check_moment, check_mean_expect,
                           check_var_expect, check_skew_expect,
                           check_kurt_expect, check_entropy,
                           check_private_entropy, check_edge_support,
                           check_named_args, check_random_state_property,
                           check_pickling, check_rvs_broadcast)
from scipy.stats._distr_params import distdiscrete

vals = ([1, 2, 3, 4], [0.1, 0.2, 0.3, 0.4])
distdiscrete += [[stats.rv_discrete(values=vals), ()]]


def cases_test_discrete_basic():
    seen = set()
    for distname, arg in distdiscrete:
        yield distname, arg, distname not in seen
        seen.add(distname)


@pytest.mark.parametrize('distname,arg,first_case', cases_test_discrete_basic())
def test_discrete_basic(distname, arg, first_case):
    try:
        distfn = getattr(stats, distname)
    except TypeError:
        distfn = distname
        distname = 'sample distribution'
    np.random.seed(9765456)
    rvs = distfn.rvs(size=2000, *arg)
    supp = np.unique(rvs)
    m, v = distfn.stats(*arg)
    check_cdf_ppf(distfn, arg, supp, distname + ' cdf_ppf')

    check_pmf_cdf(distfn, arg, distname)
    check_oth(distfn, arg, supp, distname + ' oth')
    check_edge_support(distfn, arg)

    alpha = 0.01
    check_discrete_chisquare(distfn, arg, rvs, alpha,
           distname + ' chisquare')

    if first_case:
        locscale_defaults = (0,)
        meths = [distfn.pmf, distfn.logpmf, distfn.cdf, distfn.logcdf,
                 distfn.logsf]
        # make sure arguments are within support
        spec_k = {'randint': 11, 'hypergeom': 4, 'bernoulli': 0, }
        k = spec_k.get(distname, 1)
        check_named_args(distfn, k, arg, locscale_defaults, meths)
        if distname != 'sample distribution':
            check_scale_docstring(distfn)
        check_random_state_property(distfn, arg)
        check_pickling(distfn, arg)

        # Entropy
        check_entropy(distfn, arg, distname)
        if distfn.__class__._entropy != stats.rv_discrete._entropy:
            check_private_entropy(distfn, arg, stats.rv_discrete)


@pytest.mark.parametrize('distname,arg', distdiscrete)
def test_moments(distname, arg):
    try:
        distfn = getattr(stats, distname)
    except TypeError:
        distfn = distname
        distname = 'sample distribution'
    m, v, s, k = distfn.stats(*arg, moments='mvsk')
    check_normalization(distfn, arg, distname)

    # compare `stats` and `moment` methods
    check_moment(distfn, arg, m, v, distname)
    check_mean_expect(distfn, arg, m, distname)
    check_var_expect(distfn, arg, m, v, distname)
    check_skew_expect(distfn, arg, m, v, s, distname)
    if distname not in ['zipf']:
        check_kurt_expect(distfn, arg, m, v, k, distname)

    # frozen distr moments
    check_moment_frozen(distfn, arg, m, 1)
    check_moment_frozen(distfn, arg, v+m*m, 2)


@pytest.mark.parametrize('dist,shape_args', distdiscrete)
def test_rvs_broadcast(dist, shape_args):
    # If shape_only is True, it means the _rvs method of the
    # distribution uses more than one random number to generate a random
    # variate.  That means the result of using rvs with broadcasting or
    # with a nontrivial size will not necessarily be the same as using the
    # numpy.vectorize'd version of rvs(), so we can only compare the shapes
    # of the results, not the values.
    # Whether or not a distribution is in the following list is an
    # implementation detail of the distribution, not a requirement.  If
    # the implementation the rvs() method of a distribution changes, this
    # test might also have to be changed.
    shape_only = dist in ['skellam']

    try:
        distfunc = getattr(stats, dist)
    except TypeError:
        distfunc = dist
        dist = 'rv_discrete(values=(%r, %r))' % (dist.xk, dist.pk)
    loc = np.zeros(2)
    nargs = distfunc.numargs
    allargs = []
    bshape = []
    # Generate shape parameter arguments...
    for k in range(nargs):
        shp = (k + 3,) + (1,)*(k + 1)
        param_val = shape_args[k]
        allargs.append(param_val*np.ones(shp, dtype=np.array(param_val).dtype))
        bshape.insert(0, shp[0])
    allargs.append(loc)
    bshape.append(loc.size)
    # bshape holds the expected shape when loc, scale, and the shape
    # parameters are all broadcast together.
    check_rvs_broadcast(distfunc, dist, allargs, bshape, shape_only, [np.int_])


def check_cdf_ppf(distfn, arg, supp, msg):
    # cdf is a step function, and ppf(q) = min{k : cdf(k) >= q, k integer}
    npt.assert_array_equal(distfn.ppf(distfn.cdf(supp, *arg), *arg),
                           supp, msg + '-roundtrip')
    npt.assert_array_equal(distfn.ppf(distfn.cdf(supp, *arg) - 1e-8, *arg),
                           supp, msg + '-roundtrip')

    if not hasattr(distfn, 'xk'):
        supp1 = supp[supp < distfn.b]
        npt.assert_array_equal(distfn.ppf(distfn.cdf(supp1, *arg) + 1e-8, *arg),
                               supp1 + distfn.inc, msg + ' ppf-cdf-next')
        # -1e-8 could cause an error if pmf < 1e-8


def check_pmf_cdf(distfn, arg, distname):
    if hasattr(distfn, 'xk'):
        index = distfn.xk
    else:
        startind = int(distfn.ppf(0.01, *arg) - 1)
        index = list(range(startind, startind + 10))
    cdfs = distfn.cdf(index, *arg)
    pmfs_cum = distfn.pmf(index, *arg).cumsum()

    atol, rtol = 1e-10, 1e-10
    if distname == 'skellam':    # ncx2 accuracy
        atol, rtol = 1e-5, 1e-5
    npt.assert_allclose(cdfs - cdfs[0], pmfs_cum - pmfs_cum[0],
                        atol=atol, rtol=rtol)


def check_moment_frozen(distfn, arg, m, k):
    npt.assert_allclose(distfn(*arg).moment(k), m,
                        atol=1e-10, rtol=1e-10)


def check_oth(distfn, arg, supp, msg):
    # checking other methods of distfn
    npt.assert_allclose(distfn.sf(supp, *arg), 1. - distfn.cdf(supp, *arg),
                        atol=1e-10, rtol=1e-10)

    q = np.linspace(0.01, 0.99, 20)
    npt.assert_allclose(distfn.isf(q, *arg), distfn.ppf(1. - q, *arg),
                        atol=1e-10, rtol=1e-10)

    median_sf = distfn.isf(0.5, *arg)
    npt.assert_(distfn.sf(median_sf - 1, *arg) > 0.5)
    npt.assert_(distfn.cdf(median_sf + 1, *arg) > 0.5)


def check_discrete_chisquare(distfn, arg, rvs, alpha, msg):
    """Perform chisquare test for random sample of a discrete distribution

    Parameters
    ----------
    distname : string
        name of distribution function
    arg : sequence
        parameters of distribution
    alpha : float
        significance level, threshold for p-value

    Returns
    -------
    result : bool
        0 if test passes, 1 if test fails

    """
    wsupp = 0.05

    # construct intervals with minimum mass `wsupp`.
    # intervals are left-half-open as in a cdf difference
    lo = int(max(distfn.a, -1000))
    distsupport = xrange(lo, int(min(distfn.b, 1000)) + 1)
    last = 0
    distsupp = [lo]
    distmass = []
    for ii in distsupport:
        current = distfn.cdf(ii, *arg)
        if current - last >= wsupp - 1e-14:
            distsupp.append(ii)
            distmass.append(current - last)
            last = current
            if current > (1 - wsupp):
                break
    if distsupp[-1] < distfn.b:
        distsupp.append(distfn.b)
        distmass.append(1 - last)
    distsupp = np.array(distsupp)
    distmass = np.array(distmass)

    # convert intervals to right-half-open as required by histogram
    histsupp = distsupp + 1e-8
    histsupp[0] = distfn.a

    # find sample frequencies and perform chisquare test
    freq, hsupp = np.histogram(rvs, histsupp)
    chis, pval = stats.chisquare(np.array(freq), len(rvs)*distmass)

    npt.assert_(pval > alpha,
                'chisquare - test for %s at arg = %s with pval = %s' %
                (msg, str(arg), str(pval)))


def check_scale_docstring(distfn):
    if distfn.__doc__ is not None:
        # Docstrings can be stripped if interpreter is run with -OO
        npt.assert_('scale' not in distfn.__doc__)