#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Dummy router supporting IGD."""
# Instructions:
# - Change `SOURCE``. When using IPv6, be sure to set the scope_id, the last value in the tuple.
# - Run this module.
# - Run upnp-client (change IP to your own IP):
#    upnp-client call-action 'http://0.0.0.0:8000/device.xml' \
#                WANCIC/GetTotalPacketsReceived

import asyncio
import logging
import xml.etree.ElementTree as ET
from time import time
from typing import Dict, Mapping, Sequence, Tuple, Type, cast

from async_upnp_client.client import UpnpRequester, UpnpStateVariable
from async_upnp_client.const import (
    STATE_VARIABLE_TYPE_MAPPING,
    DeviceInfo,
    EventableStateVariableTypeInfo,
    ServiceInfo,
    StateVariableTypeInfo,
)

from async_upnp_client.profiles.igd import Pinhole, PortMappingEntry
from async_upnp_client.server import UpnpServer, UpnpServerDevice, UpnpServerService, callable_action

logging.basicConfig(level=logging.DEBUG)
LOGGER = logging.getLogger("dummy_router")
LOGGER_SSDP_TRAFFIC = logging.getLogger("async_upnp_client.traffic")
LOGGER_SSDP_TRAFFIC.setLevel(logging.WARNING)
SOURCE = ("192.168.178.54", 0)  # Your IP here!
# SOURCE = ("fe80::215:5dff:fe3e:6d23", 0, 0, 6)  # Your IP here!
HTTP_PORT = 8000


