1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
|
//===--- SymbolUSRFinder.cpp - Clang refactoring library ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// \brief Implements methods that find the set of USRs that correspond to
/// a symbol that's required for a refactoring operation.
///
//===----------------------------------------------------------------------===//
#include "clang/AST/AST.h"
#include "clang/AST/ASTConsumer.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/Decl.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/Tooling/Refactor/RefactoringActionFinder.h"
#include "clang/Tooling/Refactor/USRFinder.h"
#include "llvm/ADT/StringRef.h"
#include <vector>
using namespace clang;
using namespace clang::tooling::rename;
namespace {
/// \brief NamedDeclFindingConsumer delegates finding USRs of a found Decl to
/// \c AdditionalUSRFinder. \c AdditionalUSRFinder adds USRs of ctors and dtor
/// if the found declaration refers to a class and adds USRs of all overridden
/// methods if the declaration refers to a virtual C++ method or an ObjC method.
class AdditionalUSRFinder : public RecursiveASTVisitor<AdditionalUSRFinder> {
public:
AdditionalUSRFinder(const Decl *FoundDecl, ASTContext &Context)
: FoundDecl(FoundDecl), Context(Context) {}
llvm::StringSet<> Find() {
llvm::StringSet<> USRSet;
// Fill OverriddenMethods and PartialSpecs storages.
TraverseDecl(Context.getTranslationUnitDecl());
if (const auto *MethodDecl = dyn_cast<CXXMethodDecl>(FoundDecl)) {
addUSRsOfOverridenFunctions(MethodDecl, USRSet);
// FIXME: Use a more efficient/optimal algorithm to find the related
// methods.
for (const auto &OverriddenMethod : OverriddenMethods) {
if (checkIfOverriddenFunctionAscends(OverriddenMethod, USRSet))
USRSet.insert(getUSRForDecl(OverriddenMethod));
}
} else if (const auto *RecordDecl = dyn_cast<CXXRecordDecl>(FoundDecl)) {
handleCXXRecordDecl(RecordDecl, USRSet);
} else if (const auto *TemplateDecl =
dyn_cast<ClassTemplateDecl>(FoundDecl)) {
handleClassTemplateDecl(TemplateDecl, USRSet);
} else if (const auto *MethodDecl = dyn_cast<ObjCMethodDecl>(FoundDecl)) {
addUSRsOfOverriddenObjCMethods(MethodDecl, USRSet);
for (const auto &PotentialOverrider : PotentialObjCMethodOverridders)
if (checkIfPotentialObjCMethodOverriddes(PotentialOverrider, USRSet))
USRSet.insert(getUSRForDecl(PotentialOverrider));
} else {
USRSet.insert(getUSRForDecl(FoundDecl));
}
return USRSet;
}
bool VisitCXXMethodDecl(const CXXMethodDecl *MethodDecl) {
if (MethodDecl->isVirtual())
OverriddenMethods.push_back(MethodDecl);
return true;
}
bool VisitObjCMethodDecl(const ObjCMethodDecl *MethodDecl) {
if (const auto *FoundMethodDecl = dyn_cast<ObjCMethodDecl>(FoundDecl))
if (DeclarationName::compare(MethodDecl->getDeclName(),
FoundMethodDecl->getDeclName()) == 0 &&
MethodDecl->isOverriding())
PotentialObjCMethodOverridders.push_back(MethodDecl);
return true;
}
bool VisitClassTemplatePartialSpecializationDecl(
const ClassTemplatePartialSpecializationDecl *PartialSpec) {
if (!isa<ClassTemplateDecl>(FoundDecl) && !isa<CXXRecordDecl>(FoundDecl))
return true;
PartialSpecs.push_back(PartialSpec);
return true;
}
private:
void handleCXXRecordDecl(const CXXRecordDecl *RecordDecl,
llvm::StringSet<> &USRSet) {
const auto *RD = RecordDecl->getDefinition();
if (!RD) {
USRSet.insert(getUSRForDecl(RecordDecl));
return;
}
if (const auto *ClassTemplateSpecDecl =
dyn_cast<ClassTemplateSpecializationDecl>(RD))
handleClassTemplateDecl(ClassTemplateSpecDecl->getSpecializedTemplate(),
USRSet);
addUSRsOfCtorDtors(RD, USRSet);
}
void handleClassTemplateDecl(const ClassTemplateDecl *TemplateDecl,
llvm::StringSet<> &USRSet) {
for (const auto *Specialization : TemplateDecl->specializations())
addUSRsOfCtorDtors(Specialization, USRSet);
for (const auto *PartialSpec : PartialSpecs) {
if (PartialSpec->getSpecializedTemplate() == TemplateDecl)
addUSRsOfCtorDtors(PartialSpec, USRSet);
}
addUSRsOfCtorDtors(TemplateDecl->getTemplatedDecl(), USRSet);
}
void addUSRsOfCtorDtors(const CXXRecordDecl *RecordDecl,
llvm::StringSet<> &USRSet) {
const CXXRecordDecl *RD = RecordDecl;
RecordDecl = RD->getDefinition();
if (!RecordDecl) {
USRSet.insert(getUSRForDecl(RD));
return;
}
for (const auto *CtorDecl : RecordDecl->ctors()) {
auto USR = getUSRForDecl(CtorDecl);
if (!USR.empty())
USRSet.insert(USR);
}
auto USR = getUSRForDecl(RecordDecl->getDestructor());
if (!USR.empty())
USRSet.insert(USR);
USRSet.insert(getUSRForDecl(RecordDecl));
}
void addUSRsOfOverridenFunctions(const CXXMethodDecl *MethodDecl,
llvm::StringSet<> &USRSet) {
USRSet.insert(getUSRForDecl(MethodDecl));
// Recursively visit each OverridenMethod.
for (const auto &OverriddenMethod : MethodDecl->overridden_methods())
addUSRsOfOverridenFunctions(OverriddenMethod, USRSet);
}
bool checkIfOverriddenFunctionAscends(const CXXMethodDecl *MethodDecl,
const llvm::StringSet<> &USRSet) {
for (const auto &OverriddenMethod : MethodDecl->overridden_methods()) {
if (USRSet.find(getUSRForDecl(OverriddenMethod)) != USRSet.end())
return true;
return checkIfOverriddenFunctionAscends(OverriddenMethod, USRSet);
}
return false;
}
/// \brief Recursively visit all the methods which the given method
/// declaration overrides and adds them to the USR set.
void addUSRsOfOverriddenObjCMethods(const ObjCMethodDecl *MethodDecl,
llvm::StringSet<> &USRSet) {
// Exit early if this method was already visited.
if (!USRSet.insert(getUSRForDecl(MethodDecl)).second)
return;
SmallVector<const ObjCMethodDecl *, 8> Overrides;
MethodDecl->getOverriddenMethods(Overrides);
for (const auto &OverriddenMethod : Overrides)
addUSRsOfOverriddenObjCMethods(OverriddenMethod, USRSet);
}
/// \brief Returns true if the given Objective-C method overrides the
/// found Objective-C method declaration.
bool checkIfPotentialObjCMethodOverriddes(const ObjCMethodDecl *MethodDecl,
const llvm::StringSet<> &USRSet) {
SmallVector<const ObjCMethodDecl *, 8> Overrides;
MethodDecl->getOverriddenMethods(Overrides);
for (const auto &OverriddenMethod : Overrides) {
if (USRSet.find(getUSRForDecl(OverriddenMethod)) != USRSet.end())
return true;
if (checkIfPotentialObjCMethodOverriddes(OverriddenMethod, USRSet))
return true;
}
return false;
}
const Decl *FoundDecl;
ASTContext &Context;
std::vector<const CXXMethodDecl *> OverriddenMethods;
std::vector<const ClassTemplatePartialSpecializationDecl *> PartialSpecs;
/// \brief An array of Objective-C methods that potentially override the
/// found Objective-C method declaration \p FoundDecl.
std::vector<const ObjCMethodDecl *> PotentialObjCMethodOverridders;
};
} // end anonymous namespace
namespace clang {
namespace tooling {
llvm::StringSet<> findSymbolsUSRSet(const NamedDecl *FoundDecl,
ASTContext &Context) {
return AdditionalUSRFinder(FoundDecl, Context).Find();
}
} // end namespace tooling
} // end namespace clang
|