File: key_request.py

package info (click to toggle)
mautrix-python 0.20.7-1
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 1,812 kB
  • sloc: python: 19,103; makefile: 16
file content (158 lines) | stat: -rw-r--r-- 5,954 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# Copyright (c) 2023 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Dict, List, Optional, Union
import asyncio
import uuid

from mautrix.types import (
    DecryptedOlmEvent,
    DeviceID,
    EncryptionAlgorithm,
    EventType,
    ForwardedRoomKeyEventContent,
    IdentityKey,
    KeyRequestAction,
    RequestedKeyInfo,
    RoomID,
    RoomKeyRequestEventContent,
    SessionID,
    UserID,
)

from .base import BaseOlmMachine
from .sessions import InboundGroupSession


class KeyRequestingMachine(BaseOlmMachine):
    async def request_room_key(
        self,
        room_id: RoomID,
        sender_key: IdentityKey,
        session_id: SessionID,
        from_devices: Dict[UserID, List[DeviceID]],
        timeout: Optional[Union[int, float]] = None,
    ) -> bool:
        """
        Request keys for a Megolm group session from other devices.

        Once the keys are received, or if this task is cancelled (via the ``timeout`` parameter),
        a cancel request event is sent to the remaining devices. If the ``timeout`` is set to zero
        or less, this will return immediately, and the extra key requests will not be cancelled.

        Args:
            room_id: The room where the session is used.
            sender_key: The key of the user who created the session.
            session_id: The ID of the session.
            from_devices: A dict from user ID to list of device IDs whom to ask for the keys.
            timeout: The maximum number of seconds to wait for the keys. If the timeout is
                     ``None``, the wait time is not limited, but the task can still be cancelled.
                     If it's zero or less, this returns immediately and will never cancel requests.

        Returns:
            ``True`` if the keys were received and are now in the crypto store,
            ``False`` otherwise (including if the method didn't wait at all).
        """
        request_id = str(uuid.uuid1())
        request = RoomKeyRequestEventContent(
            action=KeyRequestAction.REQUEST,
            body=RequestedKeyInfo(
                algorithm=EncryptionAlgorithm.MEGOLM_V1,
                room_id=room_id,
                sender_key=sender_key,
                session_id=session_id,
            ),
            request_id=request_id,
            requesting_device_id=self.client.device_id,
        )

        wait = timeout is None or timeout > 0
        fut: Optional[asyncio.Future] = None
        if wait:
            fut = asyncio.get_running_loop().create_future()
            self._key_request_waiters[session_id] = fut
        await self.client.send_to_device(
            EventType.ROOM_KEY_REQUEST,
            {
                user_id: {device_id: request for device_id in devices}
                for user_id, devices in from_devices.items()
            },
        )
        if not wait:
            # Timeout is set and <=0, don't wait for keys
            return False
        assert fut is not None
        got_keys = False
        try:
            user_id, device_id = await asyncio.wait_for(fut, timeout=timeout)
            got_keys = True
            try:
                del from_devices[user_id][device_id]
                if len(from_devices[user_id]) == 0:
                    del from_devices[user_id]
            except KeyError:
                pass
        except (asyncio.CancelledError, asyncio.TimeoutError):
            pass
        del self._key_request_waiters[session_id]
        if len(from_devices) > 0:
            cancel = RoomKeyRequestEventContent(
                action=KeyRequestAction.CANCEL,
                request_id=str(request_id),
                requesting_device_id=self.client.device_id,
            )
            await self.client.send_to_device(
                EventType.ROOM_KEY_REQUEST,
                {
                    user_id: {device_id: cancel for device_id in devices}
                    for user_id, devices in from_devices.items()
                },
            )
        return got_keys

    async def _receive_forwarded_room_key(self, evt: DecryptedOlmEvent) -> None:
        key: ForwardedRoomKeyEventContent = evt.content
        if await self.crypto_store.has_group_session(key.room_id, key.session_id):
            self.log.debug(
                f"Ignoring received session {key.session_id} from {evt.sender}/"
                f"{evt.sender_device}, as crypto store says we have it already"
            )
            return
        if not key.beeper_max_messages or not key.beeper_max_age_ms:
            await self._fill_encryption_info(key)
        key.forwarding_key_chain.append(evt.sender_key)
        sess = InboundGroupSession.import_session(
            key.session_key,
            key.signing_key,
            key.sender_key,
            key.room_id,
            key.forwarding_key_chain,
            max_age=key.beeper_max_age_ms,
            max_messages=key.beeper_max_messages,
            is_scheduled=key.beeper_is_scheduled,
        )
        if key.session_id != sess.id:
            self.log.warning(
                f"Mismatched session ID while importing forwarded key from "
                f"{evt.sender}/{evt.sender_device}: '{key.session_id}' != '{sess.id}'"
            )
            return

        await self.crypto_store.put_group_session(
            key.room_id, key.sender_key, key.session_id, sess
        )
        self._mark_session_received(key.session_id)

        self.log.debug(
            f"Imported {key.session_id} for {key.room_id} "
            f"from {evt.sender}/{evt.sender_device}"
        )

        try:
            task = self._key_request_waiters[key.session_id]
        except KeyError:
            pass
        else:
            task.set_result((evt.sender, evt.sender_device))