1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
|
"""Test registry module."""
from collections.abc import Callable, Generator
import threading
from typing import Any
import pytest
from asusrouter.registry import ARCallableRegistry as ARCallReg
@pytest.fixture(autouse=True)
def clear_registry() -> Generator[None, None, None]:
"""Ensure tests start with a clean registry state."""
ARCallReg.clear()
yield
ARCallReg.clear()
def test_register_and_get_callable_by_instance_and_class() -> None:
"""Test registering and retrieving callables by instance and class."""
class Base:
pass
class Child(Base):
pass
def base_get(s: Any) -> str:
return "base"
# register using kwargs API
ARCallReg.register(Base, get_state=base_get)
# lookup by instance
fn = ARCallReg.get_callable(Child(), "get_state")
assert fn is base_get
assert fn is not None
assert fn(Child()) == "base"
# lookup by class
fn2 = ARCallReg.get_callable(Child, "get_state")
assert fn2 is base_get
def test_get_all_for_merges_mro_correctly() -> None:
"""Test that get_all_for merges MRO correctly."""
class A:
pass
class B(A):
pass
def a_get(s: Any) -> str:
return "a"
def b_get(s: Any) -> str:
return "b"
def a_set(s: Any, v: Any) -> tuple[str, Any]:
return ("a-set", v)
ARCallReg.register(A, get_state=a_get, set_state=a_set)
ARCallReg.register(B, get_state=b_get)
merged = ARCallReg.get_all_for(B())
# B overrides get_state, but inherits set_state from A
assert merged["get_state"] is b_get
assert merged["set_state"] is a_set
def test_unregister_and_clear() -> None:
"""Test unregistering and clearing the registry."""
class S:
pass
def fn(s: Any) -> str:
return "ok"
ARCallReg.register(S, get_state=fn)
assert ARCallReg.get_callable(S(), "get_state") is fn
ARCallReg.unregister(S)
assert ARCallReg.get_callable(S(), "get_state") is None
# re-register and then clear everything
ARCallReg.register(S, get_state=fn)
ARCallReg.clear()
assert ARCallReg.get_callable(S(), "get_state") is None
def test_get_callable_returns_none_when_missing() -> None:
"""Test that get_callable returns None when no callable is found."""
class X:
pass
assert ARCallReg.get_callable(X(), "nope") is None
assert ARCallReg.get_all_for(X()) == {}
def test_concurrent_registers_are_thread_safe() -> None:
"""Test that concurrent registrations are thread-safe."""
class Root:
pass
def make_fn(i: int) -> Callable[[Any], int]:
"""Create a function that returns the given integer."""
def f(s: Any) -> int:
"""Return the given integer."""
return i
return f
def worker(i: int) -> None:
"""Register a function in the ARCallableRegistry."""
ARCallReg.register(Root, **{f"fn{i}": make_fn(i)})
threads = [threading.Thread(target=worker, args=(i,)) for i in range(8)]
for t in threads:
t.start()
for t in threads:
t.join()
all_map = ARCallReg.get_all_for(Root())
# expect all registered names to be present
for i in range(8):
key = f"fn{i}"
assert key in all_map
assert callable(all_map[key])
|