1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
from typing import Dict, Sequence, Union
import libcst
import libcst.matchers as m
from libcst import CSTLogicError
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.helpers import insert_header_comments
class FixPyreDirectivesCommand(VisitorBasedCodemodCommand):
"""
Given a source file, we'll move the any strict or unsafe tag to the top of the
file if it contains one. Also tries to fix typo'd directives.
"""
DESCRIPTION: str = "Fixes common misspelling and location errors with pyre tags."
PYRE_TAGS: Sequence[str] = ["strict", "unsafe"]
def __init__(self, context: CodemodContext) -> None:
super().__init__(context)
self.move_strict: Dict[str, bool] = {tag: False for tag in self.PYRE_TAGS}
self.module_header_tags: Dict[str, int] = {tag: 0 for tag in self.PYRE_TAGS}
self.in_module_header: bool = False
def visit_Module_header(self, node: libcst.Module) -> None:
if self.in_module_header:
raise CSTLogicError("Logic error!")
self.in_module_header = True
def leave_Module_header(self, node: libcst.Module) -> None:
if not self.in_module_header:
raise CSTLogicError("Logic error!")
self.in_module_header = False
def leave_EmptyLine(
self, original_node: libcst.EmptyLine, updated_node: libcst.EmptyLine
) -> Union[libcst.EmptyLine, libcst.RemovalSentinel]:
# First, find misplaced lines.
for tag in self.PYRE_TAGS:
if m.matches(updated_node, m.EmptyLine(comment=m.Comment(f"# pyre-{tag}"))):
if self.in_module_header:
# We only want to remove this if we've already found another
# pyre-strict in the header (that means its duplicated). We
# also don't want to move the pyre-strict since its already in
# the header, so don't mark that we need to move.
self.module_header_tags[tag] += 1
if self.module_header_tags[tag] > 1:
return libcst.RemoveFromParent()
else:
return updated_node
else:
# This showed up outside the module header, so move it inside
if self.module_header_tags[tag] < 1:
self.move_strict[tag] = True
return libcst.RemoveFromParent()
# Now, find misnamed lines
if m.matches(updated_node, m.EmptyLine(comment=m.Comment(f"# pyre {tag}"))):
if self.in_module_header:
# We only want to remove this if we've already found another
# pyre-strict in the header (that means its duplicated). We
# also don't want to move the pyre-strict since its already in
# the header, so don't mark that we need to move.
self.module_header_tags[tag] += 1
if self.module_header_tags[tag] > 1:
return libcst.RemoveFromParent()
else:
return updated_node.with_changes(
comment=libcst.Comment(f"# pyre-{tag}")
)
else:
# We found an intended pyre-strict, but its spelled wrong. So, remove it
# and re-add a new one in leave_Module.
if self.module_header_tags[tag] < 1:
self.move_strict[tag] = True
return libcst.RemoveFromParent()
# We found a regular comment, don't care about this.
return updated_node
def leave_Module(
self, original_node: libcst.Module, updated_node: libcst.Module
) -> libcst.Module:
comments = [f"# pyre-{tag}" for tag in self.PYRE_TAGS if self.move_strict[tag]]
return insert_header_comments(updated_node, comments)
|