1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
|
# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Optional
import asyncio
import olm
from mautrix.errors import DecryptionError, MatchingSessionDecryptionError
from mautrix.types import (
DecryptedOlmEvent,
EncryptedOlmEventContent,
EncryptionAlgorithm,
IdentityKey,
OlmCiphertext,
OlmMsgType,
ToDeviceEvent,
UserID,
)
from mautrix.util import background_task
from .base import BaseOlmMachine
from .sessions import Session
class OlmDecryptionMachine(BaseOlmMachine):
async def _decrypt_olm_event(self, evt: ToDeviceEvent) -> DecryptedOlmEvent:
if not isinstance(evt.content, EncryptedOlmEventContent):
raise DecryptionError("unsupported event content class")
elif evt.content.algorithm != EncryptionAlgorithm.OLM_V1:
raise DecryptionError("unsupported event encryption algorithm")
try:
own_content = evt.content.ciphertext[self.account.identity_key]
except KeyError:
raise DecryptionError("olm event doesn't contain ciphertext for this device")
self.log.debug(
f"Decrypting to-device olm event from {evt.sender}/{evt.content.sender_key}"
)
plaintext = await self._decrypt_olm_ciphertext(
evt.sender, evt.content.sender_key, own_content
)
try:
decrypted_evt: DecryptedOlmEvent = DecryptedOlmEvent.parse_json(plaintext)
except Exception:
self.log.trace("Failed to parse olm event plaintext: %s", plaintext)
raise
if decrypted_evt.sender != evt.sender:
raise DecryptionError("mismatched sender in olm payload")
elif decrypted_evt.recipient != self.client.mxid:
raise DecryptionError("mismatched recipient in olm payload")
elif decrypted_evt.recipient_keys.ed25519 != self.account.signing_key:
raise DecryptionError("mismatched recipient key in olm payload")
decrypted_evt.sender_key = evt.content.sender_key
decrypted_evt.source = evt
self.log.debug(
f"Successfully decrypted olm event from {evt.sender}/{decrypted_evt.sender_device} "
f"(sender key: {decrypted_evt.sender_key} into a {decrypted_evt.type}"
)
return decrypted_evt
async def _decrypt_olm_ciphertext(
self, sender: UserID, sender_key: IdentityKey, message: OlmCiphertext
) -> str:
if message.type not in (OlmMsgType.PREKEY, OlmMsgType.MESSAGE):
raise DecryptionError("unsupported olm message type")
try:
plaintext = await self._try_decrypt_olm_ciphertext(sender_key, message)
except MatchingSessionDecryptionError:
self.log.warning(
f"Found matching session yet decryption failed for sender {sender}"
f" with key {sender_key}"
)
background_task.create(self._unwedge_session(sender, sender_key))
raise
if not plaintext:
if message.type != OlmMsgType.PREKEY:
background_task.create(self._unwedge_session(sender, sender_key))
raise DecryptionError("Decryption failed for normal message")
self.log.trace(f"Trying to create inbound session for {sender}/{sender_key}")
try:
session = await self._create_inbound_session(sender_key, message.body)
except olm.OlmSessionError as e:
background_task.create(self._unwedge_session(sender, sender_key))
raise DecryptionError("Failed to create new session from prekey message") from e
self.log.debug(
f"Created inbound session {session.id} for {sender} (sender key: {sender_key})"
)
try:
plaintext = session.decrypt(message)
except olm.OlmSessionError as e:
raise DecryptionError(
"Failed to decrypt olm event with session created from prekey message"
) from e
await self.crypto_store.update_session(sender_key, session)
return plaintext
async def _try_decrypt_olm_ciphertext(
self, sender_key: IdentityKey, message: OlmCiphertext
) -> Optional[str]:
sessions = await self.crypto_store.get_sessions(sender_key)
for session in sessions:
if message.type == OlmMsgType.PREKEY and not session.matches(message.body):
continue
try:
plaintext = session.decrypt(message)
except olm.OlmSessionError as e:
if message.type == OlmMsgType.PREKEY:
raise MatchingSessionDecryptionError(
"decryption failed with matching session"
) from e
else:
await self.crypto_store.update_session(sender_key, session)
return plaintext
return None
async def _create_inbound_session(self, sender_key: IdentityKey, ciphertext: str) -> Session:
session = self.account.new_inbound_session(sender_key, ciphertext)
await self.crypto_store.put_account(self.account)
await self.crypto_store.add_session(sender_key, session)
return session
async def _unwedge_session(self, sender: UserID, sender_key: IdentityKey) -> None:
raise NotImplementedError()
|