# Owner(s): ["module: distributions"]

import io
from numbers import Number

import pytest

import torch
from torch.autograd.functional import jacobian
from torch.distributions import Dirichlet, Independent, Normal, TransformedDistribution, constraints
from torch.distributions.transforms import (AbsTransform, AffineTransform, ComposeTransform,
                                            CorrCholeskyTransform, CumulativeDistributionTransform,
                                            ExpTransform, IndependentTransform,
                                            LowerCholeskyTransform, PowerTransform,
                                            ReshapeTransform, SigmoidTransform, TanhTransform,
                                            SoftmaxTransform, SoftplusTransform, StickBreakingTransform,
                                            identity_transform, Transform, _InverseTransform)
from torch.distributions.utils import tril_matrix_to_vec, vec_to_tril_matrix


def get_transforms(cache_size):
    transforms = [
        AbsTransform(cache_size=cache_size),
        ExpTransform(cache_size=cache_size),
        PowerTransform(exponent=2,
                       cache_size=cache_size),
        PowerTransform(exponent=torch.tensor(5.).normal_(),
                       cache_size=cache_size),
        PowerTransform(exponent=torch.tensor(5.).normal_(),
                       cache_size=cache_size),
        SigmoidTransform(cache_size=cache_size),
        TanhTransform(cache_size=cache_size),
        AffineTransform(0, 1, cache_size=cache_size),
        AffineTransform(1, -2, cache_size=cache_size),
        AffineTransform(torch.randn(5),
                        torch.randn(5),
                        cache_size=cache_size),
        AffineTransform(torch.randn(4, 5),
                        torch.randn(4, 5),
                        cache_size=cache_size),
        SoftmaxTransform(cache_size=cache_size),
        SoftplusTransform(cache_size=cache_size),
        StickBreakingTransform(cache_size=cache_size),
        LowerCholeskyTransform(cache_size=cache_size),
        CorrCholeskyTransform(cache_size=cache_size),
        ComposeTransform([
            AffineTransform(torch.randn(4, 5),
                            torch.randn(4, 5),
                            cache_size=cache_size),
        ]),
        ComposeTransform([
            AffineTransform(torch.randn(4, 5),
                            torch.randn(4, 5),
                            cache_size=cache_size),
            ExpTransform(cache_size=cache_size),
        ]),
        ComposeTransform([
            AffineTransform(0, 1, cache_size=cache_size),
            AffineTransform(torch.randn(4, 5),
                            torch.randn(4, 5),
                            cache_size=cache_size),
            AffineTransform(1, -2, cache_size=cache_size),
            AffineTransform(torch.randn(4, 5),
                            torch.randn(4, 5),
                            cache_size=cache_size),
        ]),
        ReshapeTransform((4, 5), (2, 5, 2)),
        IndependentTransform(
            AffineTransform(torch.randn(5),
                            torch.randn(5),
                            cache_size=cache_size),
            1),
        CumulativeDistributionTransform(Normal(0, 1)),
    ]
    transforms += [t.inv for t in transforms]
    return transforms


def reshape_transform(transform, shape):
    # Needed to squash batch dims for testing jacobian
    if isinstance(transform, AffineTransform):
        if isinstance(transform.loc, Number):
            return transform
        try:
            return AffineTransform(transform.loc.expand(shape), transform.scale.expand(shape), cache_size=transform._cache_size)
        except RuntimeError:
            return AffineTransform(transform.loc.reshape(shape), transform.scale.reshape(shape), cache_size=transform._cache_size)
    if isinstance(transform, ComposeTransform):
        reshaped_parts = []
        for p in transform.parts:
            reshaped_parts.append(reshape_transform(p, shape))
        return ComposeTransform(reshaped_parts, cache_size=transform._cache_size)
    if isinstance(transform.inv, AffineTransform):
        return reshape_transform(transform.inv, shape).inv
    if isinstance(transform.inv, ComposeTransform):
        return reshape_transform(transform.inv, shape).inv
    return transform


