#
# Author:  Travis Oliphant  2002-2011 with contributions from
#          SciPy Developers 2004-2011
#
from __future__ import division, print_function, absolute_import

from scipy import special
from scipy.special import gammaln as gamln

from numpy import floor, ceil, log, exp, sqrt, log1p, expm1, tanh, cosh, sinh

import numpy as np
import numpy.random as mtrand

from ._distn_infrastructure import (
        rv_discrete, _lazywhere, _ncx2_pdf, _ncx2_cdf, get_distribution_names)


class binom_gen(rv_discrete):
    """A binomial discrete random variable.

    %(before_notes)s

    Notes
    -----
    The probability mass function for `binom` is::

       binom.pmf(k) = choose(n, k) * p**k * (1-p)**(n-k)

    for ``k`` in ``{0, 1,..., n}``.

    `binom` takes ``n`` and ``p`` as shape parameters.

    %(example)s

    """
    def _rvs(self, n, p):
        return mtrand.binomial(n, p, self._size)

    def _argcheck(self, n, p):
        self.b = n
        return (n >= 0) & (p >= 0) & (p <= 1)

    def _logpmf(self, x, n, p):
        k = floor(x)
        combiln = (gamln(n+1) - (gamln(k+1) + gamln(n-k+1)))
        return combiln + special.xlogy(k, p) + special.xlog1py(n-k, -p)

    def _pmf(self, x, n, p):
        return exp(self._logpmf(x, n, p))

    def _cdf(self, x, n, p):
        k = floor(x)
        vals = special.bdtr(k, n, p)
        return vals

    def _sf(self, x, n, p):
        k = floor(x)
        return special.bdtrc(k, n, p)

    def _ppf(self, q, n, p):
        vals = ceil(special.bdtrik(q, n, p))
        vals1 = np.maximum(vals - 1, 0)
        temp = special.bdtr(vals1, n, p)
        return np.where(temp >= q, vals1, vals)

    def _stats(self, n, p):
        q = 1.0-p
        mu = n * p
        var = n * p * q
        g1 = (q-p) / sqrt(n*p*q)
        g2 = (1.0-6*p*q)/(n*p*q)
        return mu, var, g1, g2

    def _entropy(self, n, p):
        k = np.r_[0:n + 1]
        vals = self._pmf(k, n, p)
        h = -np.sum(special.xlogy(vals, vals), axis=0)
        return h
binom = binom_gen(name='binom')


class bernoulli_gen(binom_gen):
    """A Bernoulli discrete random variable.

    %(before_notes)s

    Notes
    -----
    The probability mass function for `bernoulli` is::

       bernoulli.pmf(k) = 1-p  if k = 0
                        = p    if k = 1

    for ``k`` in ``{0, 1}``.

    `bernoulli` takes ``p`` as shape parameter.

    %(example)s

    """
    def _rvs(self, p):
        return binom_gen._rvs(self, 1, p)

    def _argcheck(self, p):
        return (p >= 0) & (p <= 1)

    def _logpmf(self, x, p):
        return binom._logpmf(x, 1, p)

    def _pmf(self, x, p):
        return binom._pmf(x, 1, p)

    def _cdf(self, x, p):
        return binom._cdf(x, 1, p)

    def _sf(self, x, p):
        return binom._sf(x, 1, p)

    def _ppf(self, q, p):
        return binom._ppf(q, 1, p)

    def _stats(self, p):
        return binom._stats(1, p)

    def _entropy(self, p):
        h = -special.xlogy(p, p) - special.xlogy(1 - p, 1 - p)
        return h
bernoulli = bernoulli_gen(b=1, name='bernoulli')


