#
# Author:  Travis Oliphant  2002-2011 with contributions from
#          SciPy Developers 2004-2011
#
from __future__ import division, print_function, absolute_import

from scipy import special
from scipy.special import entr, gammaln as gamln
from scipy.misc import logsumexp
from scipy._lib._numpy_compat import broadcast_to

from numpy import floor, ceil, log, exp, sqrt, log1p, expm1, tanh, cosh, sinh

import numpy as np

from ._distn_infrastructure import (
        rv_discrete, _lazywhere, _ncx2_pdf, _ncx2_cdf, get_distribution_names)


class binom_gen(rv_discrete):
    """A binomial discrete random variable.

    %(before_notes)s

    Notes
    -----
    The probability mass function for `binom` is::

       binom.pmf(k) = choose(n, k) * p**k * (1-p)**(n-k)

    for ``k`` in ``{0, 1,..., n}``.

    `binom` takes ``n`` and ``p`` as shape parameters.

    %(after_notes)s

    %(example)s

    """
    def _rvs(self, n, p):
        return self._random_state.binomial(n, p, self._size)

    def _argcheck(self, n, p):
        self.b = n
        return (n >= 0) & (p >= 0) & (p <= 1)

    def _logpmf(self, x, n, p):
        k = floor(x)
        combiln = (gamln(n+1) - (gamln(k+1) + gamln(n-k+1)))
        return combiln + special.xlogy(k, p) + special.xlog1py(n-k, -p)

    def _pmf(self, x, n, p):
        return exp(self._logpmf(x, n, p))

    def _cdf(self, x, n, p):
        k = floor(x)
        vals = special.bdtr(k, n, p)
        return vals

    def _sf(self, x, n, p):
        k = floor(x)
        return special.bdtrc(k, n, p)

    def _ppf(self, q, n, p):
        vals = ceil(special.bdtrik(q, n, p))
        vals1 = np.maximum(vals - 1, 0)
        temp = special.bdtr(vals1, n, p)
        return np.where(temp >= q, vals1, vals)

    def _stats(self, n, p, moments='mv'):
        q = 1.0 - p
        mu = n * p
        var = n * p * q
        g1, g2 = None, None
        if 's' in moments:
            g1 = (q - p) / sqrt(var)
        if 'k' in moments:
            g2 = (1.0 - 6*p*q) / var
        return mu, var, g1, g2

    def _entropy(self, n, p):
        k = np.r_[0:n + 1]
        vals = self._pmf(k, n, p)
        return np.sum(entr(vals), axis=0)
binom = binom_gen(name='binom')


class bernoulli_gen(binom_gen):
    """A Bernoulli discrete random variable.

    %(before_notes)s

    Notes
    -----
    The probability mass function for `bernoulli` is::

       bernoulli.pmf(k) = 1-p  if k = 0
                        = p    if k = 1

    for ``k`` in ``{0, 1}``.

    `bernoulli` takes ``p`` as shape parameter.

    %(after_notes)s

    %(example)s

    """
    def _rvs(self, p):
        return binom_gen._rvs(self, 1, p)

    def _argcheck(self, p):
        return (p >= 0) & (p <= 1)

    def _logpmf(self, x, p):
        return binom._logpmf(x, 1, p)

    def _pmf(self, x, p):
        return binom._pmf(x, 1, p)

    def _cdf(self, x, p):
        return binom._cdf(x, 1, p)

    def _sf(self, x, p):
        return binom._sf(x, 1, p)

    def _ppf(self, q, p):
        return binom._ppf(q, 1, p)

    def _stats(self, p):
        return binom._stats(1, p)

    def _entropy(self, p):
        return entr(p) + entr(1-p)
bernoulli = bernoulli_gen(b=1, name='bernoulli')