class WANIPv6FirewallControlService(UpnpServerService):
    """WANIPv6FirewallControl service."""

    SERVICE_DEFINITION = ServiceInfo(
        service_id="urn:upnp-org:serviceId:WANIPv6FirewallControl1",
        service_type="urn:schemas-upnp-org:service:WANIPv6FirewallControl:1",
        control_url="/upnp/control/WANIPv6FirewallControl1",
        event_sub_url="/upnp/event/WANIPv6FirewallControl1",
        scpd_url="/WANIPv6FirewallControl_1.xml",
        xml=ET.Element("server_service"),
    )

    STATE_VARIABLE_DEFINITIONS = {
        "FirewallEnabled": EventableStateVariableTypeInfo(
            data_type="boolean",
            data_type_mapping=STATE_VARIABLE_TYPE_MAPPING["boolean"],
            default_value="1",
            allowed_value_range={},
            allowed_values=None,
            max_rate=None,
            xml=ET.Element("server_stateVariable"),
        ),
        "InboundPinholeAllowed": EventableStateVariableTypeInfo(
            data_type="boolean",
            data_type_mapping=STATE_VARIABLE_TYPE_MAPPING["boolean"],
            default_value="1",
            allowed_value_range={},
            allowed_values=None,
            max_rate=None,
            xml=ET.Element("server_stateVariable"),
        ),
        "A_ARG_TYPE_IPv6Address": StateVariableTypeInfo(
            data_type="string",
            data_type_mapping=STATE_VARIABLE_TYPE_MAPPING["string"],
            default_value=None,
            allowed_value_range={},
            allowed_values=None,
            xml=ET.Element("server_stateVariable"),
        ),
        "A_ARG_TYPE_Port": StateVariableTypeInfo(
            data_type="ui2",
            data_type_mapping=STATE_VARIABLE_TYPE_MAPPING["ui2"],
            default_value=None,
            allowed_value_range={},
            allowed_values=None,
            xml=ET.Element("server_stateVariable"),
        ),
        "A_ARG_TYPE_Protocol": StateVariableTypeInfo(
            data_type="ui2",
            data_type_mapping=STATE_VARIABLE_TYPE_MAPPING["ui2"],
            default_value=None,
            allowed_value_range={},
            allowed_values=None,
            xml=ET.Element("server_stateVariable"),
        ),
        "A_ARG_TYPE_LeaseTime": StateVariableTypeInfo(
            data_type="ui4",
            data_type_mapping=STATE_VARIABLE_TYPE_MAPPING["ui4"],
            default_value=None,
            allowed_value_range={
                "min": "1",
                "max": "86400",
            },
            allowed_values=None,
            xml=ET.Element("server_stateVariable"),
        ),
        "A_ARG_TYPE_UniqueID": StateVariableTypeInfo(
            data_type="ui2",
            data_type_mapping=STATE_VARIABLE_TYPE_MAPPING["ui2"],
            default_value=None,
            allowed_value_range={},
            allowed_values=None,
            xml=ET.Element("server_stateVariable"),
        ),
    }

    def __init__(self, *args, **kwargs) -> None:
        """Initialize."""
        super().__init__(*args, **kwargs)
        self._pinholes: Dict[int, Pinhole] = {}
        self._next_pinhole_id = 0

    @callable_action(
        name="GetFirewallStatus",
        in_args={},
        out_args={
            "FirewallEnabled": "FirewallEnabled",
            "InboundPinholeAllowed": "InboundPinholeAllowed",
        },
    )
    async def get_firewall_status(self) -> Dict[str, UpnpStateVariable]:
        """Get firewall status."""
        return {
            "FirewallEnabled": self.state_variable("FirewallEnabled"),
            "InboundPinholeAllowed": self.state_variable("InboundPinholeAllowed"),
        }

    @callable_action(
        name="AddPinhole",
        in_args={
            "RemoteHost": "A_ARG_TYPE_IPv6Address",
            "RemotePort": "A_ARG_TYPE_Port",
            "InternalClient": "A_ARG_TYPE_IPv6Address",
            "InternalPort": "A_ARG_TYPE_Port",
            "Protocol": "A_ARG_TYPE_Protocol",
            "LeaseTime": "A_ARG_TYPE_LeaseTime",
        },
        out_args={
            "UniqueID": "A_ARG_TYPE_UniqueID",
        },
    )
    async def add_pinhole(self, RemoteHost: str, RemotePort: int, InternalClient: str, InternalPort: int, Protocol: int, LeaseTime: int) -> Dict[str, UpnpStateVariable]:
        """Add pinhole."""
        # pylint: disable=invalid-name
        pinhole_id = self._next_pinhole_id
        self._next_pinhole_id += 1
        pinhole = Pinhole(
            remote_host=RemoteHost,
            remote_port=RemotePort,
            internal_client=InternalClient,
            internal_port=InternalPort,
            protocol=Protocol,
            lease_time=LeaseTime,
        )
        self._pinholes[pinhole_id] = pinhole
        return {
            "UniqueID": pinhole_id,
        }

    @callable_action(
        name="UpdatePinhole",
        in_args={
            "UniqueID": "A_ARG_TYPE_UniqueID",
            "LeaseTime": "A_ARG_TYPE_LeaseTime",
        },
        out_args={},
    )
    async def update_pinhole(self, UniqueID: int, LeaseTime: int) -> Dict[str, UpnpStateVariable]:
        """Update pinhole."""
        # pylint: disable=invalid-name
        self._pinholes[UniqueID].lease_time = LeaseTime
        return {}

    @callable_action(
        name="DeletePinhole",
        in_args={
            "UniqueID": "A_ARG_TYPE_UniqueID",
        },
        out_args={},
    )
    async def delete_pinhole(self, UniqueID: int) -> Dict[str, UpnpStateVariable]:
        """Delete pinhole."""
        # pylint: disable=invalid-name
        del self._pinholes[UniqueID]
        return {}


