# Copyright 2017-2020 Palantir Technologies, Inc.
# Copyright 2021- Python Language Server Contributors.

import functools
import inspect
import logging
import os
import pathlib
import re
import subprocess
import sys
import threading
import time
from typing import Optional

import docstring_to_markdown
import jedi

JEDI_VERSION = jedi.__version__

# Eol chars accepted by the LSP protocol
# the ordering affects performance
EOL_CHARS = ["\r\n", "\r", "\n"]
EOL_REGEX = re.compile(f"({'|'.join(EOL_CHARS)})")

log = logging.getLogger(__name__)


def debounce(interval_s, keyed_by=None):
    """Debounce calls to this function until interval_s seconds have passed."""

    def wrapper(func):
        timers = {}
        lock = threading.Lock()

        @functools.wraps(func)
        def debounced(*args, **kwargs):
            sig = inspect.signature(func)
            call_args = sig.bind(*args, **kwargs)
            key = call_args.arguments[keyed_by] if keyed_by else None

            def run():
                with lock:
                    del timers[key]
                return func(*args, **kwargs)

            with lock:
                old_timer = timers.get(key)
                if old_timer:
                    old_timer.cancel()

                timer = threading.Timer(interval_s, run)
                timers[key] = timer
                timer.start()

        return debounced

    return wrapper


def throttle(seconds=1):
    """Throttles calls to a function every `seconds` seconds."""

    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            if not hasattr(wrapper, "last_call"):
                wrapper.last_call = 0
            if time.time() - wrapper.last_call >= seconds:
                wrapper.last_call = time.time()
                return func(*args, **kwargs)

        return wrapper

    return decorator


def find_parents(root, path, names):
    """Find files matching the given names relative to the given path.

    Args:
        path (str): The file path to start searching up from.
        names (list[str]): The file/directory names to look for.
        root (str): The directory at which to stop recursing upwards.

    Note:
        The path MUST be within the root.
    """
    if not root:
        return []

    if not os.path.commonprefix((root, path)):
        log.warning("Path %r not in %r", path, root)
        return []

    # Split the relative by directory, generate all the parent directories, then check each of them.
    # This avoids running a loop that has different base-cases for unix/windows
    # e.g. /a/b and /a/b/c/d/e.py -> ['/a/b', 'c', 'd']
    dirs = [root] + os.path.relpath(os.path.dirname(path), root).split(os.path.sep)

    # Search each of /a/b/c, /a/b, /a
    while dirs:
        search_dir = os.path.join(*dirs)
        existing = list(
            filter(os.path.exists, [os.path.join(search_dir, n) for n in names])
        )
        if existing:
            return existing
        dirs.pop()

    # Otherwise nothing
    return []


def path_to_dot_name(path):
    """Given a path to a module, derive its dot-separated full name."""
    directory = os.path.dirname(path)
    module_name, _ = os.path.splitext(os.path.basename(path))
    full_name = [module_name]
    while os.path.exists(os.path.join(directory, "__init__.py")):
        this_directory = os.path.basename(directory)
        directory = os.path.dirname(directory)
        full_name = [this_directory] + full_name
    return ".".join(full_name)


def match_uri_to_workspace(uri, workspaces):
    if uri is None:
        return None
    max_len, chosen_workspace = -1, None
    path = pathlib.Path(uri).parts
    for workspace in workspaces:
        workspace_parts = pathlib.Path(workspace).parts
        if len(workspace_parts) > len(path):
            continue
        match_len = 0
        for workspace_part, path_part in zip(workspace_parts, path):
            if workspace_part == path_part:
                match_len += 1
        if match_len > 0:
            if match_len > max_len:
                max_len = match_len
                chosen_workspace = workspace
    return chosen_workspace


def list_to_string(value):
    return ",".join(value) if isinstance(value, list) else value


def merge_dicts(dict_a, dict_b):
    """Recursively merge dictionary b into dictionary a.

    If override_nones is True, then
    """

    def _merge_dicts_(a, b):
        for key in set(a.keys()).union(b.keys()):
            if key in a and key in b:
                if isinstance(a[key], dict) and isinstance(b[key], dict):
                    yield (key, dict(_merge_dicts_(a[key], b[key])))
                elif isinstance(a[key], list) and isinstance(b[key], list):
                    yield (key, list(set(a[key] + b[key])))
                elif b[key] is not None:
                    yield (key, b[key])
                else:
                    yield (key, a[key])
            elif key in a:
                yield (key, a[key])
            elif b[key] is not None:
                yield (key, b[key])

    return dict(_merge_dicts_(dict_a, dict_b))


def escape_plain_text(contents: str) -> str:
    """
    Format plain text to display nicely in environments which do not respect whitespaces.
    """
    contents = contents.replace("\t", "\u00a0" * 4)
    contents = contents.replace("  ", "\u00a0" * 2)
    return contents


def escape_markdown(contents: str) -> str:
    """
    Format plain text to display nicely in Markdown environment.
    """
    # escape markdown syntax
    contents = re.sub(r"([\\*_#[\]])", r"\\\1", contents)
    # preserve white space characters
    contents = escape_plain_text(contents)
    return contents


def wrap_signature(signature):
    return "```python\n" + signature + "\n```\n"


SERVER_SUPPORTED_MARKUP_KINDS = {"markdown", "plaintext"}


def choose_markup_kind(client_supported_markup_kinds: list[str]):
    """Choose a markup kind supported by both client and the server.

    This gives priority to the markup kinds provided earlier on the client preference list.
    """
    for kind in client_supported_markup_kinds:
        if kind in SERVER_SUPPORTED_MARKUP_KINDS:
            return kind
    return "markdown"


