# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations

import logging
import struct
from collections.abc import Callable
from enum import IntEnum

from bumble import core, l2cap
from bumble.colors import color

# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)


# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
AVCTP_PSM = 0x0017
AVCTP_BROWSING_PSM = 0x001B


# -----------------------------------------------------------------------------
class MessageAssembler:
    Callback = Callable[[int, bool, bool, int, bytes], None]

    transaction_label: int
    pid: int
    c_r: int
    ipid: int
    payload: bytes
    number_of_packets: int
    packets_received: int

    def __init__(self, callback: Callback) -> None:
        self.callback = callback
        self.reset()

    def reset(self) -> None:
        self.packets_received = 0
        self.transaction_label = -1
        self.pid = -1
        self.c_r = -1
        self.ipid = -1
        self.payload = b''
        self.number_of_packets = 0
        self.packet_count = 0

    def on_pdu(self, pdu: bytes) -> None:
        self.packets_received += 1

        transaction_label = pdu[0] >> 4
        packet_type = Protocol.PacketType((pdu[0] >> 2) & 3)
        c_r = (pdu[0] >> 1) & 1
        ipid = pdu[0] & 1

        if c_r == 0 and ipid != 0:
            logger.warning("invalid IPID in command frame")
            self.reset()
            return

        pid_offset = 1
        if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.START):
            if self.transaction_label >= 0:
                # We are already in a transaction
                logger.warning("received START or SINGLE fragment while in transaction")
                self.reset()
                self.packets_received = 1

            if packet_type == Protocol.PacketType.START:
                self.number_of_packets = pdu[1]
                pid_offset = 2

        pid = struct.unpack_from(">H", pdu, pid_offset)[0]
        self.payload += pdu[pid_offset + 2 :]

        if packet_type in (Protocol.PacketType.CONTINUE, Protocol.PacketType.END):
            if transaction_label != self.transaction_label:
                logger.warning("transaction label does not match")
                self.reset()
                return

            if pid != self.pid:
                logger.warning("PID does not match")
                self.reset()
                return

            if c_r != self.c_r:
                logger.warning("C/R does not match")
                self.reset()
                return

            if self.packets_received > self.number_of_packets:
                logger.warning("too many fragments in transaction")
                self.reset()
                return

            if packet_type == Protocol.PacketType.END:
                if self.packets_received != self.number_of_packets:
                    logger.warning("premature END")
                    self.reset()
                    return
        else:
            self.transaction_label = transaction_label
            self.c_r = c_r
            self.ipid = ipid
            self.pid = pid

        if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.END):
            self.on_message_complete()

    def on_message_complete(self):
        try:
            self.callback(
                self.transaction_label,
                self.c_r == 0,
                self.ipid != 0,
                self.pid,
                self.payload,
            )
        except Exception:
            logger.exception(color("!!! exception in callback", "red"))

        self.reset()


# -----------------------------------------------------------------------------
class Protocol:
    CommandHandler = Callable[[int, bytes], None]
    command_handlers: dict[int, CommandHandler]  # Command handlers, by PID
    ResponseHandler = Callable[[int, bytes | None], None]
    response_handlers: dict[int, ResponseHandler]  # Response handlers, by PID
    next_transaction_label: int
    message_assembler: MessageAssembler

    class PacketType(IntEnum):
        SINGLE = 0b00
        START = 0b01
        CONTINUE = 0b10
        END = 0b11

    def __init__(self, l2cap_channel: l2cap.ClassicChannel) -> None:
        self.command_handlers = {}
        self.response_handlers = {}
        self.l2cap_channel = l2cap_channel
        self.message_assembler = MessageAssembler(self.on_message)

        # Register to receive PDUs from the channel
        l2cap_channel.sink = self.on_pdu
        l2cap_channel.on(l2cap_channel.EVENT_OPEN, self.on_l2cap_channel_open)
        l2cap_channel.on(l2cap_channel.EVENT_CLOSE, self.on_l2cap_channel_close)

    def on_l2cap_channel_open(self):
        logger.debug(color("<<< AVCTP channel open", "magenta"))

    def on_l2cap_channel_close(self):
        logger.debug(color("<<< AVCTP channel closed", "magenta"))

    def on_pdu(self, pdu: bytes) -> None:
        self.message_assembler.on_pdu(pdu)

    def on_message(
        self,
        transaction_label: int,
        is_command: bool,
        ipid: bool,
        pid: int,
        payload: bytes,
    ) -> None:
        logger.debug(
            f"<<< AVCTP Message: pid={pid}, "
            f"transaction_label={transaction_label}, "
            f"is_command={is_command}, "
            f"ipid={ipid}, "
            f"payload={payload.hex()}"
        )

        # Check for invalid PID responses.
        if ipid:
            logger.debug(f"received IPID for PID={pid}")

        # Find the appropriate handler.
        if is_command:
            if pid not in self.command_handlers:
                logger.warning(f"no command handler for PID {pid}")
                self.send_ipid(transaction_label, pid)
                return

            self.command_handlers[pid](transaction_label, payload)
        else:
            if pid not in self.response_handlers:
                logger.warning(f"no response handler for PID {pid}")
                return

            # By convention, for an ipid, send a None payload to the response handler.
            response_payload = None if ipid else payload
            self.response_handlers[pid](transaction_label, response_payload)

    def send_message(
        self,
        transaction_label: int,
        is_command: bool,
        ipid: bool,
        pid: int,
        payload: bytes,
    ):
        # TODO: fragment large messages
        packet_type = Protocol.PacketType.SINGLE
        pdu = (
            struct.pack(
                ">BH",
                transaction_label << 4
                | packet_type << 2
                | (0 if is_command else 1) << 1
                | (1 if ipid else 0),
                pid,
            )
            + payload
        )
        self.l2cap_channel.write(pdu)

    def send_command(self, transaction_label: int, pid: int, payload: bytes) -> None:
        logger.debug(
            ">>> AVCTP command: "
            f"transaction_label={transaction_label}, "
            f"pid={pid}, "
            f"payload={payload.hex()}"
        )
        self.send_message(transaction_label, True, False, pid, payload)

    def send_response(self, transaction_label: int, pid: int, payload: bytes):
        logger.debug(
            ">>> AVCTP response: "
            f"transaction_label={transaction_label}, "
            f"pid={pid}, "
            f"payload={payload.hex()}"
        )
        self.send_message(transaction_label, False, False, pid, payload)

    def send_ipid(self, transaction_label: int, pid: int) -> None:
        logger.debug(
            f">>> AVCTP ipid: transaction_label={transaction_label}, pid={pid}"
        )
        self.send_message(transaction_label, False, True, pid, b'')

    def register_command_handler(
        self, pid: int, handler: Protocol.CommandHandler
    ) -> None:
        self.command_handlers[pid] = handler

    def unregister_command_handler(
        self, pid: int, handler: Protocol.CommandHandler
    ) -> None:
        if pid not in self.command_handlers or self.command_handlers[pid] != handler:
            raise core.InvalidArgumentError("command handler not registered")
        del self.command_handlers[pid]

    def register_response_handler(
        self, pid: int, handler: Protocol.ResponseHandler
    ) -> None:
        self.response_handlers[pid] = handler

    def unregister_response_handler(
        self, pid: int, handler: Protocol.ResponseHandler
    ) -> None:
        if pid not in self.response_handlers or self.response_handlers[pid] != handler:
            raise core.InvalidArgumentError("response handler not registered")
        del self.response_handlers[pid]