class WANIPConnectionService(UpnpServerService):
    """WANIPConnection service."""

    SERVICE_DEFINITION = ServiceInfo(
        service_id="urn:upnp-org:serviceId:WANIPConnection1",
        service_type="urn:schemas-upnp-org:service:WANIPConnection:1",
        control_url="/upnp/control/WANIPConnection1",
        event_sub_url="/upnp/event/WANIPConnection1",
        scpd_url="/WANIPConnection_1.xml",
        xml=ET.Element("server_service"),
    )

    STATE_VARIABLE_DEFINITIONS = {
        "ExternalIPAddress": EventableStateVariableTypeInfo(
            data_type="string",
            data_type_mapping=STATE_VARIABLE_TYPE_MAPPING["string"],
            default_value="1.2.3.0",
            allowed_value_range={},
            allowed_values=None,
            max_rate=None,
            xml=ET.Element("server_stateVariable"),
        ),
        "ConnectionStatus": EventableStateVariableTypeInfo(
            data_type="string",
            data_type_mapping=STATE_VARIABLE_TYPE_MAPPING["string"],
            default_value="Unconfigured",
            allowed_value_range={},
            allowed_values=[
                "Unconfigured",
                "Authenticating",
                "Connecting",
                "Connected",
                "PendingDisconnect",
                "Disconnecting",
                "Disconnected",
            ],
            max_rate=None,
            xml=ET.Element("server_stateVariable"),
        ),
        "LastConnectionError": StateVariableTypeInfo(
            data_type="string",
            data_type_mapping=STATE_VARIABLE_TYPE_MAPPING["string"],
            default_value="ERROR_NONE",
            allowed_value_range={},
            allowed_values=[
                "ERROR_NONE",
            ],
            xml=ET.Element("server_stateVariable"),
        ),
        "Uptime": StateVariableTypeInfo(
            data_type="ui4",
            data_type_mapping=STATE_VARIABLE_TYPE_MAPPING["ui4"],
            default_value="0",
            allowed_value_range={},
            allowed_values=None,
            xml=ET.Element("server_stateVariable"),
        ),
        "RemoteHost": StateVariableTypeInfo(
            data_type="string",
            data_type_mapping=STATE_VARIABLE_TYPE_MAPPING["string"],
            default_value=None,
            allowed_value_range={},
            allowed_values=None,
            xml=ET.Element("server_stateVariable"),
        ),

        "ExternalPort": StateVariableTypeInfo(
            data_type="ui2",
            data_type_mapping=STATE_VARIABLE_TYPE_MAPPING["ui2"],
            default_value=None,
            allowed_value_range={},
            allowed_values=None,
            xml=ET.Element("server_stateVariable"),
        ),
        "PortMappingProtocol": StateVariableTypeInfo(
            data_type="string",
            data_type_mapping=STATE_VARIABLE_TYPE_MAPPING["string"],
            default_value=None,
            allowed_value_range={},
            allowed_values=["TCP", "UDP"],
            xml=ET.Element("server_stateVariable"),
        ),
        "InternalPort": StateVariableTypeInfo(
            data_type="ui2",
            data_type_mapping=STATE_VARIABLE_TYPE_MAPPING["ui2"],
            default_value=None,
            allowed_value_range={},
            allowed_values=None,
            xml=ET.Element("server_stateVariable"),
        ),
        "InternalClient": StateVariableTypeInfo(
            data_type="string",
            data_type_mapping=STATE_VARIABLE_TYPE_MAPPING["string"],
            default_value=None,
            allowed_value_range={},
            allowed_values=None,
            xml=ET.Element("server_stateVariable"),
        ),
        "PortMappingEnabled": StateVariableTypeInfo(
            data_type="boolean",
            data_type_mapping=STATE_VARIABLE_TYPE_MAPPING["boolean"],
            default_value=None,
            allowed_value_range={},
            allowed_values=None,
            xml=ET.Element("server_stateVariable"),
        ),
        "PortMappingDescription": StateVariableTypeInfo(
            data_type="string",
            data_type_mapping=STATE_VARIABLE_TYPE_MAPPING["string"],
            default_value=None,
            allowed_value_range={},
            allowed_values=None,
            xml=ET.Element("server_stateVariable"),
        ),
        "PortMappingLeaseDuration": StateVariableTypeInfo(
            data_type="ui4",
            data_type_mapping=STATE_VARIABLE_TYPE_MAPPING["ui4"],
            default_value=None,
            allowed_value_range={},
            allowed_values=None,
            xml=ET.Element("server_stateVariable"),
        ),
        "PortMappingNumberOfEntries": EventableStateVariableTypeInfo(
            data_type="ui2",
            data_type_mapping=STATE_VARIABLE_TYPE_MAPPING["ui2"],
            default_value="0",
            allowed_value_range={
                "min": "0",
                "max": "65535",
                "step": "1"
            },
            allowed_values=None,
            max_rate=0,
            xml=ET.Element("server_stateVariable"),
        )
    }

    def __init__(self, *args, **kwargs) -> None:
        """Initialize."""
        super().__init__(*args, **kwargs)
        self._port_mappings: Dict[Tuple[str, int, str, str], PortMappingEntry] = {}

    @callable_action(
        name="GetStatusInfo",
        in_args={},
        out_args={
            "NewConnectionStatus": "ConnectionStatus",
            "NewLastConnectionError": "LastConnectionError",
            "NewUptime": "Uptime",
        },
    )
    async def get_status_info(self) -> Dict[str, UpnpStateVariable]:
        """Get status info."""
        # from async_upnp_client.exceptions import UpnpActionError, UpnpActionErrorCode
        # raise UpnpActionError(
        #     error_code=UpnpActionErrorCode.INVALID_ACTION, error_desc="Invalid action"
        # )
        return {
            "NewConnectionStatus": self.state_variable("ConnectionStatus"),
            "NewLastConnectionError": self.state_variable("LastConnectionError"),
            "NewUptime": self.state_variable("Uptime"),
        }

    @callable_action(
        name="GetExternalIPAddress",
        in_args={},
        out_args={
            "NewExternalIPAddress": "ExternalIPAddress",
        },
    )
    async def get_external_ip_address(self) -> Dict[str, UpnpStateVariable]:
        """Get external IP address."""
        # from async_upnp_client.exceptions import UpnpActionError, UpnpActionErrorCode
        # raise UpnpActionError(
        #     error_code=UpnpActionErrorCode.INVALID_ACTION, error_desc="Invalid action"
        # )
        return {
            "NewExternalIPAddress": self.state_variable("ExternalIPAddress"),
        }

    @callable_action(
        name="AddPortMapping",
        in_args={
            "NewRemoteHost": "RemoteHost",
            "NewExternalPort": "ExternalPort",
            "NewProtocol": "PortMappingProtocol",
            "NewInternalPort": "InternalPort",
            "NewInternalClient": "InternalClient",
            "NewEnabled": "PortMappingEnabled",
            "NewPortMappingDescription": "PortMappingDescription",
            "NewLeaseDuration": "PortMappingLeaseDuration",
        },
        out_args={},
    )
    async def add_port_mapping(self, NewRemoteHost: str, NewExternalPort: int, NewProtocol: str, NewInternalPort: int, NewInternalClient: str, NewEnabled: bool, NewPortMappingDescription: str, NewLeaseDuration: int) ->  Dict[str, UpnpStateVariable]:
        """Add port mapping."""
        # pylint: disable=invalid-name
        key = (NewRemoteHost, NewExternalPort, NewProtocol)
        existing_port_mapping = key in self._port_mappings
        self._port_mappings[key] = PortMappingEntry(
            remote_host=NewRemoteHost,
            external_port=NewExternalPort,
            protocol=NewProtocol,
            internal_client=NewInternalClient,
            internal_port=NewInternalPort,
            enabled=NewEnabled,
            description=NewPortMappingDescription,
            lease_duration=NewLeaseDuration,
        )
        if not existing_port_mapping:
            self.state_variable("PortMappingNumberOfEntries").value += 1
        return {}

    @callable_action(
        name="DeletePortMapping",
        in_args={
            "NewRemoteHost": "RemoteHost",
            "NewExternalPort": "ExternalPort",
            "NewProtocol": "PortMappingProtocol",
        },
        out_args={},
    )
    async def delete_port_mapping(self, NewRemoteHost: str, NewExternalPort: int, NewProtocol: str) ->  Dict[str, UpnpStateVariable]:
        """Delete an existing port mapping entry."""
        # pylint: disable=invalid-name
        key = (NewRemoteHost, NewExternalPort, NewProtocol)
        del self._port_mappings[key]
        self.state_variable("PortMappingNumberOfEntries").value -= 1
        return {}