class nbinom_gen(rv_discrete):
    """A negative binomial discrete random variable.

    %(before_notes)s

    Notes
    -----
    The probability mass function for `nbinom` is::

         nbinom.pmf(k) = choose(k+n-1, n-1) * p**n * (1-p)**k

    for ``k >= 0``.

    `nbinom` takes ``n`` and ``p`` as shape parameters.

    %(example)s

    """
    def _rvs(self, n, p):
        return mtrand.negative_binomial(n, p, self._size)

    def _argcheck(self, n, p):
        return (n >= 0) & (p >= 0) & (p <= 1)

    def _pmf(self, x, n, p):
        return exp(self._logpmf(x, n, p))

    def _logpmf(self, x, n, p):
        coeff = gamln(n+x) - gamln(x+1) - gamln(n)
        return coeff + n*log(p) + x*log(1-p)

    def _cdf(self, x, n, p):
        k = floor(x)
        return special.betainc(n, k+1, p)

    def _sf_skip(self, x, n, p):
        # skip because special.nbdtrc doesn't work for 0<n<1
        k = floor(x)
        return special.nbdtrc(k, n, p)

    def _ppf(self, q, n, p):
        vals = ceil(special.nbdtrik(q, n, p))
        vals1 = (vals-1).clip(0.0, np.inf)
        temp = self._cdf(vals1, n, p)
        return np.where(temp >= q, vals1, vals)

    def _stats(self, n, p):
        Q = 1.0 / p
        P = Q - 1.0
        mu = n*P
        var = n*P*Q
        g1 = (Q+P)/sqrt(n*P*Q)
        g2 = (1.0 + 6*P*Q) / (n*P*Q)
        return mu, var, g1, g2
nbinom = nbinom_gen(name='nbinom')


class geom_gen(rv_discrete):
    """A geometric discrete random variable.

    %(before_notes)s

    Notes
    -----
    The probability mass function for `geom` is::

        geom.pmf(k) = (1-p)**(k-1)*p

    for ``k >= 1``.

    `geom` takes ``p`` as shape parameter.

    %(example)s

    """
    def _rvs(self, p):
        return mtrand.geometric(p, size=self._size)

    def _argcheck(self, p):
        return (p <= 1) & (p >= 0)

    def _pmf(self, k, p):
        return np.power(1-p, k-1) * p

    def _logpmf(self, k, p):
        return (k-1) * log(1-p) + log(p)

    def _cdf(self, x, p):
        k = floor(x)
        return -expm1(log1p(-p)*k)

    def _sf(self, x, p):
        return np.exp(self._logsf(x, p))

    def _logsf(self, x, p):
        k = floor(x)
        return k*log1p(-p)

    def _ppf(self, q, p):
        vals = ceil(log(1.0-q)/log(1-p))
        temp = self._cdf(vals-1, p)
        return np.where((temp >= q) & (vals > 0), vals-1, vals)

    def _stats(self, p):
        mu = 1.0/p
        qr = 1.0-p
        var = qr / p / p
        g1 = (2.0-p) / sqrt(qr)
        g2 = np.polyval([1, -6, 6], p)/(1.0-p)
        return mu, var, g1, g2
geom = geom_gen(a=1, name='geom', longname="A geometric")