# Generate pytest ids
def transform_id(x):
    assert isinstance(x, Transform)
    name = f'Inv({type(x._inv).__name__})' if isinstance(x, _InverseTransform) else f'{type(x).__name__}'
    return f'{name}(cache_size={x._cache_size})'


def generate_data(transform):
    torch.manual_seed(1)
    while isinstance(transform, IndependentTransform):
        transform = transform.base_transform
    if isinstance(transform, ReshapeTransform):
        return torch.randn(transform.in_shape)
    if isinstance(transform.inv, ReshapeTransform):
        return torch.randn(transform.inv.out_shape)
    domain = transform.domain
    while (isinstance(domain, constraints.independent) and
           domain is not constraints.real_vector):
        domain = domain.base_constraint
    codomain = transform.codomain
    x = torch.empty(4, 5)
    if domain is constraints.lower_cholesky or codomain is constraints.lower_cholesky:
        x = torch.empty(6, 6)
        x = x.normal_()
        return x
    elif domain is constraints.real:
        return x.normal_()
    elif domain is constraints.real_vector:
        # For corr_cholesky the last dim in the vector
        # must be of size (dim * dim) // 2
        x = torch.empty(3, 6)
        x = x.normal_()
        return x
    elif domain is constraints.positive:
        return x.normal_().exp()
    elif domain is constraints.unit_interval:
        return x.uniform_()
    elif isinstance(domain, constraints.interval):
        x = x.uniform_()
        x = x.mul_(domain.upper_bound - domain.lower_bound).add_(domain.lower_bound)
        return x
    elif domain is constraints.simplex:
        x = x.normal_().exp()
        x /= x.sum(-1, True)
        return x
    elif domain is constraints.corr_cholesky:
        x = torch.empty(4, 5, 5)
        x = x.normal_().tril()
        x /= x.norm(dim=-1, keepdim=True)
        x.diagonal(dim1=-1).copy_(x.diagonal(dim1=-1).abs())
        return x
    raise ValueError('Unsupported domain: {}'.format(domain))


TRANSFORMS_CACHE_ACTIVE = get_transforms(cache_size=1)
TRANSFORMS_CACHE_INACTIVE = get_transforms(cache_size=0)
ALL_TRANSFORMS = TRANSFORMS_CACHE_ACTIVE + TRANSFORMS_CACHE_INACTIVE + [identity_transform]


@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id)
def test_inv_inv(transform, ids=transform_id):
    assert transform.inv.inv is transform


@pytest.mark.parametrize('x', TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
@pytest.mark.parametrize('y', TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
def test_equality(x, y):
    if x is y:
        assert x == y
    else:
        assert x != y
    assert identity_transform == identity_transform.inv


@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id)
def test_with_cache(transform):
    if transform._cache_size == 0:
        transform = transform.with_cache(1)
    assert transform._cache_size == 1
    x = generate_data(transform).requires_grad_()
    try:
        y = transform(x)
    except NotImplementedError:
        pytest.skip('Not implemented.')
    y2 = transform(x)
    assert y2 is y


@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id)
@pytest.mark.parametrize('test_cached', [True, False])
def test_forward_inverse(transform, test_cached):
    x = generate_data(transform).requires_grad_()
    try:
        y = transform(x)
    except NotImplementedError:
        pytest.skip('Not implemented.')
    assert y.shape == transform.forward_shape(x.shape)
    if test_cached:
        x2 = transform.inv(y)  # should be implemented at least by caching
    else:
        try:
            x2 = transform.inv(y.clone())  # bypass cache
        except NotImplementedError:
            pytest.skip('Not implemented.')
    assert x2.shape == transform.inverse_shape(y.shape)
    y2 = transform(x2)
    if transform.bijective:
        # verify function inverse
        assert torch.allclose(x2, x, atol=1e-4, equal_nan=True), '\n'.join([
            '{} t.inv(t(-)) error'.format(transform),
            'x = {}'.format(x),
            'y = t(x) = {}'.format(y),
            'x2 = t.inv(y) = {}'.format(x2),
        ])
    else:
        # verify weaker function pseudo-inverse
        assert torch.allclose(y2, y, atol=1e-4, equal_nan=True), '\n'.join([
            '{} t(t.inv(t(-))) error'.format(transform),
            'x = {}'.format(x),
            'y = t(x) = {}'.format(y),
            'x2 = t.inv(y) = {}'.format(x2),
            'y2 = t(x2) = {}'.format(y2),
        ])