class nbinom_gen(rv_discrete):
    """A negative binomial discrete random variable.

    %(before_notes)s

    Notes
    -----
    The probability mass function for `nbinom` is::

         nbinom.pmf(k) = choose(k+n-1, n-1) * p**n * (1-p)**k

    for ``k >= 0``.

    `nbinom` takes ``n`` and ``p`` as shape parameters.

    %(after_notes)s

    %(example)s

    """
    def _rvs(self, n, p):
        return self._random_state.negative_binomial(n, p, self._size)

    def _argcheck(self, n, p):
        return (n > 0) & (p >= 0) & (p <= 1)

    def _pmf(self, x, n, p):
        return exp(self._logpmf(x, n, p))

    def _logpmf(self, x, n, p):
        coeff = gamln(n+x) - gamln(x+1) - gamln(n)
        return coeff + n*log(p) + special.xlog1py(x, -p)

    def _cdf(self, x, n, p):
        k = floor(x)
        return special.betainc(n, k+1, p)

    def _sf_skip(self, x, n, p):
        # skip because special.nbdtrc doesn't work for 0<n<1
        k = floor(x)
        return special.nbdtrc(k, n, p)

    def _ppf(self, q, n, p):
        vals = ceil(special.nbdtrik(q, n, p))
        vals1 = (vals-1).clip(0.0, np.inf)
        temp = self._cdf(vals1, n, p)
        return np.where(temp >= q, vals1, vals)

    def _stats(self, n, p):
        Q = 1.0 / p
        P = Q - 1.0
        mu = n*P
        var = n*P*Q
        g1 = (Q+P)/sqrt(n*P*Q)
        g2 = (1.0 + 6*P*Q) / (n*P*Q)
        return mu, var, g1, g2
nbinom = nbinom_gen(name='nbinom')


class geom_gen(rv_discrete):
    """A geometric discrete random variable.

    %(before_notes)s

    Notes
    -----
    The probability mass function for `geom` is::

        geom.pmf(k) = (1-p)**(k-1)*p

    for ``k >= 1``.

    `geom` takes ``p`` as shape parameter.

    %(after_notes)s

    %(example)s

    """
    def _rvs(self, p):
        return self._random_state.geometric(p, size=self._size)

    def _argcheck(self, p):
        return (p <= 1) & (p >= 0)

    def _pmf(self, k, p):
        return np.power(1-p, k-1) * p

    def _logpmf(self, k, p):
        return special.xlog1py(k - 1, -p) + log(p)

    def _cdf(self, x, p):
        k = floor(x)
        return -expm1(log1p(-p)*k)

    def _sf(self, x, p):
        return np.exp(self._logsf(x, p))

    def _logsf(self, x, p):
        k = floor(x)
        return k*log1p(-p)

    def _ppf(self, q, p):
        vals = ceil(log(1.0-q)/log(1-p))
        temp = self._cdf(vals-1, p)
        return np.where((temp >= q) & (vals > 0), vals-1, vals)

    def _stats(self, p):
        mu = 1.0/p
        qr = 1.0-p
        var = qr / p / p
        g1 = (2.0-p) / sqrt(qr)
        g2 = np.polyval([1, -6, 6], p)/(1.0-p)
        return mu, var, g1, g2
geom = geom_gen(a=1, name='geom', longname="A geometric")