class WanConnectionDevice(UpnpServerDevice):
    """WAN Connection device."""

    DEVICE_DEFINITION = DeviceInfo(
        device_type="urn:schemas-upnp-org:device:WANConnectionDevice:1",
        friendly_name="Dummy Router WAN Connection Device",
        manufacturer="Steven",
        manufacturer_url=None,
        model_name="DummyRouter v1",
        model_url=None,
        udn="uuid:51e00c19-c8f3-4b28-9ef1-7f562f204c82",
        upc=None,
        model_description="Dummy Router IGD",
        model_number="v0.0.1",
        serial_number="0000001",
        presentation_url=None,
        url="/device.xml",
        icons=[],
        xml=ET.Element("server_device"),
    )
    EMBEDDED_DEVICES: Sequence[Type[UpnpServerDevice]] = []
    SERVICES = [WANIPConnectionService, WANIPv6FirewallControlService]

    def __init__(self, requester: UpnpRequester, base_uri: str, boot_id: int, config_id: int) -> None:
        """Initialize."""
        super().__init__(
            requester=requester,
            base_uri=base_uri,
            boot_id=boot_id,
            config_id=config_id,
        )


class WANCommonInterfaceConfigService(UpnpServerService):
    """WANCommonInterfaceConfig service."""

    SERVICE_DEFINITION = ServiceInfo(
        service_id="urn:upnp-org:serviceId:WANCommonInterfaceConfig1",
        service_type="urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1",
        control_url="/upnp/control/WANCommonInterfaceConfig1",
        event_sub_url="/upnp/event/WANCommonInterfaceConfig1",
        scpd_url="/WANCommonInterfaceConfig_1.xml",
        xml=ET.Element("server_service"),
    )

    STATE_VARIABLE_DEFINITIONS = {
        "TotalBytesReceived": StateVariableTypeInfo(
            data_type="ui4",
            data_type_mapping=STATE_VARIABLE_TYPE_MAPPING["ui4"],
            default_value="0",
            allowed_value_range={},
            allowed_values=None,
            xml=ET.Element("server_stateVariable"),
        ),
        "TotalBytesSent": StateVariableTypeInfo(
            data_type="ui4",
            data_type_mapping=STATE_VARIABLE_TYPE_MAPPING["ui4"],
            default_value="0",
            allowed_value_range={},
            allowed_values=None,
            xml=ET.Element("server_stateVariable"),
        ),
        "TotalPacketsReceived": StateVariableTypeInfo(
            data_type="ui4",
            data_type_mapping=STATE_VARIABLE_TYPE_MAPPING["ui4"],
            default_value="0",
            allowed_value_range={},
            allowed_values=None,
            xml=ET.Element("server_stateVariable"),
        ),
        "TotalPacketsSent": StateVariableTypeInfo(
            data_type="ui4",
            data_type_mapping=STATE_VARIABLE_TYPE_MAPPING["ui4"],
            default_value="0",
            allowed_value_range={},
            allowed_values=None,
            xml=ET.Element("server_stateVariable"),
        ),
    }

    MAX_COUNTER = 2**32

    def _update_bytes(self, state_var_name: str) -> None:
        """Update bytes state variable."""
        new_bytes = int(time() * 1000) % self.MAX_COUNTER
        self.state_variable(state_var_name).value = new_bytes

    def _update_packets(self, state_var_name: str) -> None:
        """Update state variable values."""
        new_packets = int(time()) % self.MAX_COUNTER
        self.state_variable(state_var_name).value = new_packets
        self.state_variable(state_var_name).value = new_packets

    @callable_action(
        name="GetTotalBytesReceived",
        in_args={},
        out_args={
            "NewTotalBytesReceived": "TotalBytesReceived",
        },
    )
    async def get_total_bytes_received(self) -> Dict[str, UpnpStateVariable]:
        """Get total bytes received."""
        self._update_bytes("TotalBytesReceived")
        return {
            "NewTotalBytesReceived": self.state_variable("TotalBytesReceived"),
        }

    @callable_action(
        name="GetTotalBytesSent",
        in_args={},
        out_args={
            "NewTotalBytesSent": "TotalBytesSent",
        },
    )
    async def get_total_bytes_sent(self) -> Dict[str, UpnpStateVariable]:
        """Get total bytes sent."""
        self._update_bytes("TotalBytesSent")
        return {
            "NewTotalBytesSent": self.state_variable("TotalBytesSent"),
        }

    @callable_action(
        name="GetTotalPacketsReceived",
        in_args={},
        out_args={
            "NewTotalPacketsReceived": "TotalPacketsReceived",
        },
    )
    async def get_total_packets_received(self) -> Dict[str, UpnpStateVariable]:
        """Get total packets received."""
        self._update_packets("TotalPacketsReceived")
        return {
            "NewTotalPacketsReceived": self.state_variable("TotalPacketsReceived"),
        }

    @callable_action(
        name="GetTotalPacketsSent",
        in_args={},
        out_args={
            "NewTotalPacketsSent": "TotalPacketsSent",
        },
    )
    async def get_total_packets_sent(self) -> Dict[str, UpnpStateVariable]:
        """Get total packets sent."""
        self._update_packets("TotalPacketsSent")
        return {
            "NewTotalPacketsSent": self.state_variable("TotalPacketsSent"),
        }


