File: filter_public.py

package info (click to toggle)
sqlalchemy 2.0.40%2Bds1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, trixie
  • size: 26,404 kB
  • sloc: python: 410,002; makefile: 230; sh: 7
file content (202 lines) | stat: -rw-r--r-- 5,964 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""Illustrates a global criteria applied to entities of a particular type.

The example here is the "public" flag, a simple boolean that indicates
the rows are part of a publicly viewable subcategory.  Rows that do not
include this flag are not shown unless a special option is passed to the
query.

Uses for this kind of recipe include tables that have "soft deleted" rows
marked as "deleted" that should be skipped, rows that have access control rules
that should be applied on a per-request basis, etc.


"""

from sqlalchemy import Boolean
from sqlalchemy import Column
from sqlalchemy import create_engine
from sqlalchemy import event
from sqlalchemy import ForeignKey
from sqlalchemy import Integer
from sqlalchemy import orm
from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy import true
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker


@event.listens_for(Session, "do_orm_execute")
def _add_filtering_criteria(execute_state):
    """Intercept all ORM queries.   Add a with_loader_criteria option to all
    of them.

    This option applies to SELECT queries and adds a global WHERE criteria
    (or as appropriate ON CLAUSE criteria for join targets)
    to all objects of a certain class or superclass.

    """

    # the with_loader_criteria automatically applies itself to
    # relationship loads as well including lazy loads.   So if this is
    # a relationship load, assume the option was set up from the top level
    # query.

    if (
        not execute_state.is_column_load
        and not execute_state.is_relationship_load
        and not execute_state.execution_options.get("include_private", False)
    ):
        execute_state.statement = execute_state.statement.options(
            orm.with_loader_criteria(
                HasPrivate,
                lambda cls: cls.public == true(),
                include_aliases=True,
            )
        )


class HasPrivate:
    """Mixin that identifies a class as having private entities"""

    public = Column(Boolean, nullable=False)


if __name__ == "__main__":
    Base = declarative_base()

    class User(HasPrivate, Base):
        __tablename__ = "user"

        id = Column(Integer, primary_key=True)
        name = Column(String)
        addresses = relationship("Address", back_populates="user")

    class Address(HasPrivate, Base):
        __tablename__ = "address"

        id = Column(Integer, primary_key=True)
        email = Column(String)
        user_id = Column(Integer, ForeignKey("user.id"))

        user = relationship("User", back_populates="addresses")

    engine = create_engine("sqlite://", echo=True)
    Base.metadata.create_all(engine)

    Session = sessionmaker(bind=engine)

    sess = Session()

    sess.add_all(
        [
            User(
                name="u1",
                public=True,
                addresses=[
                    Address(email="u1a1", public=True),
                    Address(email="u1a2", public=True),
                ],
            ),
            User(
                name="u2",
                public=True,
                addresses=[
                    Address(email="u2a1", public=False),
                    Address(email="u2a2", public=True),
                ],
            ),
            User(
                name="u3",
                public=False,
                addresses=[
                    Address(email="u3a1", public=False),
                    Address(email="u3a2", public=False),
                ],
            ),
            User(
                name="u4",
                public=False,
                addresses=[
                    Address(email="u4a1", public=False),
                    Address(email="u4a2", public=True),
                ],
            ),
            User(
                name="u5",
                public=True,
                addresses=[
                    Address(email="u5a1", public=True),
                    Address(email="u5a2", public=False),
                ],
            ),
        ]
    )

    sess.commit()

    # now querying Address or User objects only gives us the public ones
    for u1 in sess.query(User).options(orm.selectinload(User.addresses)):
        assert u1.public

        # the addresses collection will also be "public only", which works
        # for all relationship loaders including joinedload
        for address in u1.addresses:
            assert address.public

    # works for columns too
    cols = (
        sess.query(User.id, Address.id)
        .join(User.addresses)
        .order_by(User.id, Address.id)
        .all()
    )
    assert cols == [(1, 1), (1, 2), (2, 4), (5, 9)]

    cols = (
        sess.query(User.id, Address.id)
        .join(User.addresses)
        .order_by(User.id, Address.id)
        .execution_options(include_private=True)
        .all()
    )
    assert cols == [
        (1, 1),
        (1, 2),
        (2, 3),
        (2, 4),
        (3, 5),
        (3, 6),
        (4, 7),
        (4, 8),
        (5, 9),
        (5, 10),
    ]

    # count all public addresses
    assert sess.query(Address).count() == 5

    # count all addresses public and private
    assert (
        sess.query(Address).execution_options(include_private=True).count()
        == 10
    )

    # load an Address that is public, but its parent User is private
    # (2.0 style query)
    a1 = sess.execute(select(Address).filter_by(email="u4a2")).scalar()

    # assuming the User isn't already in the Session, it returns None
    assert a1.user is None

    # however, if that user is present in the session, then a many-to-one
    # does a simple get() and it will be present
    sess.expire(a1, ["user"])
    u1 = sess.execute(
        select(User)
        .filter_by(name="u4")
        .execution_options(include_private=True)
    ).scalar()
    assert a1.user is u1