File: test_abc_polymorphic.py

package info (click to toggle)
sqlalchemy 1.2.18%2Bds1-2
  • links: PTS, VCS
  • area: main
  • in suites: buster
  • size: 16,080 kB
  • sloc: python: 239,496; ansic: 1,345; makefile: 264; xml: 17
file content (126 lines) | stat: -rw-r--r-- 3,953 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from sqlalchemy import ForeignKey
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy.orm import create_session
from sqlalchemy.orm import mapper
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
from sqlalchemy.testing.util import function_named


class ABCTest(fixtures.MappedTest):
    @classmethod
    def define_tables(cls, metadata):
        global a, b, c
        a = Table(
            "a",
            metadata,
            Column(
                "id", Integer, primary_key=True, test_needs_autoincrement=True
            ),
            Column("adata", String(30)),
            Column("type", String(30)),
        )
        b = Table(
            "b",
            metadata,
            Column("id", Integer, ForeignKey("a.id"), primary_key=True),
            Column("bdata", String(30)),
        )
        c = Table(
            "c",
            metadata,
            Column("id", Integer, ForeignKey("b.id"), primary_key=True),
            Column("cdata", String(30)),
        )

    def _make_test(fetchtype):
        def test_roundtrip(self):
            class A(fixtures.ComparableEntity):
                pass

            class B(A):
                pass

            class C(B):
                pass

            if fetchtype == "union":
                abc = a.outerjoin(b).outerjoin(c)
                bc = a.join(b).outerjoin(c)
            else:
                abc = bc = None

            mapper(
                A,
                a,
                with_polymorphic=("*", abc),
                polymorphic_on=a.c.type,
                polymorphic_identity="a",
            )
            mapper(
                B,
                b,
                with_polymorphic=("*", bc),
                inherits=A,
                polymorphic_identity="b",
            )
            mapper(C, c, inherits=B, polymorphic_identity="c")

            a1 = A(adata="a1")
            b1 = B(bdata="b1", adata="b1")
            b2 = B(bdata="b2", adata="b2")
            b3 = B(bdata="b3", adata="b3")
            c1 = C(cdata="c1", bdata="c1", adata="c1")
            c2 = C(cdata="c2", bdata="c2", adata="c2")
            c3 = C(cdata="c2", bdata="c2", adata="c2")

            sess = create_session()
            for x in (a1, b1, b2, b3, c1, c2, c3):
                sess.add(x)
            sess.flush()
            sess.expunge_all()

            # for obj in sess.query(A).all():
            #    print obj
            eq_(
                [
                    A(adata="a1"),
                    B(bdata="b1", adata="b1"),
                    B(bdata="b2", adata="b2"),
                    B(bdata="b3", adata="b3"),
                    C(cdata="c1", bdata="c1", adata="c1"),
                    C(cdata="c2", bdata="c2", adata="c2"),
                    C(cdata="c2", bdata="c2", adata="c2"),
                ],
                sess.query(A).order_by(A.id).all(),
            )

            eq_(
                [
                    B(bdata="b1", adata="b1"),
                    B(bdata="b2", adata="b2"),
                    B(bdata="b3", adata="b3"),
                    C(cdata="c1", bdata="c1", adata="c1"),
                    C(cdata="c2", bdata="c2", adata="c2"),
                    C(cdata="c2", bdata="c2", adata="c2"),
                ],
                sess.query(B).order_by(A.id).all(),
            )

            eq_(
                [
                    C(cdata="c1", bdata="c1", adata="c1"),
                    C(cdata="c2", bdata="c2", adata="c2"),
                    C(cdata="c2", bdata="c2", adata="c2"),
                ],
                sess.query(C).order_by(A.id).all(),
            )

        test_roundtrip = function_named(test_roundtrip, "test_%s" % fetchtype)
        return test_roundtrip

    test_union = _make_test("union")
    test_none = _make_test("none")