File: store.py

package info (click to toggle)
mautrix-python 0.20.7-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,812 kB
  • sloc: python: 19,103; makefile: 16
file content (257 lines) | stat: -rw-r--r-- 10,532 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations

from typing import Any, NamedTuple
import json

from mautrix.types import (
    Member,
    Membership,
    MemberStateEventContent,
    PowerLevelStateEventContent,
    RoomEncryptionStateEventContent,
    RoomID,
    Serializable,
    UserID,
)
from mautrix.util.async_db import Database, Scheme

from ..abstract import StateStore
from .upgrade import upgrade_table


class RoomState(NamedTuple):
    is_encrypted: bool
    has_full_member_list: bool
    encryption: RoomEncryptionStateEventContent
    power_levels: PowerLevelStateEventContent


class PgStateStore(StateStore):
    upgrade_table = upgrade_table

    db: Database

    def __init__(self, db: Database) -> None:
        self.db = db

    async def get_member(self, room_id: RoomID, user_id: UserID) -> Member | None:
        res = await self.db.fetchrow(
            "SELECT membership, displayname, avatar_url "
            "FROM mx_user_profile WHERE room_id=$1 AND user_id=$2",
            room_id,
            user_id,
        )
        if res is None:
            return None
        return Member(
            membership=Membership.deserialize(res["membership"]),
            displayname=res["displayname"],
            avatar_url=res["avatar_url"],
        )

    async def set_member(
        self, room_id: RoomID, user_id: UserID, member: Member | MemberStateEventContent
    ) -> None:
        q = (
            "INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) "
            "VALUES ($1, $2, $3, $4, $5)"
            "ON CONFLICT (room_id, user_id) DO UPDATE SET membership=$3, displayname=$4,"
            "                                             avatar_url=$5"
        )
        await self.db.execute(
            q, room_id, user_id, member.membership.value, member.displayname, member.avatar_url
        )

    async def set_membership(
        self, room_id: RoomID, user_id: UserID, membership: Membership
    ) -> None:
        q = (
            "INSERT INTO mx_user_profile (room_id, user_id, membership) VALUES ($1, $2, $3) "
            "ON CONFLICT (room_id, user_id) DO UPDATE SET membership=$3"
        )
        await self.db.execute(q, room_id, user_id, membership.value)

    async def get_members(
        self,
        room_id: RoomID,
        memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE),
    ) -> list[UserID]:
        membership_values = [membership.value for membership in memberships]
        if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
            q = "SELECT user_id FROM mx_user_profile WHERE room_id=$1 AND membership=ANY($2)"
            res = await self.db.fetch(q, room_id, membership_values)
        else:
            membership_placeholders = ("?," * len(memberships)).rstrip(",")
            q = (
                "SELECT user_id FROM mx_user_profile "
                f"WHERE room_id=? AND membership IN ({membership_placeholders})"
            )
            res = await self.db.fetch(q, room_id, *membership_values)
        return [profile["user_id"] for profile in res]

    async def get_member_profiles(
        self,
        room_id: RoomID,
        memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE),
    ) -> dict[UserID, Member]:
        membership_values = [membership.value for membership in memberships]
        if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
            q = (
                "SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile "
                "WHERE room_id=$1 AND membership=ANY($2)"
            )
            res = await self.db.fetch(q, room_id, membership_values)
        else:
            membership_placeholders = ("?," * len(memberships)).rstrip(",")
            q = (
                "SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile "
                f"WHERE room_id=? AND membership IN ({membership_placeholders})"
            )
            res = await self.db.fetch(q, room_id, *membership_values)
        return {profile["user_id"]: Member.deserialize(profile) for profile in res}

    async def get_members_filtered(
        self,
        room_id: RoomID,
        not_prefix: str,
        not_suffix: str,
        not_id: str,
        memberships: tuple[Membership, ...] = (Membership.JOIN, Membership.INVITE),
    ) -> list[UserID]:
        not_like = f"{not_prefix}%{not_suffix}"
        membership_values = [membership.value for membership in memberships]
        if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
            q = (
                "SELECT user_id FROM mx_user_profile "
                "WHERE room_id=$1 AND membership=ANY($2)"
                "AND user_id != $3 AND user_id NOT LIKE $4"
            )
            res = await self.db.fetch(q, room_id, membership_values, not_id, not_like)
        else:
            membership_placeholders = ("?," * len(memberships)).rstrip(",")
            q = (
                "SELECT user_id FROM mx_user_profile "
                f"WHERE room_id=? AND membership IN ({membership_placeholders})"
                "AND user_id != ? AND user_id NOT LIKE ?"
            )
            res = await self.db.fetch(q, room_id, *membership_values, not_id, not_like)
        return [profile["user_id"] for profile in res]

    async def set_members(
        self,
        room_id: RoomID,
        members: dict[UserID, Member | MemberStateEventContent],
        only_membership: Membership | None = None,
    ) -> None:
        columns = ["room_id", "user_id", "membership", "displayname", "avatar_url"]
        records = [
            (room_id, user_id, str(member.membership), member.displayname, member.avatar_url)
            for user_id, member in members.items()
        ]
        async with self.db.acquire() as conn, conn.transaction():
            del_q = "DELETE FROM mx_user_profile WHERE room_id=$1"
            if only_membership is None:
                await conn.execute(del_q, room_id)
            elif self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
                del_q = f"{del_q} AND (membership=$2 OR user_id = ANY($3))"
                await conn.execute(del_q, room_id, only_membership.value, list(members.keys()))
            else:
                member_placeholders = ("?," * len(members)).rstrip(",")
                del_q = f"{del_q} AND (membership=? OR user_id IN ({member_placeholders}))"
                await conn.execute(del_q, room_id, only_membership.value, *members.keys())

            if self.db.scheme == Scheme.POSTGRES:
                await conn.copy_records_to_table(
                    "mx_user_profile", records=records, columns=columns
                )
            else:
                q = (
                    "INSERT INTO mx_user_profile (room_id, user_id, membership, "
                    "displayname, avatar_url) VALUES ($1, $2, $3, $4, $5)"
                )
                await conn.executemany(q, records)

            if not only_membership or only_membership == Membership.JOIN:
                await conn.execute(
                    "UPDATE mx_room_state SET has_full_member_list=true WHERE room_id=$1",
                    room_id,
                )

    async def find_shared_rooms(self, user_id: UserID) -> list[RoomID]:
        q = (
            "SELECT mx_user_profile.room_id FROM mx_user_profile "
            "LEFT JOIN mx_room_state ON mx_room_state.room_id=mx_user_profile.room_id "
            "WHERE user_id=$1 AND mx_room_state.is_encrypted=true"
        )
        rows = await self.db.fetch(q, user_id)
        return [row["room_id"] for row in rows]

    async def has_full_member_list(self, room_id: RoomID) -> bool:
        return bool(
            await self.db.fetchval(
                "SELECT has_full_member_list FROM mx_room_state WHERE room_id=$1", room_id
            )
        )

    async def has_power_levels_cached(self, room_id: RoomID) -> bool:
        return bool(
            await self.db.fetchval(
                "SELECT power_levels IS NOT NULL FROM mx_room_state WHERE room_id=$1", room_id
            )
        )

    async def get_power_levels(self, room_id: RoomID) -> PowerLevelStateEventContent | None:
        power_levels_json = await self.db.fetchval(
            "SELECT power_levels FROM mx_room_state WHERE room_id=$1", room_id
        )
        if power_levels_json is None:
            return None
        return PowerLevelStateEventContent.parse_json(power_levels_json)

    async def set_power_levels(
        self, room_id: RoomID, content: PowerLevelStateEventContent | dict[str, Any]
    ) -> None:
        await self.db.execute(
            "INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2) "
            "ON CONFLICT (room_id) DO UPDATE SET power_levels=$2",
            room_id,
            json.dumps(content.serialize() if isinstance(content, Serializable) else content),
        )

    async def has_encryption_info_cached(self, room_id: RoomID) -> bool:
        return bool(
            await self.db.fetchval(
                "SELECT encryption IS NULL FROM mx_room_state WHERE room_id=$1", room_id
            )
        )

    async def is_encrypted(self, room_id: RoomID) -> bool | None:
        return await self.db.fetchval(
            "SELECT is_encrypted FROM mx_room_state WHERE room_id=$1", room_id
        )

    async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEventContent | None:
        row = await self.db.fetchrow(
            "SELECT is_encrypted, encryption FROM mx_room_state WHERE room_id=$1", room_id
        )
        if row is None or not row["is_encrypted"]:
            return None
        return RoomEncryptionStateEventContent.parse_json(row["encryption"])

    async def set_encryption_info(
        self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, Any]
    ) -> None:
        q = (
            "INSERT INTO mx_room_state (room_id, is_encrypted, encryption) VALUES ($1, true, $2) "
            "ON CONFLICT (room_id) DO UPDATE SET is_encrypted=true, encryption=$2"
        )
        await self.db.execute(
            q,
            room_id,
            json.dumps(content.serialize() if isinstance(content, Serializable) else content),
        )