File: test_constraints.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (127 lines) | stat: -rw-r--r-- 5,638 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Owner(s): ["module: distributions"]

import pytest

import torch
from torch.distributions import biject_to, constraints, transform_to
from torch.testing._internal.common_cuda import TEST_CUDA


EXAMPLES = [
    (constraints.symmetric, False, [[2., 0], [2., 2]]),
    (constraints.positive_semidefinite, False, [[2., 0], [2., 2]]),
    (constraints.positive_definite, False, [[2., 0], [2., 2]]),
    (constraints.symmetric, True, [[3., -5], [-5., 3]]),
    (constraints.positive_semidefinite, False, [[3., -5], [-5., 3]]),
    (constraints.positive_definite, False, [[3., -5], [-5., 3]]),
    (constraints.symmetric, True, [[1., 2], [2., 4]]),
    (constraints.positive_semidefinite, True, [[1., 2], [2., 4]]),
    (constraints.positive_definite, False, [[1., 2], [2., 4]]),
    (constraints.symmetric, True, [[[1., -2], [-2., 1]], [[2., 3], [3., 2]]]),
    (constraints.positive_semidefinite, False, [[[1., -2], [-2., 1]], [[2., 3], [3., 2]]]),
    (constraints.positive_definite, False, [[[1., -2], [-2., 1]], [[2., 3], [3., 2]]]),
    (constraints.symmetric, True, [[[1., -2], [-2., 4]], [[1., -1], [-1., 1]]]),
    (constraints.positive_semidefinite, True, [[[1., -2], [-2., 4]], [[1., -1], [-1., 1]]]),
    (constraints.positive_definite, False, [[[1., -2], [-2., 4]], [[1., -1], [-1., 1]]]),
    (constraints.symmetric, True, [[[4., 2], [2., 4]], [[3., -1], [-1., 3]]]),
    (constraints.positive_semidefinite, True, [[[4., 2], [2., 4]], [[3., -1], [-1., 3]]]),
    (constraints.positive_definite, True, [[[4., 2], [2., 4]], [[3., -1], [-1., 3]]]),
]

CONSTRAINTS = [
    (constraints.real,),
    (constraints.real_vector,),
    (constraints.positive,),
    (constraints.greater_than, [-10., -2, 0, 2, 10]),
    (constraints.greater_than, 0),
    (constraints.greater_than, 2),
    (constraints.greater_than, -2),
    (constraints.greater_than_eq, 0),
    (constraints.greater_than_eq, 2),
    (constraints.greater_than_eq, -2),
    (constraints.less_than, [-10., -2, 0, 2, 10]),
    (constraints.less_than, 0),
    (constraints.less_than, 2),
    (constraints.less_than, -2),
    (constraints.unit_interval,),
    (constraints.interval, [-4., -2, 0, 2, 4], [-3., 3, 1, 5, 5]),
    (constraints.interval, -2, -1),
    (constraints.interval, 1, 2),
    (constraints.half_open_interval, [-4., -2, 0, 2, 4], [-3., 3, 1, 5, 5]),
    (constraints.half_open_interval, -2, -1),
    (constraints.half_open_interval, 1, 2),
    (constraints.simplex,),
    (constraints.corr_cholesky,),
    (constraints.lower_cholesky,),
]


def build_constraint(constraint_fn, args, is_cuda=False):
    if not args:
        return constraint_fn
    t = torch.cuda.DoubleTensor if is_cuda else torch.DoubleTensor
    return constraint_fn(*(t(x) if isinstance(x, list) else x for x in args))

@pytest.mark.parametrize('constraint_fn, result, value', EXAMPLES)
@pytest.mark.parametrize('is_cuda', [False,
                                     pytest.param(True, marks=pytest.mark.skipif(not TEST_CUDA,
                                                                                 reason='CUDA not found.'))])
def test_constraint(constraint_fn, result, value, is_cuda):
    t = torch.cuda.DoubleTensor if is_cuda else torch.DoubleTensor
    assert constraint_fn.check(t(value)).all() == result


@pytest.mark.parametrize('constraint_fn, args', [(c[0], c[1:]) for c in CONSTRAINTS])
@pytest.mark.parametrize('is_cuda', [False,
                                     pytest.param(True, marks=pytest.mark.skipif(not TEST_CUDA,
                                                                                 reason='CUDA not found.'))])
def test_biject_to(constraint_fn, args, is_cuda):
    constraint = build_constraint(constraint_fn, args, is_cuda=is_cuda)
    try:
        t = biject_to(constraint)
    except NotImplementedError:
        pytest.skip('`biject_to` not implemented.')
    assert t.bijective, "biject_to({}) is not bijective".format(constraint)
    if constraint_fn is constraints.corr_cholesky:
        # (D * (D-1)) / 2 (where D = 4) = 6 (size of last dim)
        x = torch.randn(6, 6, dtype=torch.double)
    else:
        x = torch.randn(5, 5, dtype=torch.double)
    if is_cuda:
        x = x.cuda()
    y = t(x)
    assert constraint.check(y).all(), '\n'.join([
        "Failed to biject_to({})".format(constraint),
        "x = {}".format(x),
        "biject_to(...)(x) = {}".format(y),
    ])
    x2 = t.inv(y)
    assert torch.allclose(x, x2), "Error in biject_to({}) inverse".format(constraint)

    j = t.log_abs_det_jacobian(x, y)
    assert j.shape == x.shape[:x.dim() - t.domain.event_dim]


@pytest.mark.parametrize('constraint_fn, args', [(c[0], c[1:]) for c in CONSTRAINTS])
@pytest.mark.parametrize('is_cuda', [False,
                                     pytest.param(True, marks=pytest.mark.skipif(not TEST_CUDA,
                                                                                 reason='CUDA not found.'))])
def test_transform_to(constraint_fn, args, is_cuda):
    constraint = build_constraint(constraint_fn, args, is_cuda=is_cuda)
    t = transform_to(constraint)
    if constraint_fn is constraints.corr_cholesky:
        # (D * (D-1)) / 2 (where D = 4) = 6 (size of last dim)
        x = torch.randn(6, 6, dtype=torch.double)
    else:
        x = torch.randn(5, 5, dtype=torch.double)
    if is_cuda:
        x = x.cuda()
    y = t(x)
    assert constraint.check(y).all(), "Failed to transform_to({})".format(constraint)
    x2 = t.inv(y)
    y2 = t(x2)
    assert torch.allclose(y, y2), "Error in transform_to({}) pseudoinverse".format(constraint)


if __name__ == '__main__':
    pytest.main([__file__])