File: attachments.py

package info (click to toggle)
mautrix-python 0.20.7-1
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 1,812 kB
  • sloc: python: 19,103; makefile: 16
file content (166 lines) | stat: -rw-r--r-- 5,363 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# Copyright 2018 Zil0 (under the Apache 2.0 license)
# Copyright © 2019 Damir Jelić <poljar@termina.org.uk> (under the Apache 2.0 license)
# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

from __future__ import annotations

from typing import Generator, Iterable
import binascii
import struct

import unpaddedbase64

from mautrix.errors import DecryptionError
from mautrix.types import EncryptedFile, JSONWebKey

try:
    from Crypto import Random
    from Crypto.Cipher import AES
    from Crypto.Hash import SHA256
    from Crypto.Util import Counter
except ImportError:
    from Cryptodome import Random
    from Cryptodome.Cipher import AES
    from Cryptodome.Hash import SHA256
    from Cryptodome.Util import Counter


def decrypt_attachment(
    ciphertext: bytes | bytearray | memoryview, key: str, hash: str, iv: str, inplace: bool = False
) -> bytes:
    """Decrypt an encrypted attachment.

    Args:
        ciphertext: The data to decrypt.
        key: AES_CTR JWK key object.
        hash: Base64 encoded SHA-256 hash of the ciphertext.
        iv: Base64 encoded 16 byte AES-CTR IV.
        inplace: Should the decryption be performed in-place?
                 The input must be a bytearray or writable memoryview to use this.
    Returns:
        The plaintext bytes.
    Raises:
        EncryptionError: if the integrity check fails.
    """
    expected_hash = unpaddedbase64.decode_base64(hash)

    h = SHA256.new()
    h.update(ciphertext)

    if h.digest() != expected_hash:
        raise DecryptionError("Mismatched SHA-256 digest")

    try:
        byte_key: bytes = unpaddedbase64.decode_base64(key)
    except (binascii.Error, TypeError):
        raise DecryptionError("Error decoding key")

    try:
        byte_iv: bytes = unpaddedbase64.decode_base64(iv)
        if len(byte_iv) != 16:
            raise DecryptionError("Invalid IV length")
        prefix = byte_iv[:8]
        # A non-zero IV counter is not spec-compliant, but some clients still do it,
        # so decode the counter part too.
        initial_value = struct.unpack(">Q", byte_iv[8:])[0]
    except (binascii.Error, TypeError, IndexError, struct.error):
        raise DecryptionError("Error decoding IV")

    ctr = Counter.new(64, prefix=prefix, initial_value=initial_value)

    try:
        cipher = AES.new(byte_key, AES.MODE_CTR, counter=ctr)
    except ValueError as e:
        raise DecryptionError("Failed to create AES cipher") from e

    if inplace:
        cipher.decrypt(ciphertext, ciphertext)
        return ciphertext
    else:
        return cipher.decrypt(ciphertext)


def encrypt_attachment(plaintext: bytes) -> tuple[bytes, EncryptedFile]:
    """Encrypt data in order to send it as an encrypted attachment.

    Args:
        plaintext: The data to encrypt.

    Returns:
        A tuple with the encrypted bytes and a dict containing the info needed
        to decrypt data. See ``encrypted_attachment_generator()`` for the keys.
    """
    values = list(encrypted_attachment_generator(plaintext))
    return b"".join(values[:-1]), values[-1]


def _prepare_encryption() -> tuple[bytes, bytes, AES, SHA256.SHA256Hash]:
    key = Random.new().read(32)
    # 8 bytes IV
    iv = Random.new().read(8)
    # 8 bytes counter, prefixed by the IV
    ctr = Counter.new(64, prefix=iv, initial_value=0)

    cipher = AES.new(key, AES.MODE_CTR, counter=ctr)
    sha256 = SHA256.new()

    return key, iv, cipher, sha256


def inplace_encrypt_attachment(data: bytearray | memoryview) -> EncryptedFile:
    key, iv, cipher, sha256 = _prepare_encryption()

    cipher.encrypt(plaintext=data, output=data)
    sha256.update(data)

    return _get_decryption_info(key, iv, sha256)


def encrypted_attachment_generator(
    data: bytes | Iterable[bytes],
) -> Generator[bytes | EncryptedFile, None, None]:
    """Generator to encrypt data in order to send it as an encrypted
    attachment.

    Unlike ``encrypt_attachment()``, this function lazily encrypts and yields
    data, thus it can be used to encrypt large files without fully loading them
    into memory if an iterable of bytes is passed as data.

    Args:
        data: The data to encrypt.

    Yields:
        The encrypted bytes for each chunk of data.
        The last yielded value will be a dict containing the info needed to decrypt data.
    """

    key, iv, cipher, sha256 = _prepare_encryption()

    if isinstance(data, bytes):
        data = [data]

    for chunk in data:
        encrypted_chunk = cipher.encrypt(chunk)  # in executor
        sha256.update(encrypted_chunk)  # in executor
        yield encrypted_chunk

    yield _get_decryption_info(key, iv, sha256)


def _get_decryption_info(key: bytes, iv: bytes, sha256: SHA256.SHA256Hash) -> EncryptedFile:
    return EncryptedFile(
        version="v2",
        iv=unpaddedbase64.encode_base64(iv + b"\x00" * 8),
        hashes={"sha256": unpaddedbase64.encode_base64(sha256.digest())},
        key=JSONWebKey(
            key_type="oct",
            algorithm="A256CTR",
            extractable=True,
            key_ops=["encrypt", "decrypt"],
            key=unpaddedbase64.encode_base64(key, urlsafe=True),
        ),
    )