def test_compose_transform_shapes():
    transform0 = ExpTransform()
    transform1 = SoftmaxTransform()
    transform2 = LowerCholeskyTransform()

    assert transform0.event_dim == 0
    assert transform1.event_dim == 1
    assert transform2.event_dim == 2
    assert ComposeTransform([transform0, transform1]).event_dim == 1
    assert ComposeTransform([transform0, transform2]).event_dim == 2
    assert ComposeTransform([transform1, transform2]).event_dim == 2


transform0 = ExpTransform()
transform1 = SoftmaxTransform()
transform2 = LowerCholeskyTransform()
base_dist0 = Normal(torch.zeros(4, 4), torch.ones(4, 4))
base_dist1 = Dirichlet(torch.ones(4, 4))
base_dist2 = Normal(torch.zeros(3, 4, 4), torch.ones(3, 4, 4))


@pytest.mark.parametrize('batch_shape, event_shape, dist', [
    ((4, 4), (), base_dist0),
    ((4,), (4,), base_dist1),
    ((4, 4), (), TransformedDistribution(base_dist0, [transform0])),
    ((4,), (4,), TransformedDistribution(base_dist0, [transform1])),
    ((4,), (4,), TransformedDistribution(base_dist0, [transform0, transform1])),
    ((), (4, 4), TransformedDistribution(base_dist0, [transform0, transform2])),
    ((4,), (4,), TransformedDistribution(base_dist0, [transform1, transform0])),
    ((), (4, 4), TransformedDistribution(base_dist0, [transform1, transform2])),
    ((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform0])),
    ((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform1])),
    ((4,), (4,), TransformedDistribution(base_dist1, [transform0])),
    ((4,), (4,), TransformedDistribution(base_dist1, [transform1])),
    ((), (4, 4), TransformedDistribution(base_dist1, [transform2])),
    ((4,), (4,), TransformedDistribution(base_dist1, [transform0, transform1])),
    ((), (4, 4), TransformedDistribution(base_dist1, [transform0, transform2])),
    ((4,), (4,), TransformedDistribution(base_dist1, [transform1, transform0])),
    ((), (4, 4), TransformedDistribution(base_dist1, [transform1, transform2])),
    ((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform0])),
    ((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform1])),
    ((3, 4, 4), (), base_dist2),
    ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2])),
    ((3,), (4, 4), TransformedDistribution(base_dist2, [transform0, transform2])),
    ((3,), (4, 4), TransformedDistribution(base_dist2, [transform1, transform2])),
    ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform0])),
    ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform1])),
])
def test_transformed_distribution_shapes(batch_shape, event_shape, dist):
    assert dist.batch_shape == batch_shape
    assert dist.event_shape == event_shape
    x = dist.rsample()
    try:
        dist.log_prob(x)  # this should not crash
    except NotImplementedError:
        pytest.skip('Not implemented.')