class hypergeom_gen(rv_discrete):
    """A hypergeometric discrete random variable.

    The hypergeometric distribution models drawing objects from a bin.
    M is the total number of objects, n is total number of Type I objects.
    The random variate represents the number of Type I objects in N drawn
    without replacement from the total population.

    %(before_notes)s

    Notes
    -----
    The probability mass function is defined as::

        pmf(k, M, n, N) = choose(n, k) * choose(M - n, N - k) / choose(M, N),
                                       for max(0, N - (M-n)) <= k <= min(n, N)

    %(after_notes)s

    Examples
    --------
    >>> from scipy.stats import hypergeom
    >>> import matplotlib.pyplot as plt

    Suppose we have a collection of 20 animals, of which 7 are dogs.  Then if
    we want to know the probability of finding a given number of dogs if we
    choose at random 12 of the 20 animals, we can initialize a frozen
    distribution and plot the probability mass function:

    >>> [M, n, N] = [20, 7, 12]
    >>> rv = hypergeom(M, n, N)
    >>> x = np.arange(0, n+1)
    >>> pmf_dogs = rv.pmf(x)

    >>> fig = plt.figure()
    >>> ax = fig.add_subplot(111)
    >>> ax.plot(x, pmf_dogs, 'bo')
    >>> ax.vlines(x, 0, pmf_dogs, lw=2)
    >>> ax.set_xlabel('# of dogs in our group of chosen animals')
    >>> ax.set_ylabel('hypergeom PMF')
    >>> plt.show()

    Instead of using a frozen distribution we can also use `hypergeom`
    methods directly.  To for example obtain the cumulative distribution
    function, use:

    >>> prb = hypergeom.cdf(x, M, n, N)

    And to generate random numbers:

    >>> R = hypergeom.rvs(M, n, N, size=10)

    """
    def _rvs(self, M, n, N):
        return self._random_state.hypergeometric(n, M-n, N, size=self._size)

    def _argcheck(self, M, n, N):
        cond = (M > 0) & (n >= 0) & (N >= 0)
        cond &= (n <= M) & (N <= M)
        self.a = np.maximum(N-(M-n), 0)
        self.b = np.minimum(n, N)
        return cond

    def _logpmf(self, k, M, n, N):
        tot, good = M, n
        bad = tot - good
        return gamln(good+1) - gamln(good-k+1) - gamln(k+1) + gamln(bad+1) \
            - gamln(bad-N+k+1) - gamln(N-k+1) - gamln(tot+1) + gamln(tot-N+1) \
            + gamln(N+1)

    def _pmf(self, k, M, n, N):
        # same as the following but numerically more precise
        # return comb(good, k) * comb(bad, N-k) / comb(tot, N)
        return exp(self._logpmf(k, M, n, N))

    def _stats(self, M, n, N):
        # tot, good, sample_size = M, n, N
        # "wikipedia".replace('N', 'M').replace('n', 'N').replace('K', 'n')
        M, n, N = 1.*M, 1.*n, 1.*N
        m = M - n
        p = n/M
        mu = N*p

        var = m*n*N*(M - N)*1.0/(M*M*(M-1))
        g1 = (m - n)*(M-2*N) / (M-2.0) * sqrt((M-1.0) / (m*n*N*(M-N)))

        g2 = M*(M+1) - 6.*N*(M-N) - 6.*n*m
        g2 *= (M-1)*M*M
        g2 += 6.*n*N*(M-N)*m*(5.*M-6)
        g2 /= n * N * (M-N) * m * (M-2.) * (M-3.)
        return mu, var, g1, g2

    def _entropy(self, M, n, N):
        k = np.r_[N - (M - n):min(n, N) + 1]
        vals = self.pmf(k, M, n, N)
        return np.sum(entr(vals), axis=0)

    def _sf(self, k, M, n, N):
        """More precise calculation, 1 - cdf doesn't cut it."""
        # This for loop is needed because `k` can be an array. If that's the
        # case, the sf() method makes M, n and N arrays of the same shape. We
        # therefore unpack all inputs args, so we can do the manual
        # integration.
        res = []
        for quant, tot, good, draw in zip(k, M, n, N):
            # Manual integration over probability mass function. More accurate
            # than integrate.quad.
            k2 = np.arange(quant + 1, draw + 1)
            res.append(np.sum(self._pmf(k2, tot, good, draw)))
        return np.asarray(res)
        
    def _logsf(self, k, M, n, N):
        """
        More precise calculation than log(sf)
        """
        res = []
        for quant, tot, good, draw in zip(k, M, n, N):
            # Integration over probability mass function using logsumexp
            k2 = np.arange(quant + 1, draw + 1)
            res.append(logsumexp(self._logpmf(k2, tot, good, draw)))
        return np.asarray(res)
hypergeom = hypergeom_gen(name='hypergeom')


# FIXME: Fails _cdfvec
class logser_gen(rv_discrete):
    """A Logarithmic (Log-Series, Series) discrete random variable.

    %(before_notes)s

    Notes
    -----
    The probability mass function for `logser` is::

        logser.pmf(k) = - p**k / (k*log(1-p))

    for ``k >= 1``.

    `logser` takes ``p`` as shape parameter.

    %(after_notes)s

    %(example)s

    """
    def _rvs(self, p):
        # looks wrong for p>0.5, too few k=1
        # trying to use generic is worse, no k=1 at all
        return self._random_state.logseries(p, size=self._size)

    def _argcheck(self, p):
        return (p > 0) & (p < 1)

    def _pmf(self, k, p):
        return -np.power(p, k) * 1.0 / k / log(1 - p)

    def _stats(self, p):
        r = log(1 - p)
        mu = p / (p - 1.0) / r
        mu2p = -p / r / (p - 1.0)**2
        var = mu2p - mu*mu
        mu3p = -p / r * (1.0+p) / (1.0 - p)**3
        mu3 = mu3p - 3*mu*mu2p + 2*mu**3
        g1 = mu3 / np.power(var, 1.5)

        mu4p = -p / r * (
            1.0 / (p-1)**2 - 6*p / (p - 1)**3 + 6*p*p / (p-1)**4)
        mu4 = mu4p - 4*mu3p*mu + 6*mu2p*mu*mu - 3*mu**4
        g2 = mu4 / var**2 - 3.0
        return mu, var, g1, g2
