File: encryption_manager.py

package info (click to toggle)
mautrix-python 0.20.7-1
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 1,812 kB
  • sloc: python: 19,103; makefile: 16
file content (187 lines) | stat: -rw-r--r-- 7,123 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations

import asyncio
import logging

from mautrix import __optional_imports__
from mautrix.errors import DecryptionError, EncryptionError, MNotFound
from mautrix.types import (
    EncryptedEvent,
    EncryptedMegolmEventContent,
    EventContent,
    EventID,
    EventType,
    RoomID,
)
from mautrix.util.logging import TraceLogger

from . import client, dispatcher, store_updater

if __optional_imports__:
    from .. import crypto as crypt


class EncryptingAPI(store_updater.StoreUpdatingAPI):
    """
    EncryptingAPI is a wrapper around StoreUpdatingAPI that automatically encrypts messages.

    For automatic decryption, see :class:`DecryptionDispatcher`.
    """

    _crypto: crypt.OlmMachine | None
    encryption_blacklist: set[EventType] = {EventType.REACTION}
    """A set of event types which shouldn't be encrypted even in encrypted rooms."""
    crypto_log: TraceLogger = logging.getLogger("mau.client.crypto")
    """The logger to use for crypto-related things."""
    _share_session_events: dict[RoomID, asyncio.Event]

    def __init__(self, *args, crypto_log: TraceLogger | None = None, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        if crypto_log:
            self.crypto_log = crypto_log
        self._crypto = None
        self._share_session_events = {}

    @property
    def crypto(self) -> crypt.OlmMachine | None:
        """The :class:`crypto.OlmMachine` to use for e2ee stuff."""
        return self._crypto

    @crypto.setter
    def crypto(self, crypto: crypt.OlmMachine) -> None:
        """
        Args:
            crypto: The olm machine to use for crypto

        Raises:
            ValueError: if :attr:`state_store` is not set.
        """
        if not self.state_store:
            raise ValueError("State store must be set to use encryption")
        self._crypto = crypto

    @property
    def crypto_enabled(self) -> bool:
        """``True`` if both the olm machine and state store are set properly."""
        return bool(self.crypto) and bool(self.state_store)

    async def encrypt(
        self, room_id: RoomID, event_type: EventType, content: EventContent
    ) -> EncryptedMegolmEventContent:
        """
        Encrypt a message for the given room. Automatically creates and shares a group session
        if necessary.

        Args:
            room_id: The room to encrypt the event to.
            event_type: The type of event.
            content: The content of the event.

        Returns:
            The content of the encrypted event.
        """
        try:
            return await self.crypto.encrypt_megolm_event(room_id, event_type, content)
        except EncryptionError:
            self.crypto_log.debug("Got EncryptionError, sharing group session and trying again")
            await self.share_group_session(room_id)
            self.crypto_log.trace(
                f"Shared group session, now trying to encrypt in {room_id} again"
            )
            return await self.crypto.encrypt_megolm_event(room_id, event_type, content)

    async def _share_session_lock(self, room_id: RoomID) -> bool:
        try:
            event = self._share_session_events[room_id]
        except KeyError:
            self._share_session_events[room_id] = asyncio.Event()
            return True
        else:
            await event.wait()
            return False

    async def share_group_session(self, room_id: RoomID) -> None:
        """
        Create and share a Megolm session for the given room.

        Args:
            room_id: The room to share the session for.
        """
        if not await self._share_session_lock(room_id):
            self.log.silly("Group session was already being shared, so didn't share new one")
            return
        try:
            if not await self.state_store.has_full_member_list(room_id):
                self.crypto_log.trace(
                    f"Don't have full member list for {room_id}, fetching from server"
                )
                members = list((await self.get_joined_members(room_id)).keys())
            else:
                self.crypto_log.trace(f"Fetching member list for {room_id} from state store")
                members = await self.state_store.get_members(room_id)
            await self.crypto.share_group_session(room_id, members)
        finally:
            self._share_session_events.pop(room_id).set()

    async def send_message_event(
        self,
        room_id: RoomID,
        event_type: EventType,
        content: EventContent,
        disable_encryption: bool = False,
        **kwargs,
    ) -> EventID:
        """
        A wrapper around :meth:`ClientAPI.send_message_event` that encrypts messages if the target
        room is encrypted.

        Args:
            room_id: The room to send the message to.
            event_type: The unencrypted event type.
            content: The unencrypted event content.
            disable_encryption: Set to ``True`` if you want to force-send an unencrypted message.
            **kwargs: Additional parameters to pass to :meth:`ClientAPI.send_message_event`.

        Returns:
            The ID of the event that was sent.
        """
        if self.crypto and event_type not in self.encryption_blacklist and not disable_encryption:
            is_encrypted = await self.state_store.is_encrypted(room_id)
            if is_encrypted is None:
                try:
                    await self.get_state_event(room_id, EventType.ROOM_ENCRYPTION)
                    is_encrypted = True
                except MNotFound:
                    is_encrypted = False
            if is_encrypted:
                content = await self.encrypt(room_id, event_type, content)
                event_type = EventType.ROOM_ENCRYPTED
        return await super().send_message_event(room_id, event_type, content, **kwargs)


class DecryptionDispatcher(dispatcher.SimpleDispatcher):
    """
    DecryptionDispatcher is a dispatcher that can be used with a :class:`client.Syncer`
    to automatically decrypt events and dispatch the unencrypted versions for event handlers.

    The easiest way to use this is with :class:`client.Client`, which automatically registers
    this dispatcher when :attr:`EncryptingAPI.crypto` is set.
    """

    event_type = EventType.ROOM_ENCRYPTED
    client: client.Client

    async def handle(self, evt: EncryptedEvent) -> None:
        try:
            self.client.crypto_log.trace(f"Decrypting {evt.event_id} in {evt.room_id}...")
            decrypted = await self.client.crypto.decrypt_megolm_event(evt)
        except DecryptionError as e:
            self.client.crypto_log.warning(f"Failed to decrypt {evt.event_id}: {e}")
            return
        self.client.crypto_log.trace(f"Decrypted {evt.event_id}: {decrypted}")
        self.client.dispatch_event(decrypted, evt.source)