@pytest.mark.parametrize('transform', TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
def test_jit_fwd(transform):
    x = generate_data(transform).requires_grad_()

    def f(x):
        return transform(x)

    try:
        traced_f = torch.jit.trace(f, (x,))
    except NotImplementedError:
        pytest.skip('Not implemented.')

    # check on different inputs
    x = generate_data(transform).requires_grad_()
    assert torch.allclose(f(x), traced_f(x), atol=1e-5, equal_nan=True)


@pytest.mark.parametrize('transform', TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
def test_jit_inv(transform):
    y = generate_data(transform.inv).requires_grad_()

    def f(y):
        return transform.inv(y)

    try:
        traced_f = torch.jit.trace(f, (y,))
    except NotImplementedError:
        pytest.skip('Not implemented.')

    # check on different inputs
    y = generate_data(transform.inv).requires_grad_()
    assert torch.allclose(f(y), traced_f(y), atol=1e-5, equal_nan=True)


@pytest.mark.parametrize('transform', TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
def test_jit_jacobian(transform):
    x = generate_data(transform).requires_grad_()

    def f(x):
        y = transform(x)
        return transform.log_abs_det_jacobian(x, y)

    try:
        traced_f = torch.jit.trace(f, (x,))
    except NotImplementedError:
        pytest.skip('Not implemented.')

    # check on different inputs
    x = generate_data(transform).requires_grad_()
    assert torch.allclose(f(x), traced_f(x), atol=1e-5, equal_nan=True)


@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id)
def test_jacobian(transform):
    x = generate_data(transform)
    try:
        y = transform(x)
        actual = transform.log_abs_det_jacobian(x, y)
    except NotImplementedError:
        pytest.skip('Not implemented.')
    # Test shape
    target_shape = x.shape[:x.dim() - transform.domain.event_dim]
    assert actual.shape == target_shape

    # Expand if required
    transform = reshape_transform(transform, x.shape)
    ndims = len(x.shape)
    event_dim = ndims - transform.domain.event_dim
    x_ = x.view((-1,) + x.shape[event_dim:])
    n = x_.shape[0]
    # Reshape to squash batch dims to a single batch dim
    transform = reshape_transform(transform, x_.shape)

    # 1. Transforms with unit jacobian
    if isinstance(transform, ReshapeTransform) or isinstance(transform.inv, ReshapeTransform):
        expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim])
        expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim])
    # 2. Transforms with 0 off-diagonal elements
    elif transform.domain.event_dim == 0:
        jac = jacobian(transform, x_)
        # assert off-diagonal elements are zero
        assert torch.allclose(jac, jac.diagonal().diag_embed())
        expected = jac.diagonal().abs().log().reshape(x.shape)
    # 3. Transforms with non-0 off-diagonal elements
    else:
        if isinstance(transform, CorrCholeskyTransform):
            jac = jacobian(lambda x: tril_matrix_to_vec(transform(x), diag=-1), x_)
        elif isinstance(transform.inv, CorrCholeskyTransform):
            jac = jacobian(lambda x: transform(vec_to_tril_matrix(x, diag=-1)),
                           tril_matrix_to_vec(x_, diag=-1))
        elif isinstance(transform, StickBreakingTransform):
            jac = jacobian(lambda x: transform(x)[..., :-1], x_)
        else:
            jac = jacobian(transform, x_)

        # Note that jacobian will have shape (batch_dims, y_event_dims, batch_dims, x_event_dims)
        # However, batches are independent so this can be converted into a (batch_dims, event_dims, event_dims)
        # after reshaping the event dims (see above) to give a batched square matrix whose determinant
        # can be computed.
        gather_idx_shape = list(jac.shape)
        gather_idx_shape[-2] = 1
        gather_idxs = torch.arange(n).reshape((n,) + (1,) * (len(jac.shape) - 1)).expand(gather_idx_shape)
        jac = jac.gather(-2, gather_idxs).squeeze(-2)
        out_ndims = jac.shape[-2]
        jac = jac[..., :out_ndims]  # Remove extra zero-valued dims (for inverse stick-breaking).
        expected = torch.slogdet(jac).logabsdet

    assert torch.allclose(actual, expected, atol=1e-5)


@pytest.mark.parametrize("event_dims",
                         [(0,), (1,), (2, 3), (0, 1, 2), (1, 2, 0), (2, 0, 1)],
                         ids=str)
