# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations

import asyncio
import logging
import sys

from mautrix import __optional_imports__
from mautrix.appservice import AppService
from mautrix.client import Client, InternalEventType, SyncStore
from mautrix.crypto import CryptoStore, OlmMachine, PgCryptoStore, RejectKeyShare, StateStore
from mautrix.errors import EncryptionError, MForbidden, MNotFound, SessionNotFound
from mautrix.types import (
    JSON,
    DeviceIdentity,
    EncryptedEvent,
    EncryptedMegolmEventContent,
    EventFilter,
    EventType,
    Filter,
    LoginType,
    MessageEvent,
    RequestedKeyInfo,
    RoomEventFilter,
    RoomFilter,
    RoomID,
    RoomKeyWithheldCode,
    Serializable,
    StateEvent,
    StateFilter,
    TrustState,
)
from mautrix.util import background_task
from mautrix.util.async_db import Database
from mautrix.util.logging import TraceLogger

from .. import bridge as br
from .crypto_state_store import PgCryptoStateStore


class EncryptionManager:
    loop: asyncio.AbstractEventLoop
    log: TraceLogger = logging.getLogger("mau.bridge.e2ee")

    client: Client
    crypto: OlmMachine
    crypto_store: CryptoStore | SyncStore
    crypto_db: Database | None
    state_store: StateStore

    min_send_trust: TrustState
    key_sharing_enabled: bool
    appservice_mode: bool
    periodically_delete_expired_keys: bool
    delete_outdated_inbound: bool

    bridge: br.Bridge
    az: AppService
    _id_prefix: str
    _id_suffix: str

    _share_session_events: dict[RoomID, asyncio.Event]
    _key_delete_task: asyncio.Task | None

    def __init__(
        self,
        bridge: br.Bridge,
        homeserver_address: str,
        user_id_prefix: str,
        user_id_suffix: str,
        db_url: str,
    ) -> None:
        self.loop = bridge.loop or asyncio.get_event_loop()
        self.bridge = bridge
        self.az = bridge.az
        self.device_name = bridge.name
        self._id_prefix = user_id_prefix
        self._id_suffix = user_id_suffix
        self._share_session_events = {}
        pickle_key = "mautrix.bridge.e2ee"
        self.crypto_db = Database.create(
            url=db_url,
            upgrade_table=PgCryptoStore.upgrade_table,
            log=logging.getLogger("mau.crypto.db"),
        )
        self.crypto_store = PgCryptoStore("", pickle_key, self.crypto_db)
        self.state_store = PgCryptoStateStore(self.crypto_db, bridge.get_portal)
        default_http_retry_count = bridge.config.get("homeserver.http_retry_count", None)
        self.client = Client(
            base_url=homeserver_address,
            mxid=self.az.bot_mxid,
            loop=self.loop,
            sync_store=self.crypto_store,
            log=self.log.getChild("client"),
            default_retry_count=default_http_retry_count,
            state_store=self.bridge.state_store,
        )
        self.crypto = OlmMachine(self.client, self.crypto_store, self.state_store)
        self.client.add_event_handler(InternalEventType.SYNC_STOPPED, self._exit_on_sync_fail)
        self.crypto.allow_key_share = self.allow_key_share
        verification_levels = bridge.config["bridge.encryption.verification_levels"]
        self.min_send_trust = TrustState.parse(verification_levels["send"])
        self.crypto.share_keys_min_trust = TrustState.parse(verification_levels["share"])
        self.crypto.send_keys_min_trust = TrustState.parse(verification_levels["receive"])
        self.key_sharing_enabled = bridge.config["bridge.encryption.allow_key_sharing"]
        self.appservice_mode = bridge.config["bridge.encryption.appservice"]
        if self.appservice_mode:
            self.az.otk_handler = self.crypto.handle_as_otk_counts
            self.az.device_list_handler = self.crypto.handle_as_device_lists
            self.az.to_device_handler = self.crypto.handle_as_to_device_event

        self.periodically_delete_expired_keys = False
        self.delete_outdated_inbound = False
        self._key_delete_task = None
        del_cfg = bridge.config["bridge.encryption.delete_keys"]
        if del_cfg:
            self.crypto.delete_outbound_keys_on_ack = del_cfg["delete_outbound_on_ack"]
            self.crypto.dont_store_outbound_keys = del_cfg["dont_store_outbound"]
            self.crypto.delete_previous_keys_on_receive = del_cfg["delete_prev_on_new_session"]
            self.crypto.ratchet_keys_on_decrypt = del_cfg["ratchet_on_decrypt"]
            self.crypto.delete_fully_used_keys_on_decrypt = del_cfg["delete_fully_used_on_decrypt"]
            self.crypto.delete_keys_on_device_delete = del_cfg["delete_on_device_delete"]
            self.periodically_delete_expired_keys = del_cfg["periodically_delete_expired"]
            self.delete_outdated_inbound = del_cfg["delete_outdated_inbound"]
        self.crypto.disable_device_change_key_rotation = bridge.config[
            "bridge.encryption.rotation.disable_device_change_key_rotation"
        ]

    async def _exit_on_sync_fail(self, data) -> None:
        if data["error"]:
            self.log.critical("Exiting due to crypto sync error")
            sys.exit(32)

    async def allow_key_share(self, device: DeviceIdentity, request: RequestedKeyInfo) -> bool:
        if not self.key_sharing_enabled:
            self.log.debug(
                f"Key sharing not enabled, ignoring key request from "
                f"{device.user_id}/{device.device_id}"
            )
            return False
        elif device.trust == TrustState.BLACKLISTED:
            raise RejectKeyShare(
                f"Rejecting key request from blacklisted device "
                f"{device.user_id}/{device.device_id}",
                code=RoomKeyWithheldCode.BLACKLISTED,
                reason="Your device has been blacklisted by the bridge",
            )
        elif await self.crypto.resolve_trust(device) >= self.crypto.share_keys_min_trust:
            portal = await self.bridge.get_portal(request.room_id)
            if portal is None:
                raise RejectKeyShare(
                    f"Rejecting key request for {request.session_id} from "
                    f"{device.user_id}/{device.device_id}: room is not a portal",
                    code=RoomKeyWithheldCode.UNAVAILABLE,
                    reason="Requested room is not a portal",
                )
            user = await self.bridge.get_user(device.user_id)
            if not await user.is_in_portal(portal):
                raise RejectKeyShare(
                    f"Rejecting key request for {request.session_id} from "
                    f"{device.user_id}/{device.device_id}: user is not in portal",
                    code=RoomKeyWithheldCode.UNAUTHORIZED,
                    reason="You're not in that portal",
                )
            self.log.debug(
                f"Accepting key request for {request.session_id} from "
                f"{device.user_id}/{device.device_id}"
            )
            return True
        else:
            raise RejectKeyShare(
                f"Rejecting key request from unverified device "
                f"{device.user_id}/{device.device_id}",
                code=RoomKeyWithheldCode.UNVERIFIED,
                reason="Your device is not trusted by the bridge",
            )

    def _ignore_user(self, user_id: str) -> bool:
        return (
            user_id.startswith(self._id_prefix)
            and user_id.endswith(self._id_suffix)
            and user_id != self.az.bot_mxid
        )

    async def handle_member_event(self, evt: StateEvent) -> None:
        if self._ignore_user(evt.state_key):
            # We don't want to invalidate group sessions because a ghost left or joined
            return
        await self.crypto.handle_member_event(evt)

    async def _share_session_lock(self, room_id: RoomID) -> bool:
        try:
            event = self._share_session_events[room_id]
        except KeyError:
            self._share_session_events[room_id] = asyncio.Event()
            return True
        else:
            await event.wait()
            return False

    async def encrypt(
        self, room_id: RoomID, event_type: EventType, content: Serializable | JSON
    ) -> tuple[EventType, EncryptedMegolmEventContent]:
        try:
            encrypted = await self.crypto.encrypt_megolm_event(room_id, event_type, content)
        except EncryptionError:
            self.log.debug("Got EncryptionError, sharing group session and trying again")
            if await self._share_session_lock(room_id):
                try:
                    users = await self.az.state_store.get_members_filtered(
                        room_id, self._id_prefix, self._id_suffix, self.az.bot_mxid
                    )
                    await self.crypto.share_group_session(room_id, users)
                finally:
                    self._share_session_events.pop(room_id).set()
            encrypted = await self.crypto.encrypt_megolm_event(room_id, event_type, content)
        return EventType.ROOM_ENCRYPTED, encrypted

    async def decrypt(self, evt: EncryptedEvent, wait_session_timeout: int = 5) -> MessageEvent:
        try:
            decrypted = await self.crypto.decrypt_megolm_event(evt)
        except SessionNotFound as e:
            if not wait_session_timeout:
                raise
            self.log.debug(
                f"Couldn't find session {e.session_id} trying to decrypt {evt.event_id},"
                f" waiting {wait_session_timeout} seconds..."
            )
            got_keys = await self.crypto.wait_for_session(
                evt.room_id, e.session_id, timeout=wait_session_timeout
            )
            if got_keys:
                self.log.debug(
                    f"Got session {e.session_id} after waiting, "
                    f"trying to decrypt {evt.event_id} again"
                )
                decrypted = await self.crypto.decrypt_megolm_event(evt)
            else:
                raise
        self.log.trace("Decrypted event %s: %s", evt.event_id, decrypted)
        return decrypted

    async def start(self) -> None:
        flows = await self.client.get_login_flows()
        if not flows.supports_type(LoginType.APPSERVICE):
            self.log.critical(
                "Encryption enabled in config, but homeserver does not support appservice login"
            )
            sys.exit(30)
        self.log.debug("Logging in with bridge bot user")
        if self.crypto_db:
            try:
                await self.crypto_db.start()
            except Exception as e:
                self.bridge._log_db_error(e)
        await self.crypto_store.open()
        device_id = await self.crypto_store.get_device_id()
        if device_id:
            self.log.debug(f"Found device ID in database: {device_id}")
        # We set the API token to the AS token here to authenticate the appservice login
        # It'll get overridden after the login
        self.client.api.token = self.az.as_token
        await self.client.login(
            login_type=LoginType.APPSERVICE,
            device_name=self.device_name,
            device_id=device_id,
            store_access_token=True,
            update_hs_url=False,
        )
        await self.crypto.load()
        if not device_id:
            await self.crypto_store.put_device_id(self.client.device_id)
            self.log.debug(f"Logged in with new device ID {self.client.device_id}")
        elif self.crypto.account.shared:
            await self._verify_keys_are_on_server()
        if self.appservice_mode:
            self.log.info("End-to-bridge encryption support is enabled (appservice mode)")
        else:
            _ = self.client.start(self._filter)
            self.log.info("End-to-bridge encryption support is enabled (sync mode)")
        if self.delete_outdated_inbound:
            deleted = await self.crypto_store.redact_outdated_group_sessions()
            if len(deleted) > 0:
                self.log.debug(
                    f"Deleted {len(deleted)} inbound keys which lacked expiration metadata"
                )
        if self.periodically_delete_expired_keys:
            self._key_delete_task = background_task.create(self._periodically_delete_keys())
        background_task.create(self._resync_encryption_info())

    async def _resync_encryption_info(self) -> None:
        rows = await self.crypto_db.fetch(
            """SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'"""
        )
        room_ids = [row["room_id"] for row in rows]
        if not room_ids:
            return
        self.log.debug(f"Resyncing encryption state event in rooms: {room_ids}")
        for room_id in room_ids:
            try:
                evt = await self.client.get_state_event(room_id, EventType.ROOM_ENCRYPTION)
            except (MNotFound, MForbidden) as e:
                self.log.debug(f"Failed to get encryption state in {room_id}: {e}")
                q = """
                    UPDATE mx_room_state SET encryption=NULL
                    WHERE room_id=$1 AND encryption='{"resync":true}'
                """
                await self.crypto_db.execute(q, room_id)
            else:
                self.log.debug(f"Resynced encryption state in {room_id}: {evt}")
                q = """
                    UPDATE crypto_megolm_inbound_session SET max_age=$1, max_messages=$2
                    WHERE room_id=$3 AND max_age IS NULL and max_messages IS NULL
                """
                await self.crypto_db.execute(
                    q, evt.rotation_period_ms, evt.rotation_period_msgs, room_id
                )

    async def _verify_keys_are_on_server(self) -> None:
        self.log.debug("Making sure keys are still on server")
        try:
            resp = await self.client.query_keys([self.client.mxid])
        except Exception:
            self.log.critical(
                "Failed to query own keys to make sure device still exists", exc_info=True
            )
            sys.exit(33)
        try:
            own_keys = resp.device_keys[self.client.mxid][self.client.device_id]
            if len(own_keys.keys) > 0:
                return
        except KeyError:
            pass
        self.log.critical("Existing device doesn't have keys on server, resetting crypto")
        await self.crypto.crypto_store.delete()
        await self.client.logout_all()
        sys.exit(34)

    async def stop(self) -> None:
        if self._key_delete_task:
            self._key_delete_task.cancel()
            self._key_delete_task = None
        self.client.stop()
        await self.crypto_store.close()
        if self.crypto_db:
            await self.crypto_db.stop()

    @property
    def _filter(self) -> Filter:
        all_events = EventType.find("*")
        return Filter(
            account_data=EventFilter(types=[all_events]),
            presence=EventFilter(not_types=[all_events]),
            room=RoomFilter(
                include_leave=False,
                state=StateFilter(not_types=[all_events]),
                timeline=RoomEventFilter(not_types=[all_events]),
                account_data=RoomEventFilter(not_types=[all_events]),
                ephemeral=RoomEventFilter(not_types=[all_events]),
            ),
        )

    async def _periodically_delete_keys(self) -> None:
        while True:
            deleted = await self.crypto_store.redact_expired_group_sessions()
            if deleted:
                self.log.info(f"Deleted expired megolm sessions: {deleted}")
            else:
                self.log.debug("No expired megolm sessions found")
            await asyncio.sleep(24 * 60 * 60)
