import functools
import os
import struct

from cryptography.exceptions import InvalidTag
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
from cryptography.hazmat.primitives.asymmetric import ec

MAX_RECORD_SIZE = pow(2, 31) - 1
MIN_RECORD_SIZE = 3
KEY_LENGTH = 16
NONCE_LENGTH = 12
TAG_LENGTH = 16

# Valid content types (ordered from newest, to most obsolete)
versions = {
    "aes128gcm": {"pad": 1},
    "aesgcm": {"pad": 2},
    "aesgcm128": {"pad": 1},
}


class ECEException(Exception):
    """Exception for ECE encryption functions"""

    def __init__(self, message):
        self.message = message


def derive_key(
    mode, version, salt, key, private_key, dh, auth_secret, keyid, keylabel="P-256"
):
    """Derive the encryption key

    :param mode: operational mode (encrypt or decrypt)
    :type mode: enumerate('encrypt', 'decrypt)
    :param salt: encryption salt value
    :type salt: str
    :param key: raw key
    :type key: str
    :param private_key: DH private key
    :type key: object
    :param dh: Diffie Helman public key value
    :type dh: str
    :param keyid: key identifier label
    :type keyid: str
    :param keylabel: label for aesgcm/aesgcm128
    :type keylabel: str
    :param auth_secret: authorization secret
    :type auth_secret: str
    :param version: Content Type identifier
    :type version: enumerate('aes128gcm', 'aesgcm', 'aesgcm128')

    """
    context = b""
    keyinfo = ""
    nonceinfo = ""

    def build_info(base, info_context):
        return b"Content-Encoding: " + base + b"\0" + info_context

    def derive_dh(mode, version, private_key, dh, keylabel):
        def length_prefix(key):
            return struct.pack("!H", len(key)) + key

        if isinstance(dh, ec.EllipticCurvePublicKey):
            pubkey = dh
            dh = dh.public_bytes(Encoding.X962, PublicFormat.UncompressedPoint)
        else:
            pubkey = ec.EllipticCurvePublicKey.from_encoded_point(ec.SECP256R1(), dh)

        encoded = private_key.public_key().public_bytes(
            Encoding.X962, PublicFormat.UncompressedPoint
        )
        if mode == "encrypt":
            sender_pub_key = encoded
            receiver_pub_key = dh
        else:
            sender_pub_key = dh
            receiver_pub_key = encoded

        if version == "aes128gcm":
            context = b"WebPush: info\x00" + receiver_pub_key + sender_pub_key
        else:
            context = (
                keylabel.encode("utf-8")
                + b"\0"
                + length_prefix(receiver_pub_key)
                + length_prefix(sender_pub_key)
            )

        return private_key.exchange(ec.ECDH(), pubkey), context

    if version not in versions:
        raise ECEException("Invalid version")
    if mode not in ["encrypt", "decrypt"]:
        raise ECEException("unknown 'mode' specified: " + mode)
    if salt is None or len(salt) != KEY_LENGTH:
        raise ECEException("'salt' must be a 16 octet value")
    if dh is not None:
        if private_key is None:
            raise ECEException("DH requires a private_key")
        (secret, context) = derive_dh(
            mode=mode,
            version=version,
            private_key=private_key,
            dh=dh,
            keylabel=keylabel,
        )
    else:
        secret = key

    if secret is None:
        raise ECEException("unable to determine the secret")

    if version == "aesgcm":
        keyinfo = build_info(b"aesgcm", context)
        nonceinfo = build_info(b"nonce", context)
    elif version == "aesgcm128":
        keyinfo = b"Content-Encoding: aesgcm128"
        nonceinfo = b"Content-Encoding: nonce"
    elif version == "aes128gcm":
        keyinfo = b"Content-Encoding: aes128gcm\x00"
        nonceinfo = b"Content-Encoding: nonce\x00"
        if dh is None:
            # Only mix the authentication secret when using DH for aes128gcm
            auth_secret = None

    if auth_secret is not None:
        if version == "aes128gcm":
            info = context
        else:
            info = build_info(b"auth", b"")
        hkdf_auth = HKDF(
            algorithm=hashes.SHA256(),
            length=32,
            salt=auth_secret,
            info=info,
            backend=default_backend(),
        )
        secret = hkdf_auth.derive(secret)

    hkdf_key = HKDF(
        algorithm=hashes.SHA256(),
        length=KEY_LENGTH,
        salt=salt,
        info=keyinfo,
        backend=default_backend(),
    )
    hkdf_nonce = HKDF(
        algorithm=hashes.SHA256(),
        length=NONCE_LENGTH,
        salt=salt,
        info=nonceinfo,
        backend=default_backend(),
    )
    return hkdf_key.derive(secret), hkdf_nonce.derive(secret)