class hypergeom_gen(rv_discrete):
    """A hypergeometric discrete random variable.

    The hypergeometric distribution models drawing objects from a bin.
    M is the total number of objects, n is total number of Type I objects.
    The random variate represents the number of Type I objects in N drawn
    without replacement from the total population.

    %(before_notes)s

    Notes
    -----
    The probability mass function is defined as::

        pmf(k, M, n, N) = choose(n, k) * choose(M - n, N - k) / choose(M, N),
                                       for max(0, N - (M-n)) <= k <= min(n, N)

    Examples
    --------
    >>> from scipy.stats import hypergeom
    >>> import matplotlib.pyplot as plt

    Suppose we have a collection of 20 animals, of which 7 are dogs.  Then if
    we want to know the probability of finding a given number of dogs if we
    choose at random 12 of the 20 animals, we can initialize a frozen
    distribution and plot the probability mass function:

    >>> [M, n, N] = [20, 7, 12]
    >>> rv = hypergeom(M, n, N)
    >>> x = np.arange(0, n+1)
    >>> pmf_dogs = rv.pmf(x)

    >>> fig = plt.figure()
    >>> ax = fig.add_subplot(111)
    >>> ax.plot(x, pmf_dogs, 'bo')
    >>> ax.vlines(x, 0, pmf_dogs, lw=2)
    >>> ax.set_xlabel('# of dogs in our group of chosen animals')
    >>> ax.set_ylabel('hypergeom PMF')
    >>> plt.show()

    Instead of using a frozen distribution we can also use `hypergeom`
    methods directly.  To for example obtain the cumulative distribution
    function, use:

    >>> prb = hypergeom.cdf(x, M, n, N)

    And to generate random numbers:

    >>> R = hypergeom.rvs(M, n, N, size=10)

    """
    def _rvs(self, M, n, N):
        return mtrand.hypergeometric(n, M-n, N, size=self._size)

    def _argcheck(self, M, n, N):
        cond = rv_discrete._argcheck(self, M, n, N)
        cond &= (n <= M) & (N <= M)
        self.a = max(N-(M-n), 0)
        self.b = min(n, N)
        return cond

    def _logpmf(self, k, M, n, N):
        tot, good = M, n
        bad = tot - good
        return gamln(good+1) - gamln(good-k+1) - gamln(k+1) + gamln(bad+1) \
            - gamln(bad-N+k+1) - gamln(N-k+1) - gamln(tot+1) + gamln(tot-N+1) \
            + gamln(N+1)

    def _pmf(self, k, M, n, N):
        # same as the following but numerically more precise
        # return comb(good, k) * comb(bad, N-k) / comb(tot, N)
        return exp(self._logpmf(k, M, n, N))

    def _stats(self, M, n, N):
        # tot, good, sample_size = M, n, N
        # "wikipedia".replace('N', 'M').replace('n', 'N').replace('K', 'n')
        M, n, N = 1.*M, 1.*n, 1.*N
        m = M - n
        p = n/M
        mu = N*p

        var = m*n*N*(M - N)*1.0/(M*M*(M-1))
        g1 = (m - n)*(M-2*N) / (M-2.0) * sqrt((M-1.0) / (m*n*N*(M-N)))

        g2 = M*(M+1) - 6.*N*(M-N) - 6.*n*m
        g2 *= (M-1)*M*M
        g2 += 6.*n*N*(M-N)*m*(5.*M-6)
        g2 /= n * N * (M-N) * m * (M-2.) * (M-3.)
        return mu, var, g1, g2

    def _entropy(self, M, n, N):
        k = np.r_[N - (M - n):min(n, N) + 1]
        vals = self.pmf(k, M, n, N)
        h = -np.sum(special.xlogy(vals, vals), axis=0)
        return h

    def _sf(self, k, M, n, N):
        """More precise calculation, 1 - cdf doesn't cut it."""
        # This for loop is needed because `k` can be an array. If that's the
        # case, the sf() method makes M, n and N arrays of the same shape. We
        # therefore unpack all inputs args, so we can do the manual
        # integration.
        res = []
        for quant, tot, good, draw in zip(k, M, n, N):
            # Manual integration over probability mass function. More accurate
            # than integrate.quad.
            k2 = np.arange(quant + 1, draw + 1)
            res.append(np.sum(self._pmf(k2, tot, good, draw)))
        return np.asarray(res)
hypergeom = hypergeom_gen(name='hypergeom')


# FIXME: Fails _cdfvec
class logser_gen(rv_discrete):
    """A Logarithmic (Log-Series, Series) discrete random variable.

    %(before_notes)s

    Notes
    -----
    The probability mass function for `logser` is::

        logser.pmf(k) = - p**k / (k*log(1-p))

    for ``k >= 1``.

    `logser` takes ``p`` as shape parameter.

    %(example)s

    """
    def _rvs(self, p):
        # looks wrong for p>0.5, too few k=1
        # trying to use generic is worse, no k=1 at all
        return mtrand.logseries(p, size=self._size)

    def _argcheck(self, p):
        return (p > 0) & (p < 1)

    def _pmf(self, k, p):
        return -np.power(p, k) * 1.0 / k / log(1 - p)

    def _stats(self, p):
        r = log(1 - p)
        mu = p / (p - 1.0) / r
        mu2p = -p / r / (p - 1.0)**2
        var = mu2p - mu*mu
        mu3p = -p / r * (1.0+p) / (1.0 - p)**3
        mu3 = mu3p - 3*mu*mu2p + 2*mu**3
        g1 = mu3 / np.power(var, 1.5)

        mu4p = -p / r * (
            1.0 / (p-1)**2 - 6*p / (p - 1)**3 + 6*p*p / (p-1)**4)
        mu4 = mu4p - 4*mu3p*mu + 6*mu2p*mu*mu - 3*mu**4
        g2 = mu4 / var**2 - 3.0
        return mu, var, g1, g2
