# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations

from typing import Awaitable, Callable
from abc import ABC

from mautrix import __optional_imports__
from mautrix.bridge.portal import BasePortal
from mautrix.crypto import StateStore
from mautrix.types import RoomEncryptionStateEventContent, RoomID, UserID
from mautrix.util.async_db import Database

GetPortalFunc = Callable[[RoomID], Awaitable[BasePortal]]


class BaseCryptoStateStore(StateStore, ABC):
    get_portal: GetPortalFunc

    def __init__(self, get_portal: GetPortalFunc):
        self.get_portal = get_portal

    async def is_encrypted(self, room_id: RoomID) -> bool:
        portal = await self.get_portal(room_id)
        return portal.encrypted if portal else False


class PgCryptoStateStore(BaseCryptoStateStore):
    db: Database

    def __init__(self, db: Database, get_portal: GetPortalFunc) -> None:
        super().__init__(get_portal)
        self.db = db

    async def find_shared_rooms(self, user_id: UserID) -> list[RoomID]:
        rows = await self.db.fetch(
            "SELECT room_id FROM mx_user_profile "
            "LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id "
            "WHERE user_id=$1 AND portal.encrypted=true",
            user_id,
        )
        return [row["room_id"] for row in rows]

    async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEventContent | None:
        val = await self.db.fetchval(
            "SELECT encryption FROM mx_room_state WHERE room_id=$1", room_id
        )
        if not val:
            return None
        return RoomEncryptionStateEventContent.parse_json(val)
