# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations

from mautrix.client.state_store import SyncStore
from mautrix.types import (
    CrossSigner,
    CrossSigningUsage,
    DeviceID,
    DeviceIdentity,
    EventID,
    IdentityKey,
    RoomID,
    SessionID,
    SigningKey,
    SyncToken,
    TOFUSigningKey,
    UserID,
)

from ..account import OlmAccount
from ..sessions import InboundGroupSession, OutboundGroupSession, Session
from .abstract import CryptoStore


class MemoryCryptoStore(CryptoStore, SyncStore):
    _device_id: DeviceID | None
    _sync_token: SyncToken | None
    _account: OlmAccount | None
    _message_indices: dict[tuple[IdentityKey, SessionID, int], tuple[EventID, int]]
    _devices: dict[UserID, dict[DeviceID, DeviceIdentity]]
    _olm_sessions: dict[IdentityKey, list[Session]]
    _inbound_sessions: dict[tuple[RoomID, SessionID], InboundGroupSession]
    _outbound_sessions: dict[RoomID, OutboundGroupSession]
    _signatures: dict[CrossSigner, dict[CrossSigner, str]]
    _cross_signing_keys: dict[UserID, dict[CrossSigningUsage, TOFUSigningKey]]

    def __init__(self, account_id: str, pickle_key: str) -> None:
        self.account_id = account_id
        self.pickle_key = pickle_key

        self._sync_token = None
        self._device_id = None
        self._account = None
        self._message_indices = {}
        self._devices = {}
        self._olm_sessions = {}
        self._inbound_sessions = {}
        self._outbound_sessions = {}
        self._signatures = {}
        self._cross_signing_keys = {}

    async def get_device_id(self) -> DeviceID | None:
        return self._device_id

    async def put_device_id(self, device_id: DeviceID) -> None:
        self._device_id = device_id

    async def put_next_batch(self, next_batch: SyncToken) -> None:
        self._sync_token = next_batch

    async def get_next_batch(self) -> SyncToken:
        return self._sync_token

    async def delete(self) -> None:
        self._account = None
        self._device_id = None
        self._olm_sessions = {}
        self._outbound_sessions = {}

    async def put_account(self, account: OlmAccount) -> None:
        self._account = account

    async def get_account(self) -> OlmAccount:
        return self._account

    async def has_session(self, key: IdentityKey) -> bool:
        return key in self._olm_sessions

    async def get_sessions(self, key: IdentityKey) -> list[Session]:
        return self._olm_sessions.get(key, [])

    async def get_latest_session(self, key: IdentityKey) -> Session | None:
        try:
            return self._olm_sessions[key][-1]
        except (KeyError, IndexError):
            return None

    async def add_session(self, key: IdentityKey, session: Session) -> None:
        self._olm_sessions.setdefault(key, []).append(session)

    async def update_session(self, key: IdentityKey, session: Session) -> None:
        # This is a no-op as the session object is the same one previously added.
        pass

    async def put_group_session(
        self,
        room_id: RoomID,
        sender_key: IdentityKey,
        session_id: SessionID,
        session: InboundGroupSession,
    ) -> None:
        self._inbound_sessions[(room_id, session_id)] = session

    async def get_group_session(
        self, room_id: RoomID, session_id: SessionID
    ) -> InboundGroupSession:
        return self._inbound_sessions.get((room_id, session_id))

    async def redact_group_session(
        self, room_id: RoomID, session_id: SessionID, reason: str
    ) -> None:
        self._inbound_sessions.pop((room_id, session_id), None)

    async def redact_group_sessions(
        self, room_id: RoomID, sender_key: IdentityKey, reason: str
    ) -> list[SessionID]:
        if not room_id and not sender_key:
            raise ValueError("Either room_id or sender_key must be provided")
        deleted = []
        keys = list(self._inbound_sessions.keys())
        for key in keys:
            item = self._inbound_sessions[key]
            if (not room_id or item.room_id == room_id) and (
                not sender_key or item.sender_key == sender_key
            ):
                deleted.append(SessionID(item.id))
                del self._inbound_sessions[key]
        return deleted

    async def redact_expired_group_sessions(self) -> list[SessionID]:
        raise NotImplementedError()

    async def redact_outdated_group_sessions(self) -> list[SessionID]:
        raise NotImplementedError()

    async def has_group_session(self, room_id: RoomID, session_id: SessionID) -> bool:
        return (room_id, session_id) in self._inbound_sessions

    async def add_outbound_group_session(self, session: OutboundGroupSession) -> None:
        self._outbound_sessions[session.room_id] = session

    async def update_outbound_group_session(self, session: OutboundGroupSession) -> None:
        # This is a no-op as the session object is the same one previously added.
        pass

    async def get_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSession | None:
        return self._outbound_sessions.get(room_id)

    async def remove_outbound_group_session(self, room_id: RoomID) -> None:
        self._outbound_sessions.pop(room_id, None)

    async def remove_outbound_group_sessions(self, rooms: list[RoomID]) -> None:
        for room_id in rooms:
            self._outbound_sessions.pop(room_id, None)

    async def validate_message_index(
        self,
        sender_key: IdentityKey,
        session_id: SessionID,
        event_id: EventID,
        index: int,
        timestamp: int,
    ) -> bool:
        try:
            return self._message_indices[(sender_key, session_id, index)] == (event_id, timestamp)
        except KeyError:
            self._message_indices[(sender_key, session_id, index)] = (event_id, timestamp)
            return True

    async def get_devices(self, user_id: UserID) -> dict[DeviceID, DeviceIdentity] | None:
        return self._devices.get(user_id)

    async def get_device(self, user_id: UserID, device_id: DeviceID) -> DeviceIdentity | None:
        return self._devices.get(user_id, {}).get(device_id)

    async def find_device_by_key(
        self, user_id: UserID, identity_key: IdentityKey
    ) -> DeviceIdentity | None:
        for device in self._devices.get(user_id, {}).values():
            if device.identity_key == identity_key:
                return device
        return None

    async def put_devices(self, user_id: UserID, devices: dict[DeviceID, DeviceIdentity]) -> None:
        self._devices[user_id] = devices

    async def filter_tracked_users(self, users: list[UserID]) -> list[UserID]:
        return [user_id for user_id in users if user_id in self._devices]

    async def put_cross_signing_key(
        self, user_id: UserID, usage: CrossSigningUsage, key: SigningKey
    ) -> None:
        try:
            current = self._cross_signing_keys[user_id][usage]
        except KeyError:
            self._cross_signing_keys.setdefault(user_id, {})[usage] = TOFUSigningKey(
                key=key, first=key
            )
        else:
            current.key = key

    async def get_cross_signing_keys(
        self, user_id: UserID
    ) -> dict[CrossSigningUsage, TOFUSigningKey]:
        return self._cross_signing_keys.get(user_id, {})

    async def put_signature(
        self, target: CrossSigner, signer: CrossSigner, signature: str
    ) -> None:
        self._signatures.setdefault(signer, {})[target] = signature

    async def is_key_signed_by(self, target: CrossSigner, signer: CrossSigner) -> bool:
        return target in self._signatures.get(signer, {})

    async def drop_signatures_by_key(self, signer: CrossSigner) -> int:
        deleted = self._signatures.pop(signer, None)
        return len(deleted)
