# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations

import asyncio
import collections
import enum
import functools
import logging
import warnings
from collections.abc import Awaitable, Callable
from typing import (
    Any,
    Protocol,
    TypeVar,
    overload,
)

import pyee
import pyee.asyncio
from typing_extensions import Self

from bumble.colors import color

# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)


# -----------------------------------------------------------------------------
def setup_event_forwarding(emitter, forwarder, event_name):
    def emit(*args, **kwargs):
        forwarder.emit(event_name, *args, **kwargs)

    emitter.on(event_name, emit)


# -----------------------------------------------------------------------------
def wrap_async(function):
    """
    Wraps the provided function in an async function.
    """
    return functools.partial(async_call, function)


# -----------------------------------------------------------------------------
def deprecated(msg: str):
    """
    Throw deprecation warning before execution.
    """

    def wrapper(function):
        @functools.wraps(function)
        def inner(*args, **kwargs):
            warnings.warn(msg, DeprecationWarning, stacklevel=2)
            return function(*args, **kwargs)

        return inner

    return wrapper


# -----------------------------------------------------------------------------
def experimental(msg: str):
    """
    Throws a future warning before execution.
    """

    def wrapper(function):
        @functools.wraps(function)
        def inner(*args, **kwargs):
            warnings.warn(msg, FutureWarning, stacklevel=2)
            return function(*args, **kwargs)

        return inner

    return wrapper


# -----------------------------------------------------------------------------
def composite_listener(cls):
    """
    Decorator that adds a `register` and `deregister` method to a class, which
    registers/deregisters all methods named `on_<event_name>` as a listener for
    the <event_name> event with an emitter.
    """
    # pylint: disable=protected-access

    def register(self, emitter):
        for method_name in dir(cls):
            if method_name.startswith('on_'):
                emitter.on(method_name[3:], getattr(self, method_name))

    def deregister(self, emitter):
        for method_name in dir(cls):
            if method_name.startswith('on_'):
                emitter.remove_listener(method_name[3:], getattr(self, method_name))

    cls._bumble_register_composite = register
    cls._bumble_deregister_composite = deregister
    return cls


# -----------------------------------------------------------------------------
_Handler = TypeVar('_Handler', bound=Callable)


class EventWatcher:
    '''A wrapper class to control the lifecycle of event handlers better.

    Usage:
    ```
    watcher = EventWatcher()

    def on_foo():
        ...
    watcher.on(emitter, 'foo', on_foo)

    @watcher.on(emitter, 'bar')
    def on_bar():
        ...

    # Close all event handlers watching through this watcher
    watcher.close()
    ```

    As context:
    ```
    with contextlib.closing(EventWatcher()) as context:
        @context.on(emitter, 'foo')
        def on_foo():
            ...
    # on_foo() has been removed here!
    ```
    '''

    handlers: list[tuple[pyee.EventEmitter, str, Callable[..., Any]]]

    def __init__(self) -> None:
        self.handlers = []

    @overload
    def on(
        self, emitter: pyee.EventEmitter, event: str
    ) -> Callable[[_Handler], _Handler]: ...

    @overload
    def on(
        self, emitter: pyee.EventEmitter, event: str, handler: _Handler
    ) -> _Handler: ...

    def on(
        self, emitter: pyee.EventEmitter, event: str, handler: _Handler | None = None
    ) -> _Handler | Callable[[_Handler], _Handler]:
        '''Watch an event until the context is closed.

        Args:
            emitter: EventEmitter to watch
            event: Event name
            handler: (Optional) Event handler. When nothing is passed, this method
            works as a decorator.
        '''

        def wrapper(wrapped: _Handler) -> _Handler:
            self.handlers.append((emitter, event, wrapped))
            emitter.on(event, wrapped)
            return wrapped

        return wrapper if handler is None else wrapper(handler)

    @overload
    def once(
        self, emitter: pyee.EventEmitter, event: str
    ) -> Callable[[_Handler], _Handler]: ...

    @overload
    def once(
        self, emitter: pyee.EventEmitter, event: str, handler: _Handler
    ) -> _Handler: ...

    def once(
        self, emitter: pyee.EventEmitter, event: str, handler: _Handler | None = None
    ) -> _Handler | Callable[[_Handler], _Handler]:
        '''Watch an event for once.

        Args:
            emitter: EventEmitter to watch
            event: Event name
            handler: (Optional) Event handler. When nothing passed, this method works
            as a decorator.
        '''

        def wrapper(wrapped: _Handler) -> _Handler:
            self.handlers.append((emitter, event, wrapped))
            emitter.once(event, wrapped)
            return wrapped

        return wrapper if handler is None else wrapper(handler)

    def close(self) -> None:
        for emitter, event, handler in self.handlers:
            if handler in emitter.listeners(event):
                emitter.remove_listener(event, handler)