class Formatter:
    command: list[str]

    @property
    def is_installed(self) -> bool:
        """Returns whether formatter is available"""
        if not hasattr(self, "_is_installed"):
            self._is_installed = self._is_available_via_cli()
        return self._is_installed

    def format(self, code: str, line_length: int) -> str:
        """Formats code"""
        return subprocess.check_output(
            [
                sys.executable,
                "-m",
                *self.command,
                "--line-length",
                str(line_length),
                "-",
            ],
            input=code,
            text=True,
        ).strip()

    def _is_available_via_cli(self) -> bool:
        try:
            subprocess.check_output(
                [
                    sys.executable,
                    "-m",
                    *self.command,
                    "--help",
                ],
            )
            return True
        except subprocess.CalledProcessError:
            return False


class RuffFormatter(Formatter):
    command = ["ruff", "format"]


class BlackFormatter(Formatter):
    command = ["black"]


formatters = {"ruff": RuffFormatter(), "black": BlackFormatter()}


def format_signature(signature: str, config: dict, signature_formatter: str) -> str:
    """Formats signature using ruff or black if either is available."""
    as_func = f"def {signature.strip()}:\n    pass"
    line_length = config.get("line_length", 88)
    formatter = formatters[signature_formatter]
    if formatter.is_installed:
        try:
            return (
                formatter.format(as_func, line_length=line_length)
                .removeprefix("def ")
                .removesuffix(":\n    pass")
            )
        except subprocess.CalledProcessError as e:
            log.warning("Signature formatter failed %s", e)
    else:
        log.warning(
            "Formatter %s was requested but it does not appear to be installed",
            signature_formatter,
        )
    return signature


def convert_signatures_to_markdown(signatures: list[str], config: dict) -> str:
    signature_formatter = config.get("formatter", "black")
    if signature_formatter:
        signatures = [
            format_signature(
                signature, signature_formatter=signature_formatter, config=config
            )
            for signature in signatures
        ]
    return wrap_signature("\n".join(signatures))


def format_docstring(
    contents: str,
    markup_kind: str,
    signatures: Optional[list[str]] = None,
    signature_config: Optional[dict] = None,
):
    """Transform the provided docstring into a MarkupContent object.

    If `markup_kind` is 'markdown' the docstring will get converted to
    markdown representation using `docstring-to-markdown`; if it is
    `plaintext`, it will be returned as plain text.
    Call signatures of functions (or equivalent code summaries)
    provided in optional `signatures` argument will be prepended
    to the provided contents of the docstring if given.
    """
    if not isinstance(contents, str):
        contents = ""

    if markup_kind == "markdown":
        wrapped_signatures = convert_signatures_to_markdown(
            signatures if signatures is not None else [], config=signature_config or {}
        )

        if contents != "":
            try:
                value = docstring_to_markdown.convert(contents)
            except docstring_to_markdown.UnknownFormatError:
                # try to escape the Markdown syntax instead:
                value = escape_markdown(contents)

            if signatures:
                value = wrapped_signatures + "\n\n" + value
        else:
            value = contents

            if signatures:
                value = wrapped_signatures

        return {"kind": "markdown", "value": value}
    value = contents
    if signatures:
        value = "\n".join(signatures) + "\n\n" + value
    return {"kind": "plaintext", "value": escape_plain_text(value)}


def clip_column(column, lines, line_number):
    """
    Normalise the position as per the LSP that accepts character positions > line length

    https://microsoft.github.io/language-server-protocol/specification#position
    """
    max_column = (
        len(lines[line_number].rstrip("\r\n")) if len(lines) > line_number else 0
    )
    return min(column, max_column)


def position_to_jedi_linecolumn(document, position):
    """
    Convert the LSP format 'line', 'character' to Jedi's 'line', 'column'

    https://microsoft.github.io/language-server-protocol/specification#position
    """
    code_position = {}
    if position:
        code_position = {
            "line": position["line"] + 1,
            "column": clip_column(
                position["character"], document.lines, position["line"]
            ),
        }
    return code_position


if os.name == "nt":
    import ctypes

    kernel32 = ctypes.windll.kernel32
    PROCESS_QUERY_INFROMATION = 0x1000

    def is_process_alive(pid):
        """Check whether the process with the given pid is still alive.

        Running `os.kill()` on Windows always exits the process, so it can't be used to check for an alive process.
        see: https://docs.python.org/3/library/os.html?highlight=os%20kill#os.kill

        Hence ctypes is used to check for the process directly via windows API avoiding any other 3rd-party dependency.

        Args:
            pid (int): process ID

        Returns:
            bool: False if the process is not alive or don't have permission to check, True otherwise.
        """
        process = kernel32.OpenProcess(PROCESS_QUERY_INFROMATION, 0, pid)
        if process != 0:
            kernel32.CloseHandle(process)
            return True
        return False

else:
    import errno

    def is_process_alive(pid):
        """Check whether the process with the given pid is still alive.

        Args:
            pid (int): process ID

        Returns:
            bool: False if the process is not alive or don't have permission to check, True otherwise.
        """
        if pid < 0:
            return False
        try:
            os.kill(pid, 0)
        except OSError as e:
            return e.errno == errno.EPERM
        return True


def get_eol_chars(text):
    """Get EOL chars used in text."""
    match = EOL_REGEX.search(text)
    if match:
        return match.group(0)
    return None
