1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
import ast
import pprint
from typing import Any
import griffe
from griffe import Inspector, ObjectNode, Visitor, Attribute
from amqtt.contexts import default_listeners, default_broker_plugins, default_client_plugins
from amqtt.contrib.auth_db.plugin import default_hash_scheme
default_factory_map = {
'default_listeners': default_listeners(),
'default_broker_plugins': default_broker_plugins(),
'default_client_plugins': default_client_plugins(),
'default_hash_scheme': default_hash_scheme()
}
def get_qualified_name(node: ast.AST) -> str | None:
"""Recursively build the qualified name from an AST node."""
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
parent = get_qualified_name(node.value)
if parent:
return f"{parent}.{node.attr}"
return node.attr
elif isinstance(node, ast.Call):
# e.g., uuid.uuid4()
return get_qualified_name(node.func)
return None
def get_fully_qualified_name(call_node):
"""
Extracts the fully qualified name from an ast.Call node.
"""
if isinstance(call_node.func, ast.Name):
# Direct function call (e.g., "my_function(arg)")
return call_node.func.id
elif isinstance(call_node.func, ast.Attribute):
# Method call or qualified name (e.g., "obj.method(arg)" or "module.submodule.function(arg)")
parts = []
current = call_node.func
while isinstance(current, ast.Attribute):
parts.append(current.attr)
current = current.value
if isinstance(current, ast.Name):
parts.append(current.id)
return ".".join(reversed(parts))
else:
# Handle other potential cases (e.g., ast.Subscript) if necessary
return None
def get_callable_name(node):
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
return f"{get_callable_name(node.value)}.{node.attr}"
return None
def evaluate_callable_node(node):
try:
# Wrap the node in an Expression so it can be compiled
expr = ast.Expression(body=node)
compiled = compile(expr, filename="<ast>", mode="eval")
return eval(compiled, {"__builtins__": __builtins__, "list": list, "dict": dict})
except Exception as e:
return f"<unresolvable: {e}>"
class DataclassDefaultFactoryExtension(griffe.Extension):
"""Renders the output of a dataclasses field which uses a default factory.
def other_field_defaults():
return {'item1': 'value1', 'item2': 'value2'}
@dataclass
class MyDataClass:
my_field: dict[str, Any] = field(default_factory=dict)
my_other_field: dict[str, Any] = field(default_factory=other_field_defaults)
instead of documentation rendering this as:
```
class MyDataClass:
my_field: dict[str, Any] = dict()
my_other_field: dict[str, Any] = other_field_defaults()
```
it will be displayed with the output of factory functions for more clarity:
```
class MyDataClass:
my_field: dict[str, Any] = {}
my_other_field: dict[str, Any] = {'item1': 'value1', 'item2': 'value2'}
```
_note_ : for any custom default factory function, it must be added to the `default_factory_map`
in this file as `griffe` doesn't provide a straightforward mechanism with its AST to dynamically
import/call the function.
"""
def on_attribute_instance(
self,
*,
node: ast.AST | ObjectNode,
attr: Attribute,
agent: Visitor | Inspector,
**kwargs: Any,
) -> None:
"""Called for every `node` and/or `attr` on a file's AST."""
if not hasattr(node, "value"):
return
if isinstance(node.value, ast.Call):
# Search for all of the `default_factory` fields.
default_factory_value: str | None = None
for kw in node.value.keywords:
if kw.arg == "default_factory":
# based on the node type, return the proper function name
match get_callable_name(kw.value):
# `dict` and `list` are common default factory functions
case 'dict':
default_factory_value = "{}"
case 'list':
default_factory_value = "[]"
case _:
# otherwise, see the nodes is in our map for the custom default factory function
callable_name = get_callable_name(kw.value)
if callable_name in default_factory_map:
default_factory_value = pprint.pformat(default_factory_map[callable_name], indent=4, width=80, sort_dicts=False)
else:
# if not, display as the default
default_factory_value = f"{callable_name}()"
# store the information in the griffe attribute, which is what is passed to the template for rendering
if "dataclass_ext" not in attr.extra:
attr.extra["dataclass_ext"] = {}
attr.extra["dataclass_ext"]["has_default_factory"] = False
if default_factory_value is not None:
attr.extra["dataclass_ext"]["has_default_factory"] = True
attr.extra["dataclass_ext"]["default_factory"] = default_factory_value
|