# -----------------------------------------------------------------------------
_T = TypeVar('_T')


def cancel_on_event(
    emitter: pyee.EventEmitter, event: str, awaitable: Awaitable[_T]
) -> Awaitable[_T]:
    """Set a coroutine or future to cancel when an event occur."""
    future = asyncio.ensure_future(awaitable)
    if future.done():
        return future

    def on_event(*args, **kwargs) -> None:
        del args, kwargs
        if future.done():
            return
        msg = f'abort: {event} event occurred.'
        if isinstance(future, asyncio.Task):
            future.cancel(msg)
        else:
            future.set_exception(asyncio.CancelledError(msg))

    def on_done(_):
        emitter.remove_listener(event, on_event)

    emitter.on(event, on_event)
    future.add_done_callback(on_done)
    return future


# -----------------------------------------------------------------------------
class EventEmitter(pyee.asyncio.AsyncIOEventEmitter):
    """A Base EventEmitter for Bumble."""

    @deprecated("Use `cancel_on_event` instead.")
    def abort_on(self, event: str, awaitable: Awaitable[_T]) -> Awaitable[_T]:
        """Set a coroutine or future to abort when an event occur."""
        return cancel_on_event(self, event, awaitable)


# -----------------------------------------------------------------------------
class CompositeEventEmitter(EventEmitter):
    def __init__(self):
        super().__init__()
        self._listener = None

    @property
    def listener(self):
        return self._listener

    @listener.setter
    def listener(self, listener):
        # pylint: disable=protected-access
        if self._listener:
            # Call the deregistration methods for each base class that has them
            for cls in self._listener.__class__.mro():
                if '_bumble_register_composite' in cls.__dict__:
                    cls._bumble_deregister_composite(self._listener, self)
        self._listener = listener
        if listener:
            # Call the registration methods for each base class that has them
            for cls in listener.__class__.mro():
                if '_bumble_deregister_composite' in cls.__dict__:
                    cls._bumble_register_composite(listener, self)


# -----------------------------------------------------------------------------
class AsyncRunner:
    class WorkQueue:
        def __init__(self, create_task=True):
            self.queue = None
            self.task = None
            self.create_task = create_task

        def enqueue(self, coroutine):
            # Create a task now if we need to and haven't done so already
            if self.create_task and self.task is None:
                self.task = asyncio.create_task(self.run())

            # Lazy-create the coroutine queue
            if self.queue is None:
                self.queue = asyncio.Queue()

            # Enqueue the work
            self.queue.put_nowait(coroutine)

        async def run(self):
            while True:
                item = await self.queue.get()
                try:
                    await item
                except Exception:
                    logger.exception(color("!!! Exception in work queue", "red"))

    # Shared default queue
    default_queue = WorkQueue()

    # Shared set of running tasks
    running_tasks: set[Awaitable] = set()

    @staticmethod
    def run_in_task(queue=None):
        """
        Function decorator used to adapt an async function into a sync function
        """

        def decorator(func):
            @functools.wraps(func)
            def wrapper(*args, **kwargs):
                coroutine = func(*args, **kwargs)
                if queue is None:
                    # Spawn the coroutine as a task
                    async def run():
                        try:
                            await coroutine
                        except Exception:
                            logger.exception(color("!!! Exception in wrapper:", "red"))

                    AsyncRunner.spawn(run())
                else:
                    # Queue the coroutine to be awaited by the work queue
                    queue.enqueue(coroutine)

            return wrapper

        return decorator

    @staticmethod
    def spawn(coroutine):
        """
        Spawn a task to run a coroutine in a "fire and forget" mode.

        Using this method instead of just calling `asyncio.create_task(coroutine)`
        is necessary when you don't keep a reference to the task, because `asyncio`
        only keeps weak references to alive tasks.
        """
        task = asyncio.create_task(coroutine)
        AsyncRunner.running_tasks.add(task)
        task.add_done_callback(AsyncRunner.running_tasks.remove)


# -----------------------------------------------------------------------------
class FlowControlAsyncPipe:
    """
    Asyncio pipe with flow control. When writing to the pipe, the source is
    paused (by calling a function passed in when the pipe is created) if the
    amount of queued data exceeds a specified threshold.
    """

    def __init__(
        self,
        pause_source,
        resume_source,
        write_to_sink=None,
        drain_sink=None,
        threshold=0,
    ):
        self.pause_source = pause_source
        self.resume_source = resume_source
        self.write_to_sink = write_to_sink
        self.drain_sink = drain_sink
        self.threshold = threshold
        self.queue = collections.deque()  # Queue of packets
        self.queued_bytes = 0  # Number of bytes in the queue
        self.ready_to_pump = asyncio.Event()
        self.paused = False
        self.source_paused = False
        self.pump_task = None

    def start(self):
        if self.pump_task is None:
            self.pump_task = asyncio.create_task(self.pump())

        self.check_pump()

    def stop(self):
        if self.pump_task is not None:
            self.pump_task.cancel()
            self.pump_task = None

    def write(self, packet):
        self.queued_bytes += len(packet)
        self.queue.append(packet)

        # Pause the source if we're over the threshold
        if self.queued_bytes > self.threshold and not self.source_paused:
            logger.debug(f'pausing source (queued={self.queued_bytes})')
            self.pause_source()
            self.source_paused = True

        self.check_pump()

    def pause(self):
        if not self.paused:
            self.paused = True
            if not self.source_paused:
                self.pause_source()
                self.source_paused = True
            self.check_pump()

    def resume(self):
        if self.paused:
            self.paused = False
            if self.source_paused:
                self.resume_source()
                self.source_paused = False
            self.check_pump()

    def can_pump(self):
        return self.queue and not self.paused and self.write_to_sink is not None

    def check_pump(self):
        if self.can_pump():
            self.ready_to_pump.set()
        else:
            self.ready_to_pump.clear()

    async def pump(self):
        while True:
            # Wait until we can try to pump packets
            await self.ready_to_pump.wait()

            # Try to pump a packet
            if self.can_pump():
                packet = self.queue.pop()
                self.write_to_sink(packet)
                self.queued_bytes -= len(packet)

                # Drain the sink if we can
                if self.drain_sink:
                    await self.drain_sink()

                # Check if we can accept more
                if self.queued_bytes <= self.threshold and self.source_paused:
                    logger.debug(f'resuming source (queued={self.queued_bytes})')
                    self.source_paused = False
                    self.resume_source()

            self.check_pump()


# -----------------------------------------------------------------------------
async def async_call(function, *args, **kwargs):
    """
    Immediately calls the function with provided args and kwargs, wrapping it in an
    async function.
    Rust's `pyo3_asyncio` library needs functions to be marked async to properly inject
    a running loop.

    result = await async_call(some_function, ...)
    """
    return function(*args, **kwargs)


# -----------------------------------------------------------------------------
class OpenIntEnum(enum.IntEnum):
    """
    Subclass of enum.IntEnum that can hold integer values outside the set of
    predefined values. This is convenient for implementing protocols where some
    integer constants may be added over time.
    """

    @classmethod
    def _missing_(cls, value):
        if not isinstance(value, int):
            return None

        obj = int.__new__(cls, value)
        obj._value_ = value
        obj._name_ = f"{cls.__name__}[{value}]"
        return obj


# -----------------------------------------------------------------------------
class CompatibleIntFlag(enum.IntFlag):
    """
    Subclass of `enum.IntFlag` with a `composite_name` property that behaves like the
    `name` property of the `enum.IntFlag` implementation for python vesions >= 3.11
    """

    @property
    def composite_name(self) -> str:
        return '|'.join(
            name
            for flag in self.__class__
            if self.value & flag.value and (name := flag.name) is not None
        )


# -----------------------------------------------------------------------------
class ByteSerializable(Protocol):
    """
    Type protocol for classes that can be instantiated from bytes and serialized
    to bytes.
    """

    @classmethod
    def from_bytes(cls, data: bytes) -> Self: ...

    def __bytes__(self) -> bytes: ...


# -----------------------------------------------------------------------------
class IntConvertible(Protocol):
    """
    Type protocol for classes that can be instantiated from int and converted to int.
    """

    def __init__(self, value: int) -> None: ...
    def __int__(self) -> int: ...


# -----------------------------------------------------------------------------
def crc_16(data: bytes) -> int:
    """Calculate CRC-16-IBM of given data.

    Polynomial = x^16 + x^15 + x^2 + 1 = 0x8005 or 0xA001(Reversed)
    """
    crc = 0x0000
    for byte in data:
        crc ^= byte
        for _ in range(8):
            if (crc & 0x0001) > 0:
                crc = (crc >> 1) ^ 0xA001
            else:
                crc = crc >> 1
    return crc