logser = logser_gen(a=1, name='logser', longname='A logarithmic')


class poisson_gen(rv_discrete):
    """A Poisson discrete random variable.

    %(before_notes)s

    Notes
    -----
    The probability mass function for `poisson` is::

        poisson.pmf(k) = exp(-mu) * mu**k / k!

    for ``k >= 0``.

    `poisson` takes ``mu`` as shape parameter.

    %(after_notes)s

    %(example)s

    """

    # Override rv_discrete._argcheck to allow mu=0.
    def _argcheck(self, mu):
        return mu >= 0

    def _rvs(self, mu):
        return self._random_state.poisson(mu, self._size)

    def _logpmf(self, k, mu):
        Pk = special.xlogy(k, mu) - gamln(k + 1) - mu
        return Pk

    def _pmf(self, k, mu):
        return exp(self._logpmf(k, mu))

    def _cdf(self, x, mu):
        k = floor(x)
        return special.pdtr(k, mu)

    def _sf(self, x, mu):
        k = floor(x)
        return special.pdtrc(k, mu)

    def _ppf(self, q, mu):
        vals = ceil(special.pdtrik(q, mu))
        vals1 = np.maximum(vals - 1, 0)
        temp = special.pdtr(vals1, mu)
        return np.where(temp >= q, vals1, vals)

    def _stats(self, mu):
        var = mu
        tmp = np.asarray(mu)
        mu_nonzero = tmp > 0
        g1 = _lazywhere(mu_nonzero, (tmp,), lambda x: sqrt(1.0/x), np.inf)
        g2 = _lazywhere(mu_nonzero, (tmp,), lambda x: 1.0/x, np.inf)
        return mu, var, g1, g2

poisson = poisson_gen(name="poisson", longname='A Poisson')


class planck_gen(rv_discrete):
    """A Planck discrete exponential random variable.

    %(before_notes)s

    Notes
    -----
    The probability mass function for `planck` is::

        planck.pmf(k) = (1-exp(-lambda_))*exp(-lambda_*k)

    for ``k*lambda_ >= 0``.

    `planck` takes ``lambda_`` as shape parameter.

    %(after_notes)s

    %(example)s

    """
    def _argcheck(self, lambda_):
        self.a = np.where(lambda_ > 0, 0, -np.inf)
        self.b = np.where(lambda_ > 0, np.inf, 0)
        return lambda_ != 0

    def _pmf(self, k, lambda_):
        fact = (1-exp(-lambda_))
        return fact*exp(-lambda_*k)

    def _cdf(self, x, lambda_):
        k = floor(x)
        return 1-exp(-lambda_*(k+1))

    def _ppf(self, q, lambda_):
        vals = ceil(-1.0/lambda_ * log1p(-q)-1)
        vals1 = (vals-1).clip(self.a, np.inf)
        temp = self._cdf(vals1, lambda_)
        return np.where(temp >= q, vals1, vals)

    def _stats(self, lambda_):
        mu = 1/(exp(lambda_)-1)
        var = exp(-lambda_)/(expm1(-lambda_))**2
        g1 = 2*cosh(lambda_/2.0)
        g2 = 4+2*cosh(lambda_)
        return mu, var, g1, g2

    def _entropy(self, lambda_):
        l = lambda_
        C = (1-exp(-l))
        return l*exp(-l)/C - log(C)
planck = planck_gen(name='planck', longname='A discrete exponential ')


