1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
|
# Copyright (c) 2023 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Any, Awaitable, Callable, TypedDict
import asyncio
import functools
import json
import olm
from mautrix.errors import MForbidden, MNotFound
from mautrix.types import (
DeviceID,
EncryptionKeyAlgorithm,
EventType,
IdentityKey,
KeyID,
RequestedKeyInfo,
RoomEncryptionStateEventContent,
RoomID,
RoomKeyEventContent,
SessionID,
SigningKey,
TrustState,
UserID,
)
from mautrix.util.logging import TraceLogger
from .. import client as cli, crypto
class SignedObject(TypedDict):
signatures: dict[UserID, dict[str, str]]
unsigned: Any
class BaseOlmMachine:
client: cli.Client
log: TraceLogger
crypto_store: crypto.CryptoStore
state_store: crypto.StateStore
account: account.OlmAccount
send_keys_min_trust: TrustState
share_keys_min_trust: TrustState
allow_key_share: Callable[[crypto.DeviceIdentity, RequestedKeyInfo], Awaitable[bool]]
delete_outbound_keys_on_ack: bool
dont_store_outbound_keys: bool
delete_previous_keys_on_receive: bool
ratchet_keys_on_decrypt: bool
delete_fully_used_keys_on_decrypt: bool
delete_keys_on_device_delete: bool
disable_device_change_key_rotation: bool
# Futures that wait for responses to a key request
_key_request_waiters: dict[SessionID, asyncio.Future]
# Futures that wait for a session to be received (either normally or through a key request)
_inbound_session_waiters: dict[SessionID, asyncio.Future]
_prev_unwedge: dict[IdentityKey, float]
_fetch_keys_lock: asyncio.Lock
_megolm_decrypt_lock: asyncio.Lock
_share_keys_lock: asyncio.Lock
_last_key_share: float
_cs_fetch_attempted: set[UserID]
async def wait_for_session(
self, room_id: RoomID, session_id: SessionID, timeout: float = 3
) -> bool:
try:
fut = self._inbound_session_waiters[session_id]
except KeyError:
fut = asyncio.get_running_loop().create_future()
self._inbound_session_waiters[session_id] = fut
try:
return await asyncio.wait_for(asyncio.shield(fut), timeout)
except asyncio.TimeoutError:
return await self.crypto_store.has_group_session(room_id, session_id)
def _mark_session_received(self, session_id: SessionID) -> None:
try:
self._inbound_session_waiters.pop(session_id).set_result(True)
except KeyError:
return
async def _fill_encryption_info(self, evt: RoomKeyEventContent) -> None:
encryption_info = await self.state_store.get_encryption_info(evt.room_id)
if not encryption_info:
self.log.warning(
f"Encryption info for {evt.room_id} not found in state store, fetching from server"
)
try:
encryption_info = await self.client.get_state_event(
evt.room_id, EventType.ROOM_ENCRYPTION
)
except (MNotFound, MForbidden) as e:
self.log.warning(
f"Failed to get encryption info for {evt.room_id} from server: {e},"
" using defaults"
)
encryption_info = RoomEncryptionStateEventContent()
if not encryption_info:
self.log.warning(
f"Didn't find encryption info for {evt.room_id} on server either,"
" using defaults"
)
encryption_info = RoomEncryptionStateEventContent()
if not evt.beeper_max_age_ms:
evt.beeper_max_age_ms = encryption_info.rotation_period_ms
if not evt.beeper_max_messages:
evt.beeper_max_messages = encryption_info.rotation_period_msgs
canonical_json = functools.partial(
json.dumps, ensure_ascii=False, separators=(",", ":"), sort_keys=True
)
def verify_signature_json(
data: "SignedObject", user_id: UserID, key_name: DeviceID | str, key: SigningKey
) -> bool:
data_copy = {**data}
data_copy.pop("unsigned", None)
signatures = data_copy.pop("signatures")
key_id = str(KeyID(EncryptionKeyAlgorithm.ED25519, key_name))
try:
signature = signatures[user_id][key_id]
except KeyError:
return False
signed_data = canonical_json(data_copy)
try:
olm.ed25519_verify(key, signed_data, signature)
return True
except olm.OlmVerifyError:
return False
|