File: decrypt_olm.py

package info (click to toggle)
mautrix-python 0.20.7-1
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 1,812 kB
  • sloc: python: 19,103; makefile: 16
file content (134 lines) | stat: -rw-r--r-- 5,589 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import Optional
import asyncio

import olm

from mautrix.errors import DecryptionError, MatchingSessionDecryptionError
from mautrix.types import (
    DecryptedOlmEvent,
    EncryptedOlmEventContent,
    EncryptionAlgorithm,
    IdentityKey,
    OlmCiphertext,
    OlmMsgType,
    ToDeviceEvent,
    UserID,
)
from mautrix.util import background_task

from .base import BaseOlmMachine
from .sessions import Session


class OlmDecryptionMachine(BaseOlmMachine):
    async def _decrypt_olm_event(self, evt: ToDeviceEvent) -> DecryptedOlmEvent:
        if not isinstance(evt.content, EncryptedOlmEventContent):
            raise DecryptionError("unsupported event content class")
        elif evt.content.algorithm != EncryptionAlgorithm.OLM_V1:
            raise DecryptionError("unsupported event encryption algorithm")
        try:
            own_content = evt.content.ciphertext[self.account.identity_key]
        except KeyError:
            raise DecryptionError("olm event doesn't contain ciphertext for this device")

        self.log.debug(
            f"Decrypting to-device olm event from {evt.sender}/{evt.content.sender_key}"
        )
        plaintext = await self._decrypt_olm_ciphertext(
            evt.sender, evt.content.sender_key, own_content
        )

        try:
            decrypted_evt: DecryptedOlmEvent = DecryptedOlmEvent.parse_json(plaintext)
        except Exception:
            self.log.trace("Failed to parse olm event plaintext: %s", plaintext)
            raise
        if decrypted_evt.sender != evt.sender:
            raise DecryptionError("mismatched sender in olm payload")
        elif decrypted_evt.recipient != self.client.mxid:
            raise DecryptionError("mismatched recipient in olm payload")
        elif decrypted_evt.recipient_keys.ed25519 != self.account.signing_key:
            raise DecryptionError("mismatched recipient key in olm payload")
        decrypted_evt.sender_key = evt.content.sender_key
        decrypted_evt.source = evt
        self.log.debug(
            f"Successfully decrypted olm event from {evt.sender}/{decrypted_evt.sender_device} "
            f"(sender key: {decrypted_evt.sender_key} into a {decrypted_evt.type}"
        )
        return decrypted_evt

    async def _decrypt_olm_ciphertext(
        self, sender: UserID, sender_key: IdentityKey, message: OlmCiphertext
    ) -> str:
        if message.type not in (OlmMsgType.PREKEY, OlmMsgType.MESSAGE):
            raise DecryptionError("unsupported olm message type")

        try:
            plaintext = await self._try_decrypt_olm_ciphertext(sender_key, message)
        except MatchingSessionDecryptionError:
            self.log.warning(
                f"Found matching session yet decryption failed for sender {sender}"
                f" with key {sender_key}"
            )
            background_task.create(self._unwedge_session(sender, sender_key))
            raise

        if not plaintext:
            if message.type != OlmMsgType.PREKEY:
                background_task.create(self._unwedge_session(sender, sender_key))
                raise DecryptionError("Decryption failed for normal message")

            self.log.trace(f"Trying to create inbound session for {sender}/{sender_key}")
            try:
                session = await self._create_inbound_session(sender_key, message.body)
            except olm.OlmSessionError as e:
                background_task.create(self._unwedge_session(sender, sender_key))
                raise DecryptionError("Failed to create new session from prekey message") from e
            self.log.debug(
                f"Created inbound session {session.id} for {sender} (sender key: {sender_key})"
            )

            try:
                plaintext = session.decrypt(message)
            except olm.OlmSessionError as e:
                raise DecryptionError(
                    "Failed to decrypt olm event with session created from prekey message"
                ) from e

            await self.crypto_store.update_session(sender_key, session)

        return plaintext

    async def _try_decrypt_olm_ciphertext(
        self, sender_key: IdentityKey, message: OlmCiphertext
    ) -> Optional[str]:
        sessions = await self.crypto_store.get_sessions(sender_key)
        for session in sessions:
            if message.type == OlmMsgType.PREKEY and not session.matches(message.body):
                continue

            try:
                plaintext = session.decrypt(message)
            except olm.OlmSessionError as e:
                if message.type == OlmMsgType.PREKEY:
                    raise MatchingSessionDecryptionError(
                        "decryption failed with matching session"
                    ) from e
            else:
                await self.crypto_store.update_session(sender_key, session)
                return plaintext
        return None

    async def _create_inbound_session(self, sender_key: IdentityKey, ciphertext: str) -> Session:
        session = self.account.new_inbound_session(sender_key, ciphertext)
        await self.crypto_store.put_account(self.account)
        await self.crypto_store.add_session(sender_key, session)
        return session

    async def _unwedge_session(self, sender: UserID, sender_key: IdentityKey) -> None:
        raise NotImplementedError()