1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
|
# Copyright (c) 2023 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Dict, List, Optional, Union
import asyncio
import uuid
from mautrix.types import (
DecryptedOlmEvent,
DeviceID,
EncryptionAlgorithm,
EventType,
ForwardedRoomKeyEventContent,
IdentityKey,
KeyRequestAction,
RequestedKeyInfo,
RoomID,
RoomKeyRequestEventContent,
SessionID,
UserID,
)
from .base import BaseOlmMachine
from .sessions import InboundGroupSession
class KeyRequestingMachine(BaseOlmMachine):
async def request_room_key(
self,
room_id: RoomID,
sender_key: IdentityKey,
session_id: SessionID,
from_devices: Dict[UserID, List[DeviceID]],
timeout: Optional[Union[int, float]] = None,
) -> bool:
"""
Request keys for a Megolm group session from other devices.
Once the keys are received, or if this task is cancelled (via the ``timeout`` parameter),
a cancel request event is sent to the remaining devices. If the ``timeout`` is set to zero
or less, this will return immediately, and the extra key requests will not be cancelled.
Args:
room_id: The room where the session is used.
sender_key: The key of the user who created the session.
session_id: The ID of the session.
from_devices: A dict from user ID to list of device IDs whom to ask for the keys.
timeout: The maximum number of seconds to wait for the keys. If the timeout is
``None``, the wait time is not limited, but the task can still be cancelled.
If it's zero or less, this returns immediately and will never cancel requests.
Returns:
``True`` if the keys were received and are now in the crypto store,
``False`` otherwise (including if the method didn't wait at all).
"""
request_id = str(uuid.uuid1())
request = RoomKeyRequestEventContent(
action=KeyRequestAction.REQUEST,
body=RequestedKeyInfo(
algorithm=EncryptionAlgorithm.MEGOLM_V1,
room_id=room_id,
sender_key=sender_key,
session_id=session_id,
),
request_id=request_id,
requesting_device_id=self.client.device_id,
)
wait = timeout is None or timeout > 0
fut: Optional[asyncio.Future] = None
if wait:
fut = asyncio.get_running_loop().create_future()
self._key_request_waiters[session_id] = fut
await self.client.send_to_device(
EventType.ROOM_KEY_REQUEST,
{
user_id: {device_id: request for device_id in devices}
for user_id, devices in from_devices.items()
},
)
if not wait:
# Timeout is set and <=0, don't wait for keys
return False
assert fut is not None
got_keys = False
try:
user_id, device_id = await asyncio.wait_for(fut, timeout=timeout)
got_keys = True
try:
del from_devices[user_id][device_id]
if len(from_devices[user_id]) == 0:
del from_devices[user_id]
except KeyError:
pass
except (asyncio.CancelledError, asyncio.TimeoutError):
pass
del self._key_request_waiters[session_id]
if len(from_devices) > 0:
cancel = RoomKeyRequestEventContent(
action=KeyRequestAction.CANCEL,
request_id=str(request_id),
requesting_device_id=self.client.device_id,
)
await self.client.send_to_device(
EventType.ROOM_KEY_REQUEST,
{
user_id: {device_id: cancel for device_id in devices}
for user_id, devices in from_devices.items()
},
)
return got_keys
async def _receive_forwarded_room_key(self, evt: DecryptedOlmEvent) -> None:
key: ForwardedRoomKeyEventContent = evt.content
if await self.crypto_store.has_group_session(key.room_id, key.session_id):
self.log.debug(
f"Ignoring received session {key.session_id} from {evt.sender}/"
f"{evt.sender_device}, as crypto store says we have it already"
)
return
if not key.beeper_max_messages or not key.beeper_max_age_ms:
await self._fill_encryption_info(key)
key.forwarding_key_chain.append(evt.sender_key)
sess = InboundGroupSession.import_session(
key.session_key,
key.signing_key,
key.sender_key,
key.room_id,
key.forwarding_key_chain,
max_age=key.beeper_max_age_ms,
max_messages=key.beeper_max_messages,
is_scheduled=key.beeper_is_scheduled,
)
if key.session_id != sess.id:
self.log.warning(
f"Mismatched session ID while importing forwarded key from "
f"{evt.sender}/{evt.sender_device}: '{key.session_id}' != '{sess.id}'"
)
return
await self.crypto_store.put_group_session(
key.room_id, key.sender_key, key.session_id, sess
)
self._mark_session_received(key.session_id)
self.log.debug(
f"Imported {key.session_id} for {key.room_id} "
f"from {evt.sender}/{evt.sender_device}"
)
try:
task = self._key_request_waiters[key.session_id]
except KeyError:
pass
else:
task.set_result((evt.sender, evt.sender_device))
|