def test_compose_affine(event_dims):
    transforms = [AffineTransform(torch.zeros((1,) * e), 1, event_dim=e) for e in event_dims]
    transform = ComposeTransform(transforms)
    assert transform.codomain.event_dim == max(event_dims)
    assert transform.domain.event_dim == max(event_dims)

    base_dist = Normal(0, 1)
    if transform.domain.event_dim:
        base_dist = base_dist.expand((1,) * transform.domain.event_dim)
    dist = TransformedDistribution(base_dist, transform.parts)
    assert dist.support.event_dim == max(event_dims)

    base_dist = Dirichlet(torch.ones(5))
    if transform.domain.event_dim > 1:
        base_dist = base_dist.expand((1,) * (transform.domain.event_dim - 1))
    dist = TransformedDistribution(base_dist, transforms)
    assert dist.support.event_dim == max(1, max(event_dims))


@pytest.mark.parametrize("batch_shape", [(), (6,), (5, 4)], ids=str)
def test_compose_reshape(batch_shape):
    transforms = [ReshapeTransform((), ()),
                  ReshapeTransform((2,), (1, 2)),
                  ReshapeTransform((3, 1, 2), (6,)),
                  ReshapeTransform((6,), (2, 3))]
    transform = ComposeTransform(transforms)
    assert transform.codomain.event_dim == 2
    assert transform.domain.event_dim == 2
    data = torch.randn(batch_shape + (3, 2))
    assert transform(data).shape == batch_shape + (2, 3)

    dist = TransformedDistribution(Normal(data, 1), transforms)
    assert dist.batch_shape == batch_shape
    assert dist.event_shape == (2, 3)
    assert dist.support.event_dim == 2


@pytest.mark.parametrize("sample_shape", [(), (7,)], ids=str)
@pytest.mark.parametrize("transform_dim", [0, 1, 2])
@pytest.mark.parametrize("base_batch_dim", [0, 1, 2])
@pytest.mark.parametrize("base_event_dim", [0, 1, 2])
@pytest.mark.parametrize("num_transforms", [0, 1, 2, 3])
def test_transformed_distribution(base_batch_dim, base_event_dim, transform_dim,
                                  num_transforms, sample_shape):
    shape = torch.Size([2, 3, 4, 5])
    base_dist = Normal(0, 1)
    base_dist = base_dist.expand(shape[4 - base_batch_dim - base_event_dim:])
    if base_event_dim:
        base_dist = Independent(base_dist, base_event_dim)
    transforms = [AffineTransform(torch.zeros(shape[4 - transform_dim:]), 1),
                  ReshapeTransform((4, 5), (20,)),
                  ReshapeTransform((3, 20), (6, 10))]
    transforms = transforms[:num_transforms]
    transform = ComposeTransform(transforms)

    # Check validation in .__init__().
    if base_batch_dim + base_event_dim < transform.domain.event_dim:
        with pytest.raises(ValueError):
            TransformedDistribution(base_dist, transforms)
        return
    d = TransformedDistribution(base_dist, transforms)

    # Check sampling is sufficiently expanded.
    x = d.sample(sample_shape)
    assert x.shape == sample_shape + d.batch_shape + d.event_shape
    num_unique = len(set(x.reshape(-1).tolist()))
    assert num_unique >= 0.9 * x.numel()

    # Check log_prob shape on full samples.
    log_prob = d.log_prob(x)
    assert log_prob.shape == sample_shape + d.batch_shape

    # Check log_prob shape on partial samples.
    y = x
    while y.dim() > len(d.event_shape):
        y = y[0]
    log_prob = d.log_prob(y)
    assert log_prob.shape == d.batch_shape


def test_save_load_transform():
    # Evaluating `log_prob` will create a weakref `_inv` which cannot be pickled. Here, we check
    # that `__getstate__` correctly handles the weakref, and that we can evaluate the density after.
    dist = TransformedDistribution(Normal(0, 1), [AffineTransform(2, 3)])
    x = torch.linspace(0, 1, 10)
    log_prob = dist.log_prob(x)
    stream = io.BytesIO()
    torch.save(dist, stream)
    stream.seek(0)
    other = torch.load(stream)
    assert torch.allclose(log_prob, other.log_prob(x))


if __name__ == '__main__':
    pytest.main([__file__])
