1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
|
# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
import json
import olm
from mautrix.errors import (
DecryptedPayloadError,
DecryptionError,
DuplicateMessageIndex,
MismatchingRoomError,
SessionNotFound,
VerificationError,
)
from mautrix.types import (
EncryptedEvent,
EncryptedMegolmEventContent,
EncryptionAlgorithm,
Event,
SessionID,
TrustState,
)
from .device_lists import DeviceListMachine
from .sessions import InboundGroupSession
class MegolmDecryptionMachine(DeviceListMachine):
async def decrypt_megolm_event(self, evt: EncryptedEvent) -> Event:
"""
Decrypt an event that was encrypted using Megolm.
Args:
evt: The whole encrypted event.
Returns:
The decrypted event, including some unencrypted metadata from the input event.
Raises:
DecryptionError: If decryption failed.
"""
if not isinstance(evt.content, EncryptedMegolmEventContent):
raise DecryptionError("Unsupported event content class")
elif evt.content.algorithm != EncryptionAlgorithm.MEGOLM_V1:
raise DecryptionError("Unsupported event encryption algorithm")
async with self._megolm_decrypt_lock:
session = await self.crypto_store.get_group_session(
evt.room_id, evt.content.session_id
)
if session is None:
# TODO check if olm session is wedged
raise SessionNotFound(evt.content.session_id, evt.content.sender_key)
try:
plaintext, index = session.decrypt(evt.content.ciphertext)
except olm.OlmGroupSessionError as e:
raise DecryptionError("Failed to decrypt megolm event") from e
if not await self.crypto_store.validate_message_index(
session.sender_key, SessionID(session.id), evt.event_id, index, evt.timestamp
):
raise DuplicateMessageIndex()
await self._ratchet_session(session, index)
forwarded_keys = False
if (
evt.content.device_id == self.client.device_id
and session.signing_key == self.account.signing_key
and session.sender_key == self.account.identity_key
and not session.forwarding_chain
):
trust_level = TrustState.VERIFIED
else:
device = await self.get_or_fetch_device_by_key(evt.sender, session.sender_key)
if not session.forwarding_chain or (
len(session.forwarding_chain) == 1
and session.forwarding_chain[0] == session.sender_key
):
if not device:
self.log.debug(
f"Couldn't resolve trust level of session {session.id}: "
f"sent by unknown device {evt.sender}/{session.sender_key}"
)
trust_level = TrustState.UNKNOWN_DEVICE
elif (
device.signing_key != session.signing_key
or device.identity_key != session.sender_key
):
raise VerificationError()
else:
trust_level = await self.resolve_trust(device)
else:
forwarded_keys = True
last_chain_item = session.forwarding_chain[-1]
received_from = await self.crypto_store.find_device_by_key(
evt.sender, last_chain_item
)
if received_from:
trust_level = await self.resolve_trust(received_from)
else:
self.log.debug(
f"Couldn't resolve trust level of session {session.id}: "
f"forwarding chain ends with unknown device {last_chain_item}"
)
trust_level = TrustState.FORWARDED
try:
data = json.loads(plaintext)
room_id = data["room_id"]
event_type = data["type"]
content = data["content"]
except json.JSONDecodeError as e:
raise DecryptedPayloadError("Failed to parse megolm payload") from e
except KeyError as e:
raise DecryptedPayloadError("Megolm payload is missing fields") from e
if room_id != evt.room_id:
raise MismatchingRoomError()
if evt.content.relates_to and "m.relates_to" not in content:
content["m.relates_to"] = evt.content.relates_to.serialize()
result = Event.deserialize(
{
"room_id": evt.room_id,
"event_id": evt.event_id,
"sender": evt.sender,
"origin_server_ts": evt.timestamp,
"type": event_type,
"content": content,
}
)
result.unsigned = evt.unsigned
result.type = result.type.with_class(evt.type.t_class)
result["mautrix"] = {
"trust_state": trust_level,
"forwarded_keys": forwarded_keys,
"was_encrypted": True,
}
return result
async def _ratchet_session(self, sess: InboundGroupSession, index: int) -> None:
expected_message_index = sess.ratchet_safety.next_index
did_modify = True
if index > expected_message_index:
sess.ratchet_safety.missed_indices += list(range(expected_message_index, index))
sess.ratchet_safety.next_index = index + 1
elif index == expected_message_index:
sess.ratchet_safety.next_index = index + 1
else:
try:
sess.ratchet_safety.missed_indices.remove(index)
except ValueError:
did_modify = False
# Use presence of received_at as a sign that this is a recent megolm session,
# and therefore it's safe to drop missed indices entirely.
if (
sess.received_at
and sess.ratchet_safety.missed_indices
and sess.ratchet_safety.missed_indices[0] < expected_message_index - 10
):
i = 0
for i, lost_index in enumerate(sess.ratchet_safety.missed_indices):
if lost_index < expected_message_index - 10:
sess.ratchet_safety.lost_indices.append(lost_index)
else:
break
sess.ratchet_safety.missed_indices = sess.ratchet_safety.missed_indices[i + 1 :]
ratchet_target_index = sess.ratchet_safety.next_index
if len(sess.ratchet_safety.missed_indices) > 0:
ratchet_target_index = min(sess.ratchet_safety.missed_indices)
self.log.debug(
f"Ratchet safety info for {sess.id}: {sess.ratchet_safety}, {ratchet_target_index=}"
)
sess_id = SessionID(sess.id)
if (
sess.max_messages
and ratchet_target_index >= sess.max_messages
and not sess.ratchet_safety.missed_indices
and self.delete_fully_used_keys_on_decrypt
):
self.log.info(f"Deleting fully used session {sess.id}")
await self.crypto_store.redact_group_session(
sess.room_id, sess_id, reason="maximum messages reached"
)
return
elif sess.first_known_index < ratchet_target_index and self.ratchet_keys_on_decrypt:
self.log.info(f"Ratcheting session {sess.id} to {ratchet_target_index}")
sess = sess.ratchet_to(ratchet_target_index)
elif not did_modify:
return
await self.crypto_store.put_group_session(sess.room_id, sess.sender_key, sess_id, sess)
|