logser = logser_gen(a=1, name='logser', longname='A logarithmic')


class poisson_gen(rv_discrete):
    """A Poisson discrete random variable.

    %(before_notes)s

    Notes
    -----
    The probability mass function for `poisson` is::

        poisson.pmf(k) = exp(-mu) * mu**k / k!

    for ``k >= 0``.

    `poisson` takes ``mu`` as shape parameter.

    %(example)s

    """
    def _rvs(self, mu):
        return mtrand.poisson(mu, self._size)

    def _logpmf(self, k, mu):
        Pk = k*log(mu)-gamln(k+1) - mu
        return Pk

    def _pmf(self, k, mu):
        return exp(self._logpmf(k, mu))

    def _cdf(self, x, mu):
        k = floor(x)
        return special.pdtr(k, mu)

    def _sf(self, x, mu):
        k = floor(x)
        return special.pdtrc(k, mu)

    def _ppf(self, q, mu):
        vals = ceil(special.pdtrik(q, mu))
        vals1 = np.maximum(vals - 1, 0)
        temp = special.pdtr(vals1, mu)
        return np.where(temp >= q, vals1, vals)

    def _stats(self, mu):
        var = mu
        tmp = np.asarray(mu)
        g1 = sqrt(1.0 / tmp)
        g2 = 1.0 / tmp
        return mu, var, g1, g2
poisson = poisson_gen(name="poisson", longname='A Poisson')


class planck_gen(rv_discrete):
    """A Planck discrete exponential random variable.

    %(before_notes)s

    Notes
    -----
    The probability mass function for `planck` is::

        planck.pmf(k) = (1-exp(-lambda_))*exp(-lambda_*k)

    for ``k*lambda_ >= 0``.

    `planck` takes ``lambda_`` as shape parameter.

    %(example)s

    """
    def _argcheck(self, lambda_):
        if (lambda_ > 0):
            self.a = 0
            self.b = np.inf
            return 1
        elif (lambda_ < 0):
            self.a = -np.inf
            self.b = 0
            return 1
        else:
            return 0

    def _pmf(self, k, lambda_):
        fact = (1-exp(-lambda_))
        return fact*exp(-lambda_*k)

    def _cdf(self, x, lambda_):
        k = floor(x)
        return 1-exp(-lambda_*(k+1))

    def _ppf(self, q, lambda_):
        vals = ceil(-1.0/lambda_ * log1p(-q)-1)
        vals1 = (vals-1).clip(self.a, np.inf)
        temp = self._cdf(vals1, lambda_)
        return np.where(temp >= q, vals1, vals)

    def _stats(self, lambda_):
        mu = 1/(exp(lambda_)-1)
        var = exp(-lambda_)/(expm1(-lambda_))**2
        g1 = 2*cosh(lambda_/2.0)
        g2 = 4+2*cosh(lambda_)
        return mu, var, g1, g2

    def _entropy(self, lambda_):
        l = lambda_
        C = (1-exp(-l))
        return l*exp(-l)/C - log(C)
planck = planck_gen(name='planck', longname='A discrete exponential ')


