File: state.py

package info (click to toggle)
python-asusrouter 1.21.3-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,856 kB
  • sloc: python: 20,497; makefile: 3
file content (226 lines) | stat: -rw-r--r-- 6,239 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
"""State module."""

from __future__ import annotations

from collections.abc import Awaitable, Callable
from enum import Enum
import importlib
import logging
from types import ModuleType
from typing import Any

from asusrouter.modules.aura import AsusAura
from asusrouter.modules.connection import ConnectionState
from asusrouter.modules.data import AsusData, AsusDataState
from asusrouter.modules.ddns import AsusDDNS
from asusrouter.modules.parental_control import (
    AsusBlockAll,
    AsusParentalControl,
    ParentalControlRule,
)
from asusrouter.modules.port_forwarding import AsusPortForwarding
from asusrouter.modules.system import AsusSystem
from asusrouter.modules.vpnc import AsusVPNC
from asusrouter.modules.wireguard import (
    AsusWireGuardClient,
    AsusWireGuardServer,
)
from asusrouter.modules.wlan import AsusWLAN
from asusrouter.tools.converters import get_enum_key_by_value

from .led import AsusLED
from .openvpn import AsusOVPNClient, AsusOVPNServer

_LOGGER = logging.getLogger(__name__)


class AsusStateNone(int, Enum):
    """Asus state none."""

    NONE = 0


# AsusState = Union[AsusLED, AsusOVPNClient, AsusOVPNServer, AsusStateNone]


class AsusState(Enum):
    """Asus state."""

    NONE = AsusStateNone
    AURA = AsusAura
    BLOCK_ALL = AsusBlockAll
    CONNECTION = ConnectionState
    DDNS = AsusDDNS
    LED = AsusLED
    OPENVPN_CLIENT = AsusOVPNClient
    OPENVPN_SERVER = AsusOVPNServer
    PARENTAL_CONTROL = AsusParentalControl
    PC_RULE = ParentalControlRule
    PORT_FORWARDING = AsusPortForwarding
    SYSTEM = AsusSystem
    VPNC = AsusVPNC
    WIREGUARD_CLIENT = AsusWireGuardClient
    WIREGUARD_SERVER = AsusWireGuardServer
    WLAN = AsusWLAN


AsusStateMap: dict[AsusState, AsusData | None] = {
    AsusState.NONE: None,
    AsusState.AURA: AsusData.AURA,
    AsusState.BLOCK_ALL: AsusData.PARENTAL_CONTROL,
    AsusState.CONNECTION: None,
    AsusState.DDNS: None,
    AsusState.LED: AsusData.LED,
    AsusState.OPENVPN_CLIENT: AsusData.OPENVPN_CLIENT,
    AsusState.OPENVPN_SERVER: AsusData.OPENVPN_SERVER,
    AsusState.PARENTAL_CONTROL: AsusData.PARENTAL_CONTROL,
    AsusState.PC_RULE: AsusData.PARENTAL_CONTROL,
    AsusState.PORT_FORWARDING: AsusData.PORT_FORWARDING,
    AsusState.SYSTEM: AsusData.SYSTEM,
    AsusState.VPNC: AsusData.VPNC,
    AsusState.WIREGUARD_CLIENT: AsusData.WIREGUARD_CLIENT,
    AsusState.WIREGUARD_SERVER: AsusData.WIREGUARD_SERVER,
    AsusState.WLAN: AsusData.WLAN,
}


def add_conditional_state(state: AsusState, data: AsusData) -> None:
    """Add or change AsusStateMap."""

    if not isinstance(state, AsusState) or not isinstance(data, AsusData):
        _LOGGER.debug("Invalid state or data type: %s -> %s", state, data)
        return

    AsusStateMap[state] = data
    _LOGGER.debug("Added conditional state rule: %s -> %s", state, data)


def get_datatype(state: Any | None) -> AsusData | None:
    """Get the datatype."""

    asus_state = get_enum_key_by_value(
        AsusState, type(state), default=AsusState.NONE
    )

    return AsusStateMap.get(asus_state)


def _get_module_name(state: AsusState) -> str | None:
    """Get the module name."""

    module_class = get_datatype(state)
    if module_class:
        return module_class.value

    return None


def _get_module(state: AsusState) -> ModuleType | None:
    """Get the module."""

    # Module name
    module_name = _get_module_name(state)
    if not module_name:
        return None

    if module_name.endswith(("_client", "_server")):
        module_name = module_name[:-7]

    # Module path
    module_path = f"asusrouter.modules.{module_name}"

    try:
        # Import and return the module
        return importlib.import_module(module_path)
    except ModuleNotFoundError:
        _LOGGER.debug("No module found for state %s", state)
        return None


def _has_method(module: ModuleType, method: str) -> bool:
    """Check if the module has the method."""

    return hasattr(module, method) and callable(getattr(module, method))


async def set_state(
    callback: Callable[..., Awaitable[bool]],
    state: AsusState,
    **kwargs: Any,
) -> bool:
    """Set the state."""

    # Get the module
    submodule = _get_module(state)

    # Process the data if module found
    if submodule and _has_method(submodule, "set_state"):
        # Determine the extra parameter
        if getattr(submodule, "REQUIRE_STATE", False):
            kwargs["extra_param"] = kwargs.get("router_state")
        if getattr(submodule, "REQUIRE_IDENTITY", False):
            kwargs["extra_param"] = kwargs.get("identity")

        # Call the function with the determined parameters
        return await submodule.set_state(
            callback=callback,
            state=state,
            **kwargs,
        )

    if submodule is None:
        # Log the enum class and member name if possible
        if isinstance(state, Enum):
            _LOGGER.debug(
                "No module found for state %s.%s",
                type(state).__name__,
                state.name,
            )
        else:
            _LOGGER.debug("No module found for state %r", state)

    return False


def save_state(
    state: AsusState,
    library: dict[AsusData, AsusDataState],
    needed_time: int | None = None,
    last_id: int | None = None,
) -> None:
    """Save the state."""

    # Get the correct data key
    datatype = get_datatype(state)
    if datatype is None or datatype not in library:
        return

    # Save the state
    library[datatype].update_state(state, last_id)
    library[datatype].offset_time(needed_time)


async def keep_state(
    callback: Callable[..., Awaitable[Any]],
    states: AsusState | list[AsusState] | None,
    **kwargs: Any,
) -> None:
    """Keep the state."""

    if states is None:
        return

    # Make sure the state is a list
    states = [states] if not isinstance(states, list) else states

    # Process each state
    awaitables = [
        submodule.keep_state(callback, state, **kwargs)
        for state in states
        if (submodule := _get_module(state))
        and _has_method(submodule, "keep_state")
    ]

    # Execute all awaitables
    for awaitable in awaitables:
        await awaitable