class boltzmann_gen(rv_discrete):
    """A Boltzmann (Truncated Discrete Exponential) random variable.

    %(before_notes)s

    Notes
    -----
    The probability mass function for `boltzmann` is::

        boltzmann.pmf(k) = (1-exp(-lambda_)*exp(-lambda_*k)/(1-exp(-lambda_*N))

    for ``k = 0,..., N-1``.

    `boltzmann` takes ``lambda_`` and ``N`` as shape parameters.

    %(after_notes)s

    %(example)s

    """
    def _pmf(self, k, lambda_, N):
        fact = (1-exp(-lambda_))/(1-exp(-lambda_*N))
        return fact*exp(-lambda_*k)

    def _cdf(self, x, lambda_, N):
        k = floor(x)
        return (1-exp(-lambda_*(k+1)))/(1-exp(-lambda_*N))

    def _ppf(self, q, lambda_, N):
        qnew = q*(1-exp(-lambda_*N))
        vals = ceil(-1.0/lambda_ * log(1-qnew)-1)
        vals1 = (vals-1).clip(0.0, np.inf)
        temp = self._cdf(vals1, lambda_, N)
        return np.where(temp >= q, vals1, vals)

    def _stats(self, lambda_, N):
        z = exp(-lambda_)
        zN = exp(-lambda_*N)
        mu = z/(1.0-z)-N*zN/(1-zN)
        var = z/(1.0-z)**2 - N*N*zN/(1-zN)**2
        trm = (1-zN)/(1-z)
        trm2 = (z*trm**2 - N*N*zN)
        g1 = z*(1+z)*trm**3 - N**3*zN*(1+zN)
        g1 = g1 / trm2**(1.5)
        g2 = z*(1+4*z+z*z)*trm**4 - N**4 * zN*(1+4*zN+zN*zN)
        g2 = g2 / trm2 / trm2
        return mu, var, g1, g2
boltzmann = boltzmann_gen(name='boltzmann',
        longname='A truncated discrete exponential ')


class randint_gen(rv_discrete):
    """A uniform discrete random variable.

    %(before_notes)s

    Notes
    -----
    The probability mass function for `randint` is::

        randint.pmf(k) = 1./(high - low)

    for ``k = low, ..., high - 1``.

    `randint` takes ``low`` and ``high`` as shape parameters.

    %(after_notes)s

    %(example)s

    """
    def _argcheck(self, low, high):
        self.a = low
        self.b = high - 1
        return (high > low)

    def _pmf(self, k, low, high):
        p = np.ones_like(k) / (high - low)
        return np.where((k >= low) & (k < high), p, 0.)

    def _cdf(self, x, low, high):
        k = floor(x)
        return (k - low + 1.) / (high - low)

    def _ppf(self, q, low, high):
        vals = ceil(q * (high - low) + low) - 1
        vals1 = (vals - 1).clip(low, high)
        temp = self._cdf(vals1, low, high)
        return np.where(temp >= q, vals1, vals)

    def _stats(self, low, high):
        m2, m1 = np.asarray(high), np.asarray(low)
        mu = (m2 + m1 - 1.0) / 2
        d = m2 - m1
        var = (d*d - 1) / 12.0
        g1 = 0.0
        g2 = -6.0/5.0 * (d*d + 1.0) / (d*d - 1.0)
        return mu, var, g1, g2

    def _rvs(self, low, high):
        """An array of *size* random integers >= ``low`` and < ``high``."""
        if self._size is not None:
            # Numpy's RandomState.randint() doesn't broadcast its arguments.
            # Use `broadcast_to()` to extend the shapes of low and high
            # up to self._size.  Then we can use the numpy.vectorize'd
            # randint without needing to pass it a `size` argument.
            low = broadcast_to(low, self._size)
            high = broadcast_to(high, self._size)
        randint = np.vectorize(self._random_state.randint, otypes=[np.int_])
        return randint(low, high)

    def _entropy(self, low, high):
        return log(high - low)

randint = randint_gen(name='randint', longname='A discrete uniform '
                      '(random integer)')


# FIXME: problems sampling.
class zipf_gen(rv_discrete):
    """A Zipf discrete random variable.

    %(before_notes)s

    Notes
    -----
    The probability mass function for `zipf` is::

        zipf.pmf(k, a) = 1/(zeta(a) * k**a)

    for ``k >= 1``.

    `zipf` takes ``a`` as shape parameter.

    %(after_notes)s

    %(example)s

    """
    def _rvs(self, a):
        return self._random_state.zipf(a, size=self._size)

    def _argcheck(self, a):
        return a > 1

    def _pmf(self, k, a):
        Pk = 1.0 / special.zeta(a, 1) / k**a
        return Pk

    def _munp(self, n, a):
        return _lazywhere(
            a > n + 1, (a, n),
            lambda a, n: special.zeta(a - n, 1) / special.zeta(a, 1),
            np.inf)
