1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
|
import inspect
from typing import Callable, Optional, Type
from mypy import nodes
from mypy.plugin import DynamicClassDefContext, Plugin
from mypy.plugins import dataclasses
import marshmallow_dataclass
_NEW_TYPE_SIG = inspect.signature(marshmallow_dataclass.NewType)
def plugin(version: str) -> Type[Plugin]:
return MarshmallowDataclassPlugin
class MarshmallowDataclassPlugin(Plugin):
def get_dynamic_class_hook(
self, fullname: str
) -> Optional[Callable[[DynamicClassDefContext], None]]:
if fullname == "marshmallow_dataclass.NewType":
return new_type_hook
return None
def get_class_decorator_hook(self, fullname: str):
if fullname == "marshmallow_dataclass.dataclass":
return dataclasses.dataclass_class_maker_callback
return None
def new_type_hook(ctx: DynamicClassDefContext) -> None:
"""
Dynamic class hook for :func:`marshmallow_dataclass.NewType`.
Uses the type of the ``typ`` argument.
"""
typ = _get_arg_by_name(ctx.call, "typ", _NEW_TYPE_SIG)
if not isinstance(typ, nodes.RefExpr):
return
info = typ.node
if not isinstance(info, nodes.TypeInfo):
return
ctx.api.add_symbol_table_node(ctx.name, nodes.SymbolTableNode(nodes.GDEF, info))
def _get_arg_by_name(
call: nodes.CallExpr, name: str, sig: inspect.Signature
) -> Optional[nodes.Expression]:
"""
Get value of argument from a call.
:return: The argument value, or ``None`` if it cannot be found.
.. warning::
This probably doesn't yet work for calls with ``*args`` and/or ``*kwargs``.
"""
args = []
kwargs = {}
for arg_name, arg_value in zip(call.arg_names, call.args):
if arg_name is None:
args.append(arg_value)
else:
kwargs[arg_name] = arg_value
try:
bound_args = sig.bind(*args, **kwargs)
except TypeError:
return None
try:
return bound_args.arguments[name]
except KeyError:
return None
|