class WanDevice(UpnpServerDevice):
    """WAN device."""

    DEVICE_DEFINITION = DeviceInfo(
        device_type="urn:schemas-upnp-org:device:WANDevice:1",
        friendly_name="Dummy Router WAN Device",
        manufacturer="Steven",
        manufacturer_url=None,
        model_name="DummyRouter v1",
        model_url=None,
        udn="uuid:51e00c19-c8f3-4b28-9ef1-7f562f204c81",
        upc=None,
        model_description="Dummy Router IGD",
        model_number="v0.0.1",
        serial_number="0000001",
        presentation_url=None,
        url="/device.xml",
        icons=[],
        xml=ET.Element("server_device"),
    )
    EMBEDDED_DEVICES = [WanConnectionDevice]
    SERVICES = [WANCommonInterfaceConfigService]

    def __init__(self, requester: UpnpRequester, base_uri: str, boot_id: int, config_id: int) -> None:
        """Initialize."""
        super().__init__(
            requester=requester,
            base_uri=base_uri,
            boot_id=boot_id,
            config_id=config_id,
        )


class Layer3ForwardingService(UpnpServerService):
    """Layer3Forwarding service."""

    SERVICE_DEFINITION = ServiceInfo(
        service_id="urn:upnp-org:serviceId:Layer3Forwarding1",
        service_type="urn:schemas-upnp-org:service:Layer3Forwarding:1",
        control_url="/upnp/control/Layer3Forwarding1",
        event_sub_url="/upnp/event/Layer3Forwarding1",
        scpd_url="/Layer3Forwarding_1.xml",
        xml=ET.Element("server_service"),
    )

    STATE_VARIABLE_DEFINITIONS: Mapping[str, StateVariableTypeInfo] = {}


class IgdDevice(UpnpServerDevice):
    """IGD device."""

    DEVICE_DEFINITION = DeviceInfo(
        device_type="urn:schemas-upnp-org:device:InternetGatewayDevice:1",
        friendly_name="Dummy Router",
        manufacturer="Steven",
        manufacturer_url=None,
        model_name="DummyRouter v1",
        model_url=None,
        udn="uuid:51e00c19-c8f3-4b28-9ef1-7f562f204c80",
        upc=None,
        model_description="Dummy Router IGD",
        model_number="v0.0.1",
        serial_number="0000001",
        presentation_url=None,
        url="/device.xml",
        icons=[],
        xml=ET.Element("server_device"),
    )
    EMBEDDED_DEVICES = [WanDevice]
    SERVICES = [Layer3ForwardingService]

    def __init__(self, requester: UpnpRequester, base_uri: str, boot_id: int, config_id: int) -> None:
        """Initialize."""
        super().__init__(
            requester=requester,
            base_uri=base_uri,
            boot_id=boot_id,
            config_id=config_id,
        )


async def async_main(server: UpnpServer) -> None:
    """Main."""
    await server.async_start()

    loop_no = 0
    while True:
        upnp_service = server._device.find_service("urn:schemas-upnp-org:service:WANIPConnection:1")
        wanipc_service = cast(WANIPConnectionService, upnp_service)

        external_ip_address_var = wanipc_service.state_variable("ExternalIPAddress")
        external_ip_address_var.value = f"1.2.3.{(loop_no % 255) + 1}"

        number_of_port_entries_var = wanipc_service.state_variable("PortMappingNumberOfEntries")
        number_of_port_entries_var.value = loop_no % 10

        await asyncio.sleep(30)

        loop_no += 1

async def async_stop(server: UpnpServer) -> None:
    await server.async_stop()

    loop = asyncio.get_event_loop()
    loop.run_until_complete()


if __name__ == "__main__":
    boot_id = int(time())
    config_id = 1
    server = UpnpServer(IgdDevice, SOURCE, http_port=HTTP_PORT, boot_id=boot_id, config_id=config_id)

    try:
        asyncio.run(async_main(server))
    except KeyboardInterrupt:
        print(KeyboardInterrupt)

    asyncio.run(server.async_stop())
