File: test_distributors.py

package info (click to toggle)
logilab-constraint 0.5.0-2
  • links: PTS, VCS
  • area: main
  • in suites: jessie, jessie-kfreebsd, stretch
  • size: 384 kB
  • ctags: 497
  • sloc: python: 2,902; xml: 165; makefile: 17
file content (150 lines) | stat: -rw-r--r-- 6,063 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# copyright 2002-2010 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
#
# This file is part of logilab-constraint.
#
# logilab-constraint is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published by the
# Free Software Foundation, either version 2.1 of the License, or (at your
# option) any later version.
#
# logilab-constraint is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License
# for more details.
#
# You should have received a copy of the GNU Lesser General Public License along
# with logilab-constraint. If not, see <http://www.gnu.org/licenses/>.
"""Unit testing for constraint propagation module"""

import unittest
from logilab.common.testlib import TestCase, TestSuite
from logilab.constraint import distributors,fd


class AbstractDistributorTC(TestCase):
    """override the following methods:
     * buildDistributor to create the distributor
     * distributionAssertions to check that distribution was OK
     """
    def setUp(self):
        self.variables = ('v1','v2','v3')
        self.domains1 = {'v1':fd.FiniteDomain([1]),
                        'v2':fd.FiniteDomain([2,3]),
                        'v3':fd.FiniteDomain([4,5,6,7]),
                         }
        self.domains2 = {'v1':fd.FiniteDomain([1]),
                        'v2':fd.FiniteDomain([2,3,4,5,6]),
                        'v3':fd.FiniteDomain([7,8,9,10,11,12,13]),
                         }
        self.distributor = self.buildDistributor()

    def buildDistributor(self):
        """returns a distributor"""
        raise NotImplementedError

    def distributionAssertions(self):
        """checks the distribution"""
        raise NotImplementedError

    def testFindSmallestDomain(self):
        """tests that the smallest domain is indeed the smallest one
        with at least 2 values inside"""
        dist = self.buildDistributor()
        self.assertEqual('v2', dist.findSmallestDomain(self.domains1))
        self.assertEqual('v2', dist.findSmallestDomain(self.domains2))

    def testFindLargestDomain(self):
        """tests that the largest domain is indeed the largest one"""
        dist = self.buildDistributor()
        self.assertEqual('v3', dist.findLargestDomain(self.domains1))
        self.assertEqual('v3', dist.findLargestDomain(self.domains2))

    def testSingleValueDomainNotDistributed(self):
        """tests that a domain of size 1 is not distributed"""
        for initial_domain in (self.domains1,self.domains2):
            distributed_domains = self.distributor.distribute(self.domains1)
            for d in distributed_domains:
                assert d['v1'].size() == initial_domain['v1'].size()


    def testDistribution(self):
        """tests that the right domain is correctly distributed"""
        for initial_domain in (self.domains1,self.domains2):
            distributed_domains = self.distributor.distribute(initial_domain)
            self.distributionAssertions(initial_domain,distributed_domains)

class NaiveDistributorTC(AbstractDistributorTC):
    def buildDistributor(self):
        return distributors.NaiveDistributor()
    def distributionAssertions(self,initial,distributed):
        assert len(distributed) == 2
        for d in distributed:
            for v in ('v1','v3'):
                assert d[v].getValues() == initial[v].getValues()
        assert distributed[0]['v2'].size() == 1
        assert distributed[1]['v2'].size() == initial['v2'].size()-1

class RandomizingDistributorTC(NaiveDistributorTC):
    def buildDistributor(self):
        return distributors.RandomizingDistributor()


class DichotomyDistributorTC(AbstractDistributorTC):
    def buildDistributor(self):
        return distributors.DichotomyDistributor()
    def distributionAssertions(self,initial,distributed):
        assert len(distributed) == 2
        for d in distributed:
            for v in ('v1','v3'):
                assert d[v].getValues() == initial[v].getValues()
        assert distributed[0]['v2'].size() + distributed[1]['v2'].size() == initial['v2'].size()

class SplitDistributorTC(AbstractDistributorTC):
    def buildDistributor(self):
        return distributors.SplitDistributor(4)
    def distributionAssertions(self,initial,distributed):
        assert len(distributed) == min(4,initial['v2'].size())
        for d in distributed:
            for v in ('v1','v3'):
                assert d[v].getValues() == initial[v].getValues()
        sizes = [d['v2'].size() for d in distributed]
        import operator
        tot_size = reduce(operator.add,sizes)
        assert tot_size == initial['v2'].size()

class EnumeratorDistributorTC(AbstractDistributorTC):
    def buildDistributor(self):
        return distributors.EnumeratorDistributor()
    def distributionAssertions(self,initial,distributed):
        assert len(distributed) == initial['v2'].size()
        for d in distributed:
            for v in ('v1','v3'):
                assert d[v].getValues() == initial[v].getValues()
            assert d['v2'].size() == 1


def get_all_cases(module):
    import types
    all_cases = []
    for name in dir(module):
        obj = getattr(module, name)
        if type(obj) in (types.ClassType, types.TypeType) and \
               issubclass(obj, TestCase) and \
               not name.startswith('Abstract'):
            all_cases.append(obj)
    all_cases.sort()
    return all_cases

def suite(cases = None):
    import test_distributors
    cases = cases or get_all_cases(test_distributors)
    loader = unittest.defaultTestLoader
    loader.testMethodPrefix = 'test'
    loader.sortTestMethodsUsing = None # disable sorting
    suites = [loader.loadTestsFromTestCase(tc) for tc in cases]

    return TestSuite(suites)

if __name__ == '__main__':
    unittest.main(defaultTest="suite")