# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations

import json

from aiohttp import ClientError, ClientSession, ContentTypeError
from yarl import URL

from mautrix.api import HTTPAPI, Method, Path
from mautrix.errors import (
    WellKnownInvalidVersionsResponse,
    WellKnownMissingHomeserver,
    WellKnownNotJSON,
    WellKnownNotURL,
    WellKnownUnexpectedStatus,
    WellKnownUnsupportedScheme,
)
from mautrix.types import DeviceID, SerializerError, UserID, VersionsResponse
from mautrix.util.logging import TraceLogger


class BaseClientAPI:
    """
    BaseClientAPI is the base class for :class:`ClientAPI`. This is separate from the main
    ClientAPI class so that the ClientAPI methods can be split into multiple classes (that
    inherit this class).All those section-specific method classes are inherited by the main
    ClientAPI class to create the full class.
    """

    localpart: str
    domain: str
    _mxid: UserID
    device_id: DeviceID
    api: HTTPAPI
    log: TraceLogger
    versions_cache: VersionsResponse | None

    def __init__(
        self, mxid: UserID = "", device_id: DeviceID = "", api: HTTPAPI | None = None, **kwargs
    ) -> None:
        """
        Initialize a ClientAPI. You must either provide the ``api`` parameter with an existing
        :class:`mautrix.api.HTTPAPI` instance, or provide the ``base_url`` and other arguments for
        creating it as kwargs.

        Args:
            mxid: The Matrix ID of the user. This is used for things like setting profile metadata.
                  Additionally, the homeserver domain is extracted from this string and used for
                  setting aliases and such. This can be changed later using `set_mxid`.
            device_id: The device ID corresponding to the access token used.
            api: The :class:`mautrix.api.HTTPAPI` instance to use. You can also pass the ``kwargs``
                 to create a HTTPAPI instance rather than creating the instance yourself.
            kwargs: If ``api`` is not specified, then the arguments to pass when creating a HTTPAPI.
        """
        if mxid:
            self.mxid = mxid
        else:
            self._mxid = None
            self.localpart = None
            self.domain = None
        self.fill_member_event_callback = None
        self.versions_cache = None
        self.device_id = device_id
        self.api = api or HTTPAPI(**kwargs)
        self.log = self.api.log

    @classmethod
    def parse_user_id(cls, mxid: UserID) -> tuple[str, str]:
        """
        Parse the localpart and server name from a Matrix user ID.

        Args:
            mxid: The Matrix user ID.

        Returns:
            A tuple of (localpart, server_name).

        Raises:
            ValueError: if the given user ID is invalid.
        """
        if len(mxid) == 0:
            raise ValueError("User ID is empty")
        elif mxid[0] != "@":
            raise ValueError("User IDs start with @")
        try:
            sep = mxid.index(":")
        except ValueError as e:
            raise ValueError("User ID must contain domain separator") from e
        if sep == len(mxid) - 1:
            raise ValueError("User ID must contain domain")
        return mxid[1:sep], mxid[sep + 1 :]

    @property
    def mxid(self) -> UserID:
        return self._mxid

    @mxid.setter
    def mxid(self, mxid: UserID) -> None:
        self.localpart, self.domain = self.parse_user_id(mxid)
        self._mxid = mxid

    async def versions(self, no_cache: bool = False) -> VersionsResponse:
        """
        Get client-server spec versions supported by the server.

        Args:
            no_cache: If true, the versions will always be fetched from the server
                      rather than using cached results when availab.e.

        Returns:
            The supported Matrix spec versions and unstable features.
        """
        if no_cache or not self.versions_cache:
            resp = await self.api.request(Method.GET, Path.versions)
            self.versions_cache = VersionsResponse.deserialize(resp)
        return self.versions_cache

    @classmethod
    async def discover(cls, domain: str, session: ClientSession | None = None) -> URL | None:
        """
        Follow the server discovery spec to find the actual URL when given a Matrix server name.

        Args:
            domain: The server name (end of user ID) to discover.
            session: Optionally, the aiohttp ClientSession object to use.

        Returns:
            The parsed URL if the discovery succeeded.
            ``None`` if the request returned a 404 status.

        Raises:
            WellKnownError: for other errors
        """
        if session is None:
            async with ClientSession(headers={"User-Agent": HTTPAPI.default_ua}) as sess:
                return await cls._discover(domain, sess)
        else:
            return await cls._discover(domain, session)

    @classmethod
    async def _discover(cls, domain: str, session: ClientSession) -> URL | None:
        well_known = URL.build(scheme="https", host=domain, path="/.well-known/matrix/client")
        async with session.get(well_known) as resp:
            if resp.status == 404:
                return None
            elif resp.status != 200:
                raise WellKnownUnexpectedStatus(resp.status)
            try:
                data = await resp.json(content_type=None)
            except (json.JSONDecodeError, ContentTypeError) as e:
                raise WellKnownNotJSON() from e

        try:
            homeserver_url = data["m.homeserver"]["base_url"]
        except KeyError as e:
            raise WellKnownMissingHomeserver() from e
        parsed_url = URL(homeserver_url)
        if not parsed_url.is_absolute():
            raise WellKnownNotURL()
        elif parsed_url.scheme not in ("http", "https"):
            raise WellKnownUnsupportedScheme(parsed_url.scheme)

        try:
            async with session.get(parsed_url / "_matrix/client/versions") as resp:
                data = VersionsResponse.deserialize(await resp.json())
                if len(data.versions) == 0:
                    raise ValueError("no versions defined in /_matrix/client/versions response")
        except (ClientError, json.JSONDecodeError, SerializerError, ValueError) as e:
            raise WellKnownInvalidVersionsResponse() from e

        return parsed_url