zipf = zipf_gen(a=1, name='zipf', longname='A Zipf')


class dlaplace_gen(rv_discrete):
    """A  Laplacian discrete random variable.

    %(before_notes)s

    Notes
    -----
    The probability mass function for `dlaplace` is::

        dlaplace.pmf(k) = tanh(a/2) * exp(-a*abs(k))

    for ``a > 0``.

    `dlaplace` takes ``a`` as shape parameter.

    %(after_notes)s

    %(example)s

    """
    def _pmf(self, k, a):
        return tanh(a/2.0) * exp(-a * abs(k))

    def _cdf(self, x, a):
        k = floor(x)
        f = lambda k, a: 1.0 - exp(-a * k) / (exp(a) + 1)
        f2 = lambda k, a: exp(a * (k+1)) / (exp(a) + 1)
        return _lazywhere(k >= 0, (k, a), f=f, f2=f2)

    def _ppf(self, q, a):
        const = 1 + exp(a)
        vals = ceil(np.where(q < 1.0 / (1 + exp(-a)), log(q*const) / a - 1,
                                                      -log((1-q) * const) / a))
        vals1 = vals - 1
        return np.where(self._cdf(vals1, a) >= q, vals1, vals)

    def _stats(self, a):
        ea = exp(a)
        mu2 = 2.*ea/(ea-1.)**2
        mu4 = 2.*ea*(ea**2+10.*ea+1.) / (ea-1.)**4
        return 0., mu2, 0., mu4/mu2**2 - 3.

    def _entropy(self, a):
        return a / sinh(a) - log(tanh(a/2.0))
dlaplace = dlaplace_gen(a=-np.inf,
                        name='dlaplace', longname='A discrete Laplacian')


class skellam_gen(rv_discrete):
    """A  Skellam discrete random variable.

    %(before_notes)s

    Notes
    -----
    Probability distribution of the difference of two correlated or
    uncorrelated Poisson random variables.

    Let k1 and k2 be two Poisson-distributed r.v. with expected values
    lam1 and lam2. Then, ``k1 - k2`` follows a Skellam distribution with
    parameters ``mu1 = lam1 - rho*sqrt(lam1*lam2)`` and
    ``mu2 = lam2 - rho*sqrt(lam1*lam2)``, where rho is the correlation
    coefficient between k1 and k2. If the two Poisson-distributed r.v.
    are independent then ``rho = 0``.

    Parameters mu1 and mu2 must be strictly positive.

    For details see: http://en.wikipedia.org/wiki/Skellam_distribution

    `skellam` takes ``mu1`` and ``mu2`` as shape parameters.

    %(after_notes)s

    %(example)s

    """
    def _rvs(self, mu1, mu2):
        n = self._size
        return (self._random_state.poisson(mu1, n) -
                self._random_state.poisson(mu2, n))

    def _pmf(self, x, mu1, mu2):
        px = np.where(x < 0,
                _ncx2_pdf(2*mu2, 2*(1-x), 2*mu1)*2,
                _ncx2_pdf(2*mu1, 2*(1+x), 2*mu2)*2)
        # ncx2.pdf() returns nan's for extremely low probabilities
        return px

    def _cdf(self, x, mu1, mu2):
        x = floor(x)
        px = np.where(x < 0,
                _ncx2_cdf(2*mu2, -2*x, 2*mu1),
                1-_ncx2_cdf(2*mu1, 2*(x+1), 2*mu2))
        return px

    def _stats(self, mu1, mu2):
        mean = mu1 - mu2
        var = mu1 + mu2
        g1 = mean / sqrt((var)**3)
        g2 = 1 / var
        return mean, var, g1, g2
skellam = skellam_gen(a=-np.inf, name="skellam", longname='A Skellam')


# Collect names of classes and objects in this module.
pairs = list(globals().items())
_distn_names, _distn_gen_names = get_distribution_names(pairs, rv_discrete)

__all__ = _distn_names + _distn_gen_names
