File: store_updater.py

package info (click to toggle)
mautrix-python 0.20.7-1
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 1,812 kB
  • sloc: python: 19,103; makefile: 16
file content (298 lines) | stat: -rw-r--r-- 11,209 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations

import asyncio

from mautrix.errors import MForbidden, MNotFound
from mautrix.types import (
    JSON,
    EventID,
    EventType,
    Member,
    Membership,
    MemberStateEventContent,
    RoomAlias,
    RoomID,
    StateEvent,
    StateEventContent,
    SyncToken,
    UserID,
)

from .api import ClientAPI
from .state_store import StateStore


class StoreUpdatingAPI(ClientAPI):
    """
    StoreUpdatingAPI is a wrapper around the medium-level ClientAPI that optionally updates
    a client state store with outgoing state events (after they're successfully sent).
    """

    state_store: StateStore | None

    def __init__(self, *args, state_store: StateStore | None = None, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.state_store = state_store

    async def join_room_by_id(
        self,
        room_id: RoomID,
        third_party_signed: JSON = None,
        extra_content: dict[str, JSON] | None = None,
    ) -> RoomID:
        room_id = await super().join_room_by_id(
            room_id, third_party_signed=third_party_signed, extra_content=extra_content
        )
        if room_id and not extra_content and self.state_store:
            await self.state_store.set_membership(room_id, self.mxid, Membership.JOIN)
        return room_id

    async def join_room(
        self,
        room_id_or_alias: RoomID | RoomAlias,
        servers: list[str] | None = None,
        third_party_signed: JSON = None,
        max_retries: int = 4,
    ) -> RoomID:
        room_id = await super().join_room(
            room_id_or_alias, servers, third_party_signed, max_retries
        )
        if room_id and self.state_store:
            await self.state_store.set_membership(room_id, self.mxid, Membership.JOIN)
        return room_id

    async def leave_room(
        self,
        room_id: RoomID,
        reason: str | None = None,
        extra_content: dict[str, JSON] | None = None,
        raise_not_in_room: bool = False,
    ) -> None:
        await super().leave_room(room_id, reason, extra_content, raise_not_in_room)
        if not extra_content and self.state_store:
            await self.state_store.set_membership(room_id, self.mxid, Membership.LEAVE)

    async def knock_room(
        self,
        room_id_or_alias: RoomID | RoomAlias,
        reason: str | None = None,
        servers: list[str] | None = None,
    ) -> RoomID:
        room_id = await super().knock_room(room_id_or_alias, reason, servers)
        if room_id and self.state_store:
            await self.state_store.set_membership(room_id, self.mxid, Membership.KNOCK)
        return room_id

    async def invite_user(
        self,
        room_id: RoomID,
        user_id: UserID,
        reason: str | None = None,
        extra_content: dict[str, JSON] | None = None,
    ) -> None:
        await super().invite_user(room_id, user_id, reason, extra_content=extra_content)
        if not extra_content and self.state_store:
            await self.state_store.set_membership(room_id, user_id, Membership.INVITE)

    async def kick_user(
        self,
        room_id: RoomID,
        user_id: UserID,
        reason: str = "",
        extra_content: dict[str, JSON] | None = None,
    ) -> None:
        await super().kick_user(room_id, user_id, reason=reason, extra_content=extra_content)
        if not extra_content and self.state_store:
            await self.state_store.set_membership(room_id, user_id, Membership.LEAVE)

    async def ban_user(
        self,
        room_id: RoomID,
        user_id: UserID,
        reason: str = "",
        extra_content: dict[str, JSON] | None = None,
    ) -> None:
        await super().ban_user(room_id, user_id, reason=reason, extra_content=extra_content)
        if not extra_content and self.state_store:
            await self.state_store.set_membership(room_id, user_id, Membership.BAN)

    async def unban_user(
        self,
        room_id: RoomID,
        user_id: UserID,
        reason: str = "",
        extra_content: dict[str, JSON] | None = None,
    ) -> None:
        await super().unban_user(room_id, user_id, reason=reason, extra_content=extra_content)
        if self.state_store:
            await self.state_store.set_membership(room_id, user_id, Membership.LEAVE)

    async def get_state(self, room_id: RoomID) -> list[StateEvent]:
        state = await super().get_state(room_id)
        if self.state_store:
            update_members = self.state_store.set_members(
                room_id,
                {evt.state_key: evt.content for evt in state if evt.type == EventType.ROOM_MEMBER},
            )
            await asyncio.gather(
                update_members,
                *[
                    self.state_store.update_state(evt)
                    for evt in state
                    if evt.type != EventType.ROOM_MEMBER
                ],
            )
        return state

    async def create_room(self, *args, **kwargs) -> RoomID:
        room_id = await super().create_room(*args, **kwargs)
        if self.state_store:
            invitee_membership = Membership.INVITE
            if kwargs.get("beeper_auto_join_invites"):
                invitee_membership = Membership.JOIN
            for user_id in kwargs.get("invitees", []):
                await self.state_store.set_membership(room_id, user_id, invitee_membership)
            for evt in kwargs.get("initial_state", []):
                await self.state_store.update_state(
                    StateEvent(
                        type=EventType.find(evt["type"], t_class=EventType.Class.STATE),
                        room_id=room_id,
                        event_id=EventID("$fake-create-id"),
                        sender=self.mxid,
                        state_key=evt.get("state_key", ""),
                        timestamp=0,
                        content=evt["content"],
                    )
                )
        return room_id

    async def send_state_event(
        self,
        room_id: RoomID,
        event_type: EventType,
        content: StateEventContent | dict[str, JSON],
        state_key: str = "",
        **kwargs,
    ) -> EventID:
        event_id = await super().send_state_event(
            room_id, event_type, content, state_key, **kwargs
        )
        if self.state_store:
            fake_event = StateEvent(
                type=event_type,
                room_id=room_id,
                event_id=event_id,
                sender=self.mxid,
                state_key=state_key,
                timestamp=0,
                content=content,
            )
            await self.state_store.update_state(fake_event)
        return event_id

    async def get_state_event(
        self, room_id: RoomID, event_type: EventType, state_key: str = ""
    ) -> StateEventContent:
        event = await super().get_state_event(room_id, event_type, state_key)
        if self.state_store:
            fake_event = StateEvent(
                type=event_type,
                room_id=room_id,
                event_id=EventID(""),
                sender=UserID(""),
                state_key=state_key,
                timestamp=0,
                content=event,
            )
            await self.state_store.update_state(fake_event)
        return event

    async def get_joined_members(self, room_id: RoomID) -> dict[UserID, Member]:
        members = await super().get_joined_members(room_id)
        if self.state_store:
            await self.state_store.set_members(room_id, members, only_membership=Membership.JOIN)
        return members

    async def get_members(
        self,
        room_id: RoomID,
        at: SyncToken | None = None,
        membership: Membership | None = None,
        not_membership: Membership | None = None,
    ) -> list[StateEvent]:
        member_events = await super().get_members(room_id, at, membership, not_membership)
        if self.state_store and not_membership != Membership.JOIN:
            await self.state_store.set_members(
                room_id,
                {evt.state_key: evt.content for evt in member_events},
                only_membership=membership,
            )
        return member_events

    async def fill_member_event(
        self,
        room_id: RoomID,
        user_id: UserID,
        content: MemberStateEventContent,
    ) -> MemberStateEventContent | None:
        """
        Fill a membership event content that is going to be sent in :meth:`send_member_event`.

        This is used to set default fields like the displayname and avatar, which are usually set
        by the server in the sugar membership endpoints like /join and /invite, but are not set
        automatically when sending member events manually.

        This implementation in StoreUpdatingAPI will first try to call the default implementation
        (which calls :attr:`fill_member_event_callback`). If that doesn't return anything, this
        will try to get the profile from the current member event, and then fall back to fetching
        the global profile from the server.

        Args:
            room_id: The room where the member event is going to be sent.
            user_id: The user whose membership is changing.
            content: The new member event content.

        Returns:
            The filled member event content.
        """
        callback_content = await super().fill_member_event(room_id, user_id, content)
        if callback_content is not None:
            self.log.trace("Filled new member event for %s using callback", user_id)
            return callback_content

        if content.displayname is None and content.avatar_url is None:
            existing_member = await self.state_store.get_member(room_id, user_id)
            if existing_member is not None:
                self.log.trace(
                    "Found existing member event %s to fill new member event for %s",
                    existing_member,
                    user_id,
                )
                content.displayname = existing_member.displayname
                content.avatar_url = existing_member.avatar_url
                return content

            try:
                profile = await self.get_profile(user_id)
            except (MNotFound, MForbidden):
                profile = None
            if profile:
                self.log.trace(
                    "Fetched profile %s to fill new member event of %s", profile, user_id
                )
                content.displayname = profile.displayname
                content.avatar_url = profile.avatar_url
                return content
            else:
                self.log.trace("Didn't find profile info to fill new member event of %s", user_id)
        else:
            self.log.trace(
                "Member event for %s already contains displayname or avatar, not re-filling",
                user_id,
            )
        return None