def iv(base, counter):
    """Generate an initialization vector."""
    if (counter >> 64) != 0:
        raise ECEException("Counter too big")
    (mask,) = struct.unpack("!Q", base[4:])
    return base[:4] + struct.pack("!Q", counter ^ mask)


def decrypt(
    content,
    salt=None,
    key=None,
    private_key=None,
    dh=None,
    auth_secret=None,
    keyid=None,
    keylabel="P-256",
    rs=4096,
    version="aes128gcm",
):
    """
    Decrypt a data block

    :param content: Data to be decrypted
    :type content: str
    :param salt: Encryption salt
    :type salt: str
    :param key: local public key
    :type key: str
    :param private_key: DH private key
    :type key: object
    :param keyid: Internal key identifier for private key info
    :type keyid: str
    :param dh: Remote Diffie Hellman sequence (omit for aes128gcm)
    :type dh: str
    :param rs: Record size
    :type rs: int
    :param auth_secret: Authorization secret
    :type auth_secret: str
    :param version: ECE Method version
    :type version: enumerate('aes128gcm', 'aesgcm', 'aesgcm128')
    :return: Decrypted message content
    :rtype str

    """

    def parse_content_header(content):
        """Parse an aes128gcm content body and extract the header values.

        :param content: The encrypted body of the message
        :type content: str

        """
        id_len = struct.unpack("!B", content[20:21])[0]
        return {
            "salt": content[:16],
            "rs": struct.unpack("!L", content[16:20])[0],
            "keyid": content[21 : 21 + id_len],
            "content": content[21 + id_len :],
        }

    def decrypt_record(key, nonce, counter, content):
        decryptor = Cipher(
            algorithms.AES(key),
            modes.GCM(iv(nonce, counter), tag=content[-TAG_LENGTH:]),
            backend=default_backend(),
        ).decryptor()
        return decryptor.update(content[:-TAG_LENGTH]) + decryptor.finalize()

    def unpad_legacy(data):
        pad_size = versions[version]["pad"]
        pad = functools.reduce(
            lambda x, y: x << 8 | y,
            struct.unpack("!" + ("B" * pad_size), data[0:pad_size]),
        )
        if pad_size + pad > len(data) or data[pad_size : pad_size + pad] != (
            b"\x00" * pad
        ):
            raise ECEException("Bad padding")
        return data[pad_size + pad :]

    def unpad(data, last):
        i = len(data) - 1
        for i in range(len(data) - 1, -1, -1):
            v = struct.unpack("B", data[i : i + 1])[0]
            if v != 0:
                if not last and v != 1:
                    raise ECEException("record delimiter != 1")
                if last and v != 2:
                    raise ECEException("last record delimiter != 2")
                return data[0:i]
        raise ECEException("all zero record plaintext")

    if version not in versions:
        raise ECEException("Invalid version")

    overhead = versions[version]["pad"]
    if version == "aes128gcm":
        try:
            content_header = parse_content_header(content)
        except Exception:
            raise ECEException("Could not parse the content header")
        salt = content_header["salt"]
        rs = content_header["rs"]
        keyid = content_header["keyid"]
        if private_key is not None and not dh:
            dh = keyid
        else:
            keyid = keyid.decode("utf-8")
        content = content_header["content"]
        overhead += 16

    (key_, nonce_) = derive_key(
        "decrypt",
        version=version,
        salt=salt,
        key=key,
        private_key=private_key,
        dh=dh,
        auth_secret=auth_secret,
        keyid=keyid,
        keylabel=keylabel,
    )
    if rs <= overhead:
        raise ECEException("Record size too small")
    chunk = rs
    if version != "aes128gcm":
        chunk += 16  # account for tags in old versions
        if len(content) % chunk == 0:
            raise ECEException("Message truncated")

    result = b""
    counter = 0
    try:
        for i in list(range(0, len(content), chunk)):
            data = decrypt_record(key_, nonce_, counter, content[i : i + chunk])
            if version == "aes128gcm":
                last = (i + chunk) >= len(content)
                result += unpad(data, last)
            else:
                result += unpad_legacy(data)
            counter += 1
    except InvalidTag as ex:
        raise ECEException("Decryption error: {}".format(repr(ex)))
    return result


def encrypt(
    content,
    salt=None,
    key=None,
    private_key=None,
    dh=None,
    auth_secret=None,
    keyid=None,
    keylabel="P-256",
    rs=4096,
    version="aes128gcm",
):
    """
    Encrypt a data block

    :param content: block of data to encrypt
    :type content: str
    :param salt: Encryption salt
    :type salt: str
    :param key: Encryption key data
    :type key: str
    :param private_key: DH private key
    :type key: object
    :param keyid: Internal key identifier for private key info
    :type keyid: str
    :param dh: Remote Diffie Hellman sequence
    :type dh: str
    :param rs: Record size
    :type rs: int
    :param auth_secret: Authorization secret
    :type auth_secret: str
    :param version: ECE Method version
    :type version: enumerate('aes128gcm', 'aesgcm', 'aesgcm128')
    :return: Encrypted message content
    :rtype str

    """

    def encrypt_record(key, nonce, counter, buf, last):
        encryptor = Cipher(
            algorithms.AES(key),
            modes.GCM(iv(nonce, counter)),
            backend=default_backend(),
        ).encryptor()

        if version == "aes128gcm":
            data = encryptor.update(buf + (b"\x02" if last else b"\x01"))
        else:
            data = encryptor.update((b"\x00" * versions[version]["pad"]) + buf)
        data += encryptor.finalize()
        data += encryptor.tag
        return data

    def compose_aes128gcm(salt, content, rs, keyid):
        """Compose the header and content of an aes128gcm encrypted
        message body

        :param salt: The sender's salt value
        :type salt: str
        :param content: The encrypted body of the message
        :type content: str
        :param rs: Override for the content length
        :type rs: int
        :param keyid: The keyid to use for this message
        :type keyid: str

        """
        if len(keyid) > 255:
            raise ECEException("keyid is too long")
        header = salt
        if rs > MAX_RECORD_SIZE:
            raise ECEException("Too much content")
        header += struct.pack("!L", rs)
        header += struct.pack("!B", len(keyid))
        header += keyid
        return header + content

    if version not in versions:
        raise ECEException("Invalid version")

    if salt is None:
        salt = os.urandom(16)

    (key_, nonce_) = derive_key(
        "encrypt",
        version=version,
        salt=salt,
        key=key,
        private_key=private_key,
        dh=dh,
        auth_secret=auth_secret,
        keyid=keyid,
        keylabel=keylabel,
    )

    overhead = versions[version]["pad"]
    if version == "aes128gcm":
        overhead += 16
        end = len(content)
    else:
        end = len(content) + 1
    if rs <= overhead:
        raise ECEException("Record size too small")
    chunk_size = rs - overhead

    result = b""
    counter = 0

    # the extra one on the loop ensures that we produce a padding only
    # record if the data length is an exact multiple of the chunk size
    for i in list(range(0, end, chunk_size)):
        result += encrypt_record(
            key_, nonce_, counter, content[i : i + chunk_size], (i + chunk_size) >= end
        )
        counter += 1
    if version == "aes128gcm":
        if keyid is None and private_key is not None:
            kid = private_key.public_key().public_bytes(
                Encoding.X962, PublicFormat.UncompressedPoint
            )
        else:
            kid = (keyid or "").encode("utf-8")
        return compose_aes128gcm(salt, result, rs, keyid=kid)
    return result