class boltzmann_gen(rv_discrete):
    """A Boltzmann (Truncated Discrete Exponential) random variable.

    %(before_notes)s

    Notes
    -----
    The probability mass function for `boltzmann` is::

        boltzmann.pmf(k) = (1-exp(-lambda_)*exp(-lambda_*k)/(1-exp(-lambda_*N))

    for ``k = 0,..., N-1``.

    `boltzmann` takes ``lambda_`` and ``N`` as shape parameters.

    %(example)s

    """
    def _pmf(self, k, lambda_, N):
        fact = (1-exp(-lambda_))/(1-exp(-lambda_*N))
        return fact*exp(-lambda_*k)

    def _cdf(self, x, lambda_, N):
        k = floor(x)
        return (1-exp(-lambda_*(k+1)))/(1-exp(-lambda_*N))

    def _ppf(self, q, lambda_, N):
        qnew = q*(1-exp(-lambda_*N))
        vals = ceil(-1.0/lambda_ * log(1-qnew)-1)
        vals1 = (vals-1).clip(0.0, np.inf)
        temp = self._cdf(vals1, lambda_, N)
        return np.where(temp >= q, vals1, vals)

    def _stats(self, lambda_, N):
        z = exp(-lambda_)
        zN = exp(-lambda_*N)
        mu = z/(1.0-z)-N*zN/(1-zN)
        var = z/(1.0-z)**2 - N*N*zN/(1-zN)**2
        trm = (1-zN)/(1-z)
        trm2 = (z*trm**2 - N*N*zN)
        g1 = z*(1+z)*trm**3 - N**3*zN*(1+zN)
        g1 = g1 / trm2**(1.5)
        g2 = z*(1+4*z+z*z)*trm**4 - N**4 * zN*(1+4*zN+zN*zN)
        g2 = g2 / trm2 / trm2
        return mu, var, g1, g2
boltzmann = boltzmann_gen(name='boltzmann',
        longname='A truncated discrete exponential ')


class randint_gen(rv_discrete):
    """A uniform discrete random variable.

    %(before_notes)s

    Notes
    -----
    The probability mass function for `randint` is::

        randint.pmf(k) = 1./(high - low)

    for ``k = low, ..., high - 1``.

    `randint` takes ``low`` and ``high`` as shape parameters.

    Note the difference to the numpy ``random_integers`` which
    returns integers on a *closed* interval ``[low, high]``.

    %(example)s

    """
    def _argcheck(self, low, high):
        self.a = low
        self.b = high - 1
        return (high > low)

    def _pmf(self, k, low, high):
        p = np.ones_like(k) / (high - low)
        return np.where((k >= low) & (k < high), p, 0.)

    def _cdf(self, x, low, high):
        k = floor(x)
        return (k - low + 1.) / (high - low)

    def _ppf(self, q, low, high):
        vals = ceil(q * (high - low) + low) - 1
        vals1 = (vals - 1).clip(low, high)
        temp = self._cdf(vals1, low, high)
        return np.where(temp >= q, vals1, vals)

    def _stats(self, low, high):
        m2, m1 = np.asarray(high), np.asarray(low)
        mu = (m2 + m1 - 1.0) / 2
        d = m2 - m1
        var = (d*d - 1) / 12.0
        g1 = 0.0
        g2 = -6.0/5.0 * (d*d + 1.0) / (d*d - 1.0)
        return mu, var, g1, g2

    def _rvs(self, low, high=None):
        """An array of *size* random integers >= ``low`` and < ``high``.

        If ``high`` is ``None``, then range is >=0  and < low
        """
        return mtrand.randint(low, high, self._size)

    def _entropy(self, low, high):
        return log(high - low)
randint = randint_gen(name='randint', longname='A discrete uniform '
                      '(random integer)')


# FIXME: problems sampling.
class zipf_gen(rv_discrete):
    """A Zipf discrete random variable.

    %(before_notes)s

    Notes
    -----
    The probability mass function for `zipf` is::

        zipf.pmf(k, a) = 1/(zeta(a) * k**a)

    for ``k >= 1``.

    `zipf` takes ``a`` as shape parameter.

    %(example)s

    """
    def _rvs(self, a):
        return mtrand.zipf(a, size=self._size)

    def _argcheck(self, a):
        return a > 1

    def _pmf(self, k, a):
        Pk = 1.0 / special.zeta(a, 1) / k**a
        return Pk

    def _munp(self, n, a):
        return _lazywhere(
            a > n + 1, (a, n),
            lambda a, n: special.zeta(a - n, 1) / special.zeta(a, 1),
            np.inf)
