#include "clang/AST/ParentMapContext.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/AST/Decl.h"
#include "clang/AST/Expr.h"
#include "clang/AST/TemplateBase.h"
using namespace clang;
ParentMapContext::ParentMapContext(ASTContext &Ctx) : ASTCtx(Ctx) {}
ParentMapContext::~ParentMapContext() = default;
void ParentMapContext::clear() { Parents.reset(); }
const Expr *ParentMapContext::traverseIgnored(const Expr *E) const {
return traverseIgnored(const_cast<Expr *>(E));
}
Expr *ParentMapContext::traverseIgnored(Expr *E) const {
if (!E)
return nullptr;
switch (Traversal) {
case TK_AsIs:
return E;
case TK_IgnoreUnlessSpelledInSource:
return E->IgnoreUnlessSpelledInSource();
}
llvm_unreachable("Invalid Traversal type!");
}
DynTypedNode ParentMapContext::traverseIgnored(const DynTypedNode &N) const {
if (const auto *E = N.get<Expr>()) {
return DynTypedNode::create(*traverseIgnored(E));
}
return N;
}
template <typename T, typename... U>
std::tuple<bool, DynTypedNodeList, const T *, const U *...>
matchParents(const DynTypedNodeList &NodeList,
ParentMapContext::ParentMap *ParentMap);
template <typename, typename...> struct MatchParents;
class ParentMapContext::ParentMap {
template <typename, typename...> friend struct ::MatchParents;
using ParentVector = llvm::SmallVector<DynTypedNode, 2>;
using ParentMapPointers =
llvm::DenseMap<const void *,
llvm::PointerUnion<const Decl *, const Stmt *,
DynTypedNode *, ParentVector *>>;
using ParentMapOtherNodes =
llvm::DenseMap<DynTypedNode,
llvm::PointerUnion<const Decl *, const Stmt *,
DynTypedNode *, ParentVector *>>;
ParentMapPointers PointerParents;
ParentMapOtherNodes OtherParents;
class ASTVisitor;
static DynTypedNode
getSingleDynTypedNodeFromParentMap(ParentMapPointers::mapped_type U) {
if (const auto *D = U.dyn_cast<const Decl *>())
return DynTypedNode::create(*D);
if (const auto *S = U.dyn_cast<const Stmt *>())
return DynTypedNode::create(*S);
return *U.get<DynTypedNode *>();
}
template <typename NodeTy, typename MapTy>
static DynTypedNodeList getDynNodeFromMap(const NodeTy &Node,
const MapTy &Map) {
auto I = Map.find(Node);
if (I == Map.end()) {
return llvm::ArrayRef<DynTypedNode>();
}
if (const auto *V = I->second.template dyn_cast<ParentVector *>()) {
return llvm::makeArrayRef(*V);
}
return getSingleDynTypedNodeFromParentMap(I->second);
}
public:
ParentMap(ASTContext &Ctx);
~ParentMap() {
for (const auto &Entry : PointerParents) {
if (Entry.second.is<DynTypedNode *>()) {
delete Entry.second.get<DynTypedNode *>();
} else if (Entry.second.is<ParentVector *>()) {
delete Entry.second.get<ParentVector *>();
}
}
for (const auto &Entry : OtherParents) {
if (Entry.second.is<DynTypedNode *>()) {
delete Entry.second.get<DynTypedNode *>();
} else if (Entry.second.is<ParentVector *>()) {
delete Entry.second.get<ParentVector *>();
}
}
}
DynTypedNodeList getParents(TraversalKind TK, const DynTypedNode &Node) {
if (Node.getNodeKind().hasPointerIdentity()) {
auto ParentList =
getDynNodeFromMap(Node.getMemoizationData(), PointerParents);
if (ParentList.size() > 0 && TK == TK_IgnoreUnlessSpelledInSource) {
const auto *ChildExpr = Node.get<Expr>();
{
auto RewrittenBinOpParentsList = ParentList;
int I = 0;
while (ChildExpr && RewrittenBinOpParentsList.size() == 1 &&
I++ < 4) {
const auto *S = RewrittenBinOpParentsList[0].get<Stmt>();
if (!S)
break;
const auto *RWBO = dyn_cast<CXXRewrittenBinaryOperator>(S);
if (!RWBO) {
RewrittenBinOpParentsList = getDynNodeFromMap(S, PointerParents);
continue;
}
if (RWBO->getLHS()->IgnoreUnlessSpelledInSource() != ChildExpr &&
RWBO->getRHS()->IgnoreUnlessSpelledInSource() != ChildExpr)
break;
return DynTypedNode::create(*RWBO);
}
}
const auto *ParentExpr = ParentList[0].get<Expr>();
if (ParentExpr && ChildExpr)
return AscendIgnoreUnlessSpelledInSource(ParentExpr, ChildExpr);
{
auto AncestorNodes =
matchParents<DeclStmt, CXXForRangeStmt>(ParentList, this);
if (std::get<bool>(AncestorNodes) &&
std::get<const CXXForRangeStmt *>(AncestorNodes)
->getLoopVarStmt() ==
std::get<const DeclStmt *>(AncestorNodes))
return std::get<DynTypedNodeList>(AncestorNodes);
}
{
auto AncestorNodes = matchParents<VarDecl, DeclStmt, CXXForRangeStmt>(
ParentList, this);
if (std::get<bool>(AncestorNodes) &&
std::get<const CXXForRangeStmt *>(AncestorNodes)
->getRangeStmt() ==
std::get<const DeclStmt *>(AncestorNodes))
return std::get<DynTypedNodeList>(AncestorNodes);
}
{
auto AncestorNodes =
matchParents<CXXMethodDecl, CXXRecordDecl, LambdaExpr>(ParentList,
this);
if (std::get<bool>(AncestorNodes))
return std::get<DynTypedNodeList>(AncestorNodes);
}
{
auto AncestorNodes =
matchParents<FunctionTemplateDecl, CXXRecordDecl, LambdaExpr>(
ParentList, this);
if (std::get<bool>(AncestorNodes))
return std::get<DynTypedNodeList>(AncestorNodes);
}
}
return ParentList;
}
return getDynNodeFromMap(Node, OtherParents);
}
DynTypedNodeList AscendIgnoreUnlessSpelledInSource(const Expr *E,
const Expr *Child) {
auto ShouldSkip = [](const Expr *E, const Expr *Child) {
if (isa<ImplicitCastExpr>(E))
return true;
if (isa<FullExpr>(E))
return true;
if (isa<MaterializeTemporaryExpr>(E))
return true;
if (isa<CXXBindTemporaryExpr>(E))
return true;
if (isa<ParenExpr>(E))
return true;
if (isa<ExprWithCleanups>(E))
return true;
auto SR = Child->getSourceRange();
if (const auto *C = dyn_cast<CXXFunctionalCastExpr>(E)) {
if (C->getSourceRange() == SR)
return true;
}
if (const auto *C = dyn_cast<CXXConstructExpr>(E)) {
if (C->getSourceRange() == SR || C->isElidable())
return true;
}
if (const auto *C = dyn_cast<CXXMemberCallExpr>(E)) {
if (C->getSourceRange() == SR)
return true;
}
if (const auto *C = dyn_cast<MemberExpr>(E)) {
if (C->getSourceRange() == SR)
return true;
}
return false;
};
while (ShouldSkip(E, Child)) {
auto It = PointerParents.find(E);
if (It == PointerParents.end())
break;
const auto *S = It->second.dyn_cast<const Stmt *>();
if (!S) {
if (auto *Vec = It->second.dyn_cast<ParentVector *>())
return llvm::makeArrayRef(*Vec);
return getSingleDynTypedNodeFromParentMap(It->second);
}
const auto *P = dyn_cast<Expr>(S);
if (!P)
return DynTypedNode::create(*S);
Child = E;
E = P;
}
return DynTypedNode::create(*E);
}
};
template <typename Tuple, std::size_t... Is>
auto tuple_pop_front_impl(const Tuple &tuple, std::index_sequence<Is...>) {
return std::make_tuple(std::get<1 + Is>(tuple)...);
}
template <typename Tuple> auto tuple_pop_front(const Tuple &tuple) {
return tuple_pop_front_impl(
tuple, std::make_index_sequence<std::tuple_size<Tuple>::value - 1>());
}
template <typename T, typename... U> struct MatchParents {
static std::tuple<bool, DynTypedNodeList, const T *, const U *...>
match(const DynTypedNodeList &NodeList,
ParentMapContext::ParentMap *ParentMap) {
if (const auto *TypedNode = NodeList[0].get<T>()) {
auto NextParentList =
ParentMap->getDynNodeFromMap(TypedNode, ParentMap->PointerParents);
if (NextParentList.size() == 1) {
auto TailTuple = MatchParents<U...>::match(NextParentList, ParentMap);
if (std::get<bool>(TailTuple)) {
return std::tuple_cat(
std::make_tuple(true, std::get<DynTypedNodeList>(TailTuple),
TypedNode),
tuple_pop_front(tuple_pop_front(TailTuple)));
}
}
}
return std::tuple_cat(std::make_tuple(false, NodeList),
std::tuple<const T *, const U *...>());
}
};
template <typename T> struct MatchParents<T> {
static std::tuple<bool, DynTypedNodeList, const T *>
match(const DynTypedNodeList &NodeList,
ParentMapContext::ParentMap *ParentMap) {
if (const auto *TypedNode = NodeList[0].get<T>()) {
auto NextParentList =
ParentMap->getDynNodeFromMap(TypedNode, ParentMap->PointerParents);
if (NextParentList.size() == 1)
return std::make_tuple(true, NodeList, TypedNode);
}
return std::make_tuple(false, NodeList, nullptr);
}
};
template <typename T, typename... U>
std::tuple<bool, DynTypedNodeList, const T *, const U *...>
matchParents(const DynTypedNodeList &NodeList,
ParentMapContext::ParentMap *ParentMap) {
return MatchParents<T, U...>::match(NodeList, ParentMap);
}
template <typename T> static DynTypedNode createDynTypedNode(const T &Node) {
return DynTypedNode::create(*Node);
}
template <> DynTypedNode createDynTypedNode(const TypeLoc &Node) {
return DynTypedNode::create(Node);
}
template <>
DynTypedNode createDynTypedNode(const NestedNameSpecifierLoc &Node) {
return DynTypedNode::create(Node);
}
template <> DynTypedNode createDynTypedNode(const ObjCProtocolLoc &Node) {
return DynTypedNode::create(Node);
}
class ParentMapContext::ParentMap::ASTVisitor
: public RecursiveASTVisitor<ASTVisitor> {
public:
ASTVisitor(ParentMap &Map) : Map(Map) {}
private:
friend class RecursiveASTVisitor<ASTVisitor>;
using VisitorBase = RecursiveASTVisitor<ASTVisitor>;
bool shouldVisitTemplateInstantiations() const { return true; }
bool shouldVisitImplicitCode() const { return true; }
template <typename MapNodeTy, typename MapTy>
void addParent(MapNodeTy MapNode, MapTy *Parents) {
if (ParentStack.empty())
return;
auto &NodeOrVector = (*Parents)[MapNode];
if (NodeOrVector.isNull()) {
if (const auto *D = ParentStack.back().get<Decl>())
NodeOrVector = D;
else if (const auto *S = ParentStack.back().get<Stmt>())
NodeOrVector = S;
else
NodeOrVector = new DynTypedNode(ParentStack.back());
} else {
if (!NodeOrVector.template is<ParentVector *>()) {
auto *Vector = new ParentVector(
1, getSingleDynTypedNodeFromParentMap(NodeOrVector));
delete NodeOrVector.template dyn_cast<DynTypedNode *>();
NodeOrVector = Vector;
}
auto *Vector = NodeOrVector.template get<ParentVector *>();
bool Found = ParentStack.back().getMemoizationData() &&
llvm::is_contained(*Vector, ParentStack.back());
if (!Found)
Vector->push_back(ParentStack.back());
}
}
template <typename T> static bool isNull(T Node) { return !Node; }
static bool isNull(ObjCProtocolLoc Node) { return false; }
template <typename T, typename MapNodeTy, typename BaseTraverseFn,
typename MapTy>
bool TraverseNode(T Node, MapNodeTy MapNode, BaseTraverseFn BaseTraverse,
MapTy *Parents) {
if (isNull(Node))
return true;
addParent(MapNode, Parents);
ParentStack.push_back(createDynTypedNode(Node));
bool Result = BaseTraverse();
ParentStack.pop_back();
return Result;
}
bool TraverseDecl(Decl *DeclNode) {
return TraverseNode(
DeclNode, DeclNode, [&] { return VisitorBase::TraverseDecl(DeclNode); },
&Map.PointerParents);
}
bool TraverseTypeLoc(TypeLoc TypeLocNode) {
return TraverseNode(
TypeLocNode, DynTypedNode::create(TypeLocNode),
[&] { return VisitorBase::TraverseTypeLoc(TypeLocNode); },
&Map.OtherParents);
}
bool TraverseNestedNameSpecifierLoc(NestedNameSpecifierLoc NNSLocNode) {
return TraverseNode(
NNSLocNode, DynTypedNode::create(NNSLocNode),
[&] { return VisitorBase::TraverseNestedNameSpecifierLoc(NNSLocNode); },
&Map.OtherParents);
}
bool TraverseAttr(Attr *AttrNode) {
return TraverseNode(
AttrNode, AttrNode, [&] { return VisitorBase::TraverseAttr(AttrNode); },
&Map.PointerParents);
}
bool TraverseObjCProtocolLoc(ObjCProtocolLoc ProtocolLocNode) {
return TraverseNode(
ProtocolLocNode, DynTypedNode::create(ProtocolLocNode),
[&] { return VisitorBase::TraverseObjCProtocolLoc(ProtocolLocNode); },
&Map.OtherParents);
}
bool dataTraverseStmtPre(Stmt *StmtNode) {
addParent(StmtNode, &Map.PointerParents);
ParentStack.push_back(DynTypedNode::create(*StmtNode));
return true;
}
bool dataTraverseStmtPost(Stmt *StmtNode) {
ParentStack.pop_back();
return true;
}
ParentMap ⤅
llvm::SmallVector<DynTypedNode, 16> ParentStack;
};
ParentMapContext::ParentMap::ParentMap(ASTContext &Ctx) {
ASTVisitor(*this).TraverseAST(Ctx);
}
DynTypedNodeList ParentMapContext::getParents(const DynTypedNode &Node) {
if (!Parents)
Parents = std::make_unique<ParentMap>(ASTCtx);
return Parents->getParents(getTraversalKind(), Node);
}