# Copyright (c) 2013-2021 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
#     http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
#    GNU General Public License, Version 2.0, or any later versions of
#    that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
#     Ron Frederick - initial implementation, API, and documentation

"""SSH packet encoding and decoding functions"""

from typing import Any, Callable, Iterable, Mapping, Optional, Sequence, Union

from .logging import SSHLogger
from .misc import plural


_LoggedPacket = Union[bytes, 'SSHPacket']
_PacketHandler = Callable[[Any, int, int, 'SSHPacket'], None]


class PacketDecodeError(ValueError):
    """Packet decoding error"""


def Byte(value: int) -> bytes:
    """Encode a single byte"""

    return bytes((value,))


def Boolean(value: bool) -> bytes:
    """Encode a boolean value"""

    return Byte(bool(value))


def UInt16(value: int) -> bytes:
    """Encode a 16-bit integer value"""

    return value.to_bytes(2, 'big')


def UInt32(value: int) -> bytes:
    """Encode a 32-bit integer value"""

    return value.to_bytes(4, 'big')


def UInt64(value: int) -> bytes:
    """Encode a 64-bit integer value"""

    return value.to_bytes(8, 'big')


def String(value: Union[bytes, str]) -> bytes:
    """Encode a byte string or UTF-8 string value"""

    if isinstance(value, str):
        value = value.encode('utf-8', errors='strict')

    return len(value).to_bytes(4, 'big') + value


def MPInt(value: int) -> bytes:
    """Encode a multiple precision integer value"""

    l = value.bit_length()
    l += (l % 8 == 0 and value != 0 and value != -1 << (l - 1))
    l = (l + 7) // 8

    return l.to_bytes(4, 'big') + value.to_bytes(l, 'big', signed=True)


def NameList(value: Iterable[bytes]) -> bytes:
    """Encode a comma-separated list of byte strings"""

    return String(b','.join(value))


class SSHPacket:
    """Decoder class for SSH packets"""

    def __init__(self, packet: bytes):
        self._packet = packet
        self._idx = 0
        self._len = len(packet)

    def __bool__(self) -> bool:
        return self._idx != self._len

    def check_end(self) -> None:
        """Confirm that all of the data in the packet has been consumed"""

        if self:
            raise PacketDecodeError('Unexpected data at end of packet')

    def get_consumed_payload(self) -> bytes:
        """Return the portion of the packet consumed so far"""

        return self._packet[:self._idx]

    def get_remaining_payload(self) -> bytes:
        """Return the portion of the packet not yet consumed"""

        return self._packet[self._idx:]

    def get_full_payload(self) -> bytes:
        """Return the full packet"""

        return self._packet

    def get_bytes(self, size: int) -> bytes:
        """Extract the requested number of bytes from the packet"""

        if self._idx + size > self._len:
            raise PacketDecodeError('Incomplete packet')

        value = self._packet[self._idx:self._idx+size]
        self._idx += size
        return value

    def get_byte(self) -> int:
        """Extract a single byte from the packet"""

        return self.get_bytes(1)[0]

    def get_boolean(self) -> bool:
        """Extract a boolean from the packet"""

        return bool(self.get_byte())

    def get_uint16(self) -> int:
        """Extract a 16-bit integer from the packet"""

        return int.from_bytes(self.get_bytes(2), 'big')

    def get_uint32(self) -> int:
        """Extract a 32-bit integer from the packet"""

        return int.from_bytes(self.get_bytes(4), 'big')

    def get_uint64(self) -> int:
        """Extract a 64-bit integer from the packet"""

        return int.from_bytes(self.get_bytes(8), 'big')

    def get_string(self) -> bytes:
        """Extract a UTF-8 string from the packet"""

        return self.get_bytes(self.get_uint32())

    def get_mpint(self) -> int:
        """Extract a multiple precision integer from the packet"""

        return int.from_bytes(self.get_string(), 'big', signed=True)

    def get_namelist(self) -> Sequence[bytes]:
        """Extract a comma-separated list of byte strings from the packet"""

        namelist = self.get_string()
        return namelist.split(b',') if namelist else []


class SSHPacketLogger:
    """Parent class for SSH packet loggers"""

    _handler_names: Mapping[int, str] = {}

    @property
    def logger(self) -> SSHLogger:
        """The logger to use for packet logging"""

        raise NotImplementedError

    def _log_packet(self, msg: str, pkttype: int, pktid: Optional[int],
                    packet: _LoggedPacket, note: str) -> None:
        """Log a sent/received packet"""

        if isinstance(packet, SSHPacket):
            packet = packet.get_full_payload()

        try:
            name = '%s (%d)' % (self._handler_names[pkttype], pkttype)
        except KeyError:
            name = 'packet type %d' % pkttype

        count = plural(len(packet), 'byte')

        if note:
            note = ' (%s)' % note

        self.logger.packet(pktid, packet, '%s %s, %s%s',
                           msg, name, count, note)

    def log_sent_packet(self, pkttype: int, pktid: Optional[int],
                        packet: _LoggedPacket, note: str = '') -> None:
        """Log a sent packet"""

        self._log_packet('Sent', pkttype, pktid, packet, note)


    def log_received_packet(self, pkttype: int, pktid: Optional[int],
                            packet: _LoggedPacket, note: str = '') -> None:
        """Log a received packet"""

        self._log_packet('Received', pkttype, pktid, packet, note)


class SSHPacketHandler(SSHPacketLogger):
    """Parent class for SSH packet handlers"""

    _packet_handlers: Mapping[int, _PacketHandler] = {}

    @property
    def logger(self) -> SSHLogger:
        """The logger associated with this packet handler"""

        raise NotImplementedError

    def process_packet(self, pkttype: int, pktid: int,
                       packet: SSHPacket) -> bool:
        """Log and process a received packet"""

        if pkttype in self._packet_handlers:
            self._packet_handlers[pkttype](self, pkttype, pktid, packet)
            return True
        else:
            return False