zipf = zipf_gen(a=1, name='zipf', longname='A Zipf')


class dlaplace_gen(rv_discrete):
    """A  Laplacian discrete random variable.

    %(before_notes)s

    Notes
    -----
    The probability mass function for `dlaplace` is::

        dlaplace.pmf(k) = tanh(a/2) * exp(-a*abs(k))

    for ``a > 0``.

    `dlaplace` takes ``a`` as shape parameter.

    %(example)s

    """
    def _pmf(self, k, a):
        return tanh(a/2.0) * exp(-a * abs(k))

    def _cdf(self, x, a):
        k = floor(x)
        f = lambda k, a: 1.0 - exp(-a * k) / (exp(a) + 1)
        f2 = lambda k, a: exp(a * (k+1)) / (exp(a) + 1)
        return _lazywhere(k >= 0, (k, a), f=f, f2=f2)

    def _ppf(self, q, a):
        const = 1 + exp(a)
        vals = ceil(np.where(q < 1.0 / (1 + exp(-a)), log(q*const) / a - 1,
                                                      -log((1-q) * const) / a))
        vals1 = vals - 1
        return np.where(self._cdf(vals1, a) >= q, vals1, vals)

    def _stats(self, a):
        ea = exp(a)
        mu2 = 2.*ea/(ea-1.)**2
        mu4 = 2.*ea*(ea**2+10.*ea+1.) / (ea-1.)**4
        return 0., mu2, 0., mu4/mu2**2 - 3.

    def _entropy(self, a):
        return a / sinh(a) - log(tanh(a/2.0))
dlaplace = dlaplace_gen(a=-np.inf,
                        name='dlaplace', longname='A discrete Laplacian')


class skellam_gen(rv_discrete):
    """A  Skellam discrete random variable.

    %(before_notes)s

    Notes
    -----
    Probability distribution of the difference of two correlated or
    uncorrelated Poisson random variables.

    Let k1 and k2 be two Poisson-distributed r.v. with expected values
    lam1 and lam2. Then, ``k1 - k2`` follows a Skellam distribution with
    parameters ``mu1 = lam1 - rho*sqrt(lam1*lam2)`` and
    ``mu2 = lam2 - rho*sqrt(lam1*lam2)``, where rho is the correlation
    coefficient between k1 and k2. If the two Poisson-distributed r.v.
    are independent then ``rho = 0``.

    Parameters mu1 and mu2 must be strictly positive.

    For details see: http://en.wikipedia.org/wiki/Skellam_distribution

    `skellam` takes ``mu1`` and ``mu2`` as shape parameters.

    %(example)s

    """
    def _rvs(self, mu1, mu2):
        n = self._size
        return mtrand.poisson(mu1, n) - mtrand.poisson(mu2, n)

    def _pmf(self, x, mu1, mu2):
        px = np.where(x < 0,
                _ncx2_pdf(2*mu2, 2*(1-x), 2*mu1)*2,
                _ncx2_pdf(2*mu1, 2*(1+x), 2*mu2)*2)
        # ncx2.pdf() returns nan's for extremely low probabilities
        return px

    def _cdf(self, x, mu1, mu2):
        x = floor(x)
        px = np.where(x < 0,
                _ncx2_cdf(2*mu2, -2*x, 2*mu1),
                1-_ncx2_cdf(2*mu1, 2*(x+1), 2*mu2))
        return px

    def _stats(self, mu1, mu2):
        mean = mu1 - mu2
        var = mu1 + mu2
        g1 = mean / sqrt((var)**3)
        g2 = 1 / var
        return mean, var, g1, g2
skellam = skellam_gen(a=-np.inf, name="skellam", longname='A Skellam')


# Collect names of classes and objects in this module.
pairs = list(globals().items())
_distn_names, _distn_gen_names = get_distribution_names(pairs, rv_discrete)

__all__ = _distn_names + _distn_gen_names
