1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
import textwrap
from contextlib import ExitStack
from types import MappingProxyType
from typing import (
Any,
cast,
Collection,
Iterable,
Mapping,
MutableMapping,
MutableSet,
Optional,
Type,
TYPE_CHECKING,
TypeVar,
)
from libcst._batched_visitor import BatchableCSTVisitor, visit_batched, VisitorMethod
from libcst._exceptions import MetadataException
from libcst.metadata.base_provider import BatchableMetadataProvider
if TYPE_CHECKING:
from libcst._nodes.base import CSTNode # noqa: F401
from libcst._nodes.module import Module # noqa: F401
from libcst._visitors import CSTVisitorT # noqa: F401
from libcst.metadata.base_provider import ( # noqa: F401
BaseMetadataProvider,
ProviderT,
)
_T = TypeVar("_T")
def _gen_batchable(
wrapper: "MetadataWrapper",
# pyre-fixme[2]: Parameter `providers` must have a type that does not contain `Any`
providers: Iterable[BatchableMetadataProvider[Any]],
) -> Mapping["ProviderT", Mapping["CSTNode", object]]:
"""
Returns map of metadata mappings from resolving ``providers`` on ``wrapper``.
"""
wrapper.visit_batched(providers)
# Make immutable metadata mapping
# pyre-ignore[7]
return {type(p): MappingProxyType(dict(p._computed)) for p in providers}
def _gather_providers(
providers: Collection["ProviderT"], gathered: MutableSet["ProviderT"]
) -> MutableSet["ProviderT"]:
"""
Recursively gathers all the given providers and their dependencies.
"""
for P in providers:
if P not in gathered:
gathered.add(P)
_gather_providers(P.METADATA_DEPENDENCIES, gathered)
return gathered
def _resolve_impl(
wrapper: "MetadataWrapper", providers: Collection["ProviderT"]
) -> None:
"""
Updates the _metadata map on wrapper with metadata from the given providers
as well as their dependencies.
"""
completed = set(wrapper._metadata.keys())
remaining = _gather_providers(set(providers), set()) - completed
while len(remaining) > 0:
batchable = set()
for P in remaining:
if set(P.METADATA_DEPENDENCIES).issubset(completed):
if issubclass(P, BatchableMetadataProvider):
batchable.add(P)
else:
wrapper._metadata[P] = (
P(wrapper._cache.get(P))._gen(wrapper)
if P.gen_cache
else P()._gen(wrapper)
)
completed.add(P)
initialized_batchable = [
p(wrapper._cache.get(p)) if p.gen_cache else p() for p in batchable
]
metadata_batch = _gen_batchable(wrapper, initialized_batchable)
wrapper._metadata.update(metadata_batch)
completed |= batchable
if len(completed) == 0 and len(batchable) == 0:
# remaining must be non-empty at this point
names = ", ".join([P.__name__ for P in remaining])
raise MetadataException(f"Detected circular dependencies in {names}")
remaining -= completed
class MetadataWrapper:
"""
A wrapper around a :class:`~libcst.Module` that stores associated metadata
for that module.
When a :class:`MetadataWrapper` is constructed over a module, the wrapper will
store a deep copy of the original module. This means
``MetadataWrapper(module).module == module`` is ``False``.
This copying operation ensures that a node will never appear twice (by identity) in
the same tree. This allows us to uniquely look up metadata for a node based on a
node's identity.
"""
__slots__ = ["__module", "_metadata", "_cache"]
__module: "Module"
_metadata: MutableMapping["ProviderT", Mapping["CSTNode", object]]
_cache: Mapping["ProviderT", object]
def __init__(
self,
module: "Module",
unsafe_skip_copy: bool = False,
cache: Mapping["ProviderT", object] = {},
) -> None:
"""
:param module: The module to wrap. This is deeply copied by default.
:param unsafe_skip_copy: When true, this skips the deep cloning of the module.
This can provide a small performance benefit, but you should only use this
if you know that there are no duplicate nodes in your tree (e.g. this
module came from the parser).
:param cache: Pass the needed cache to wrapper to be used when resolving metadata.
"""
# Ensure that module is safe to use by copying the module to remove
# any duplicate nodes.
if not unsafe_skip_copy:
module = module.deep_clone()
self.__module = module
self._metadata = {}
self._cache = cache
def __repr__(self) -> str:
return f"MetadataWrapper(\n{textwrap.indent(repr(self.module), ' ' * 4)},\n)"
@property
def module(self) -> "Module":
"""
The module that's wrapped by this MetadataWrapper. By default, this is a deep
copy of the passed in module.
::
mw = ModuleWrapper(module)
# Because `mw.module is not module`, you probably want to do visit and do
# your analysis on `mw.module`, not `module`.
mw.module.visit(DoSomeAnalysisVisitor)
"""
# use a property getter to enforce that this is a read-only variable
return self.__module
def resolve(
self, provider: Type["BaseMetadataProvider[_T]"]
) -> Mapping["CSTNode", _T]:
"""
Returns a copy of the metadata mapping computed by ``provider``.
"""
if provider in self._metadata:
metadata = self._metadata[provider]
else:
metadata = self.resolve_many([provider])[provider]
return cast(Mapping["CSTNode", _T], metadata)
def resolve_many(
self, providers: Collection["ProviderT"]
) -> Mapping["ProviderT", Mapping["CSTNode", object]]:
"""
Returns a copy of the map of metadata mapping computed by each provider
in ``providers``.
The returned map does not contain any metadata from undeclared metadata
dependencies that ``providers`` has.
"""
_resolve_impl(self, providers)
# Only return what what declared in providers
return {k: self._metadata[k] for k in providers}
def visit(self, visitor: "CSTVisitorT") -> "Module":
"""
Convenience method to resolve metadata before performing a traversal over
``self.module`` with ``visitor``. See :func:`~libcst.Module.visit`.
"""
with visitor.resolve(self):
return self.module.visit(visitor)
def visit_batched(
self,
visitors: Iterable[BatchableCSTVisitor],
before_visit: Optional[VisitorMethod] = None,
after_leave: Optional[VisitorMethod] = None,
) -> "CSTNode":
"""
Convenience method to resolve metadata before performing a traversal over
``self.module`` with ``visitors``. See :func:`~libcst.visit_batched`.
"""
with ExitStack() as stack:
# Resolve dependencies of visitors
for v in visitors:
stack.enter_context(v.resolve(self))
return visit_batched(self.module, visitors, before_visit, after_leave)
|