1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
|
"""Multidispatch for functions and methods.
This code is a Python 3, slimmed down version of the
generic package by Andrey Popp.
Only the generic function code is left intact -- no generic methods.
The interface has been made in line with `functools.singledispatch`.
Note that this module does not support annotated functions.
"""
from __future__ import annotations
import functools
import inspect
import logging
from typing import Any, Callable, Generic, TypeVar, Union, cast
from generic.registry import Registry, TypeAxis
__all__ = "multidispatch"
T = TypeVar("T", bound=Union[Callable[..., Any], type])
KeyType = Union[type, None]
logger = logging.getLogger(__name__)
def multidispatch(*argtypes: KeyType) -> Callable[[T], FunctionDispatcher[T]]:
"""Declare function as multidispatch.
This decorator takes ``argtypes`` argument types and replace
decorated function with :class:`.FunctionDispatcher` object, which
is responsible for multiple dispatch feature.
"""
def _replace_with_dispatcher(func: T) -> FunctionDispatcher[T]:
nonlocal argtypes
argspec = inspect.getfullargspec(func)
if not argtypes:
arity = _arity(argspec)
if isinstance(func, type):
# It's a class we deal with:
arity -= 1
argtypes = (object,) * arity
dispatcher = cast(
FunctionDispatcher[T],
functools.update_wrapper(FunctionDispatcher(argspec, len(argtypes)), func),
)
dispatcher.register_rule(func, *argtypes)
return dispatcher
return _replace_with_dispatcher
class FunctionDispatcher(Generic[T]):
"""Multidispatcher for functions.
This object dispatch calls to function by its argument types. Usually it is
produced by :func:`.multidispatch` decorator.
You should not manually create objects of this type.
"""
registry: Registry[T]
def __init__(self, argspec: inspect.FullArgSpec, params_arity: int) -> None:
"""Initialize dispatcher with ``argspec`` of type
:class:`inspect.ArgSpec` and ``params_arity`` that represent number
params."""
# Check if we have enough positional arguments for number of type params
if _arity(argspec) < params_arity:
raise TypeError(
"Not enough positional arguments "
"for number of type parameters provided."
)
self.argspec = argspec
self.params_arity = params_arity
axis = [(f"arg_{n:d}", TypeAxis()) for n in range(params_arity)]
self.registry = Registry(*axis)
def check_rule(self, rule: T, *argtypes: KeyType) -> None:
"""Check if the argument types match wrt number of arguments.
Raise TypeError in case of failure.
"""
# Check if we have the right number of parametrized types
if len(argtypes) != self.params_arity:
raise TypeError(
f"Wrong number of type parameters: have {len(argtypes)}, expected {self.params_arity}."
)
# Check if we have the same argspec (by number of args)
rule_argspec = inspect.getfullargspec(rule)
left_spec = tuple(x and len(x) or 0 for x in rule_argspec[:4])
right_spec = tuple(x and len(x) or 0 for x in self.argspec[:4])
if left_spec != right_spec:
raise TypeError(
f"Rule does not conform to previous implementations: {left_spec} != {right_spec}."
)
def register_rule(self, rule: T, *argtypes: KeyType) -> None:
"""Register new ``rule`` for ``argtypes``."""
self.check_rule(rule, *argtypes)
self.registry.register(rule, *argtypes)
def register(self, *argtypes: KeyType) -> Callable[[T], T]:
"""Decorator for registering new case for multidispatch.
New case will be registered for types identified by
``argtypes``. The length of ``argtypes`` should be equal to the
length of ``argtypes`` argument were passed corresponding
:func:`.multidispatch` call, which also indicated the number of
arguments multidispatch dispatches on.
"""
def register_rule(func: T) -> T:
"""Register rule wrapper function."""
self.register_rule(func, *argtypes)
return func
return register_rule
def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Dispatch call to appropriate rule."""
trimmed_args = args[: self.params_arity]
rule = self.registry.lookup(*trimmed_args)
if not rule:
logger.debug(self.registry._tree)
raise TypeError(f"No available rule found for {trimmed_args!r}")
return rule(*args, **kwargs)
def _arity(argspec: inspect.FullArgSpec) -> int:
"""Determinal positional arity of argspec."""
args = argspec.args or []
defaults: tuple[Any, ...] | list = argspec.defaults or []
return len(args) - len(defaults)
|