#include "clang/Tooling/Transformer/RangeSelector.h"
#include "clang/AST/Expr.h"
#include "clang/AST/TypeLoc.h"
#include "clang/ASTMatchers/ASTMatchFinder.h"
#include "clang/Basic/SourceLocation.h"
#include "clang/Lex/Lexer.h"
#include "clang/Tooling/Transformer/SourceCode.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Errc.h"
#include "llvm/Support/Error.h"
#include <string>
#include <utility>
#include <vector>
using namespace clang;
using namespace transformer;
using ast_matchers::MatchFinder;
using llvm::Error;
using llvm::StringError;
using MatchResult = MatchFinder::MatchResult;
static Error invalidArgumentError(Twine Message) {
return llvm::make_error<StringError>(llvm::errc::invalid_argument, Message);
}
static Error typeError(StringRef ID, const ASTNodeKind &Kind) {
return invalidArgumentError("mismatched type (node id=" + ID +
" kind=" + Kind.asStringRef() + ")");
}
static Error typeError(StringRef ID, const ASTNodeKind &Kind,
Twine ExpectedType) {
return invalidArgumentError("mismatched type: expected one of " +
ExpectedType + " (node id=" + ID +
" kind=" + Kind.asStringRef() + ")");
}
static Error missingPropertyError(StringRef ID, Twine Description,
StringRef Property) {
return invalidArgumentError(Description + " requires property '" + Property +
"' (node id=" + ID + ")");
}
static Expected<DynTypedNode> getNode(const ast_matchers::BoundNodes &Nodes,
StringRef ID) {
auto &NodesMap = Nodes.getMap();
auto It = NodesMap.find(ID);
if (It == NodesMap.end())
return invalidArgumentError("ID not bound: " + ID);
return It->second;
}
static SourceLocation findPreviousTokenStart(SourceLocation Start,
const SourceManager &SM,
const LangOptions &LangOpts) {
if (Start.isInvalid() || Start.isMacroID())
return SourceLocation();
SourceLocation BeforeStart = Start.getLocWithOffset(-1);
if (BeforeStart.isInvalid() || BeforeStart.isMacroID())
return SourceLocation();
return Lexer::GetBeginningOfToken(BeforeStart, SM, LangOpts);
}
static SourceLocation findPreviousTokenKind(SourceLocation Start,
const SourceManager &SM,
const LangOptions &LangOpts,
tok::TokenKind TK) {
while (true) {
SourceLocation L = findPreviousTokenStart(Start, SM, LangOpts);
if (L.isInvalid() || L.isMacroID())
return SourceLocation();
Token T;
if (Lexer::getRawToken(L, T, SM, LangOpts, true))
return SourceLocation();
if (T.is(TK))
return T.getLocation();
Start = L;
}
}
static SourceLocation findOpenParen(const CallExpr &E, const SourceManager &SM,
const LangOptions &LangOpts) {
SourceLocation EndLoc =
E.getNumArgs() == 0 ? E.getRParenLoc() : E.getArg(0)->getBeginLoc();
return findPreviousTokenKind(EndLoc, SM, LangOpts, tok::TokenKind::l_paren);
}
RangeSelector transformer::before(RangeSelector Selector) {
return [Selector](const MatchResult &Result) -> Expected<CharSourceRange> {
Expected<CharSourceRange> SelectedRange = Selector(Result);
if (!SelectedRange)
return SelectedRange.takeError();
return CharSourceRange::getCharRange(SelectedRange->getBegin());
};
}
RangeSelector transformer::after(RangeSelector Selector) {
return [Selector](const MatchResult &Result) -> Expected<CharSourceRange> {
Expected<CharSourceRange> SelectedRange = Selector(Result);
if (!SelectedRange)
return SelectedRange.takeError();
SourceLocation End = SelectedRange->getEnd();
if (SelectedRange->isTokenRange()) {
CharSourceRange Range = Lexer::makeFileCharRange(
CharSourceRange::getTokenRange(SelectedRange->getEnd()),
*Result.SourceManager, Result.Context->getLangOpts());
if (Range.isInvalid())
return invalidArgumentError(
"after: can't resolve sub-range to valid source range");
End = Range.getEnd();
}
return CharSourceRange::getCharRange(End);
};
}
RangeSelector transformer::node(std::string ID) {
return [ID](const MatchResult &Result) -> Expected<CharSourceRange> {
Expected<DynTypedNode> Node = getNode(Result.Nodes, ID);
if (!Node)
return Node.takeError();
return (Node->get<Decl>() != nullptr ||
(Node->get<Stmt>() != nullptr && Node->get<Expr>() == nullptr))
? tooling::getExtendedRange(*Node, tok::TokenKind::semi,
*Result.Context)
: CharSourceRange::getTokenRange(Node->getSourceRange());
};
}
RangeSelector transformer::statement(std::string ID) {
return [ID](const MatchResult &Result) -> Expected<CharSourceRange> {
Expected<DynTypedNode> Node = getNode(Result.Nodes, ID);
if (!Node)
return Node.takeError();
return tooling::getExtendedRange(*Node, tok::TokenKind::semi,
*Result.Context);
};
}
RangeSelector transformer::enclose(RangeSelector Begin, RangeSelector End) {
return [Begin, End](const MatchResult &Result) -> Expected<CharSourceRange> {
Expected<CharSourceRange> BeginRange = Begin(Result);
if (!BeginRange)
return BeginRange.takeError();
Expected<CharSourceRange> EndRange = End(Result);
if (!EndRange)
return EndRange.takeError();
SourceLocation B = BeginRange->getBegin();
SourceLocation E = EndRange->getEnd();
if (Result.SourceManager->isBeforeInTranslationUnit(E, B)) {
return invalidArgumentError("Bad range: out of order");
}
return CharSourceRange(SourceRange(B, E), EndRange->isTokenRange());
};
}
RangeSelector transformer::encloseNodes(std::string BeginID,
std::string EndID) {
return transformer::enclose(node(std::move(BeginID)), node(std::move(EndID)));
}
RangeSelector transformer::member(std::string ID) {
return [ID](const MatchResult &Result) -> Expected<CharSourceRange> {
Expected<DynTypedNode> Node = getNode(Result.Nodes, ID);
if (!Node)
return Node.takeError();
if (auto *M = Node->get<clang::MemberExpr>())
return CharSourceRange::getTokenRange(
M->getMemberNameInfo().getSourceRange());
return typeError(ID, Node->getNodeKind(), "MemberExpr");
};
}
RangeSelector transformer::name(std::string ID) {
return [ID](const MatchResult &Result) -> Expected<CharSourceRange> {
Expected<DynTypedNode> N = getNode(Result.Nodes, ID);
if (!N)
return N.takeError();
auto &Node = *N;
if (const auto *D = Node.get<NamedDecl>()) {
if (!D->getDeclName().isIdentifier())
return missingPropertyError(ID, "name", "identifier");
SourceLocation L = D->getLocation();
auto R = CharSourceRange::getTokenRange(L, L);
if (tooling::getText(R, *Result.Context) != D->getName())
return CharSourceRange();
return R;
}
if (const auto *E = Node.get<DeclRefExpr>()) {
if (!E->getNameInfo().getName().isIdentifier())
return missingPropertyError(ID, "name", "identifier");
SourceLocation L = E->getLocation();
return CharSourceRange::getTokenRange(L, L);
}
if (const auto *I = Node.get<CXXCtorInitializer>()) {
if (!I->isMemberInitializer() && I->isWritten())
return missingPropertyError(ID, "name", "explicit member initializer");
SourceLocation L = I->getMemberLocation();
return CharSourceRange::getTokenRange(L, L);
}
if (const auto *T = Node.get<TypeLoc>()) {
TypeLoc Loc = *T;
auto ET = Loc.getAs<ElaboratedTypeLoc>();
if (!ET.isNull()) {
Loc = ET.getNamedTypeLoc();
}
return CharSourceRange::getTokenRange(Loc.getSourceRange());
}
return typeError(ID, Node.getNodeKind(),
"DeclRefExpr, NamedDecl, CXXCtorInitializer, TypeLoc");
};
}
namespace {
template <typename T, CharSourceRange (*Func)(const MatchResult &, const T &)>
class RelativeSelector {
std::string ID;
public:
RelativeSelector(std::string ID) : ID(std::move(ID)) {}
Expected<CharSourceRange> operator()(const MatchResult &Result) {
Expected<DynTypedNode> N = getNode(Result.Nodes, ID);
if (!N)
return N.takeError();
if (const auto *Arg = N->get<T>())
return Func(Result, *Arg);
return typeError(ID, N->getNodeKind());
}
};
}
namespace {
CharSourceRange getStatementsRange(const MatchResult &,
const CompoundStmt &CS) {
return CharSourceRange::getCharRange(CS.getLBracLoc().getLocWithOffset(1),
CS.getRBracLoc());
}
}
RangeSelector transformer::statements(std::string ID) {
return RelativeSelector<CompoundStmt, getStatementsRange>(std::move(ID));
}
namespace {
CharSourceRange getCallArgumentsRange(const MatchResult &Result,
const CallExpr &CE) {
return CharSourceRange::getCharRange(
findOpenParen(CE, *Result.SourceManager, Result.Context->getLangOpts())
.getLocWithOffset(1),
CE.getRParenLoc());
}
}
RangeSelector transformer::callArgs(std::string ID) {
return RelativeSelector<CallExpr, getCallArgumentsRange>(std::move(ID));
}
namespace {
CharSourceRange getElementsRange(const MatchResult &,
const InitListExpr &E) {
return CharSourceRange::getCharRange(E.getLBraceLoc().getLocWithOffset(1),
E.getRBraceLoc());
}
}
RangeSelector transformer::initListElements(std::string ID) {
return RelativeSelector<InitListExpr, getElementsRange>(std::move(ID));
}
namespace {
CharSourceRange getElseRange(const MatchResult &Result, const IfStmt &S) {
return tooling::maybeExtendRange(
CharSourceRange::getTokenRange(S.getElseLoc(), S.getEndLoc()),
tok::TokenKind::semi, *Result.Context);
}
}
RangeSelector transformer::elseBranch(std::string ID) {
return RelativeSelector<IfStmt, getElseRange>(std::move(ID));
}
RangeSelector transformer::expansion(RangeSelector S) {
return [S](const MatchResult &Result) -> Expected<CharSourceRange> {
Expected<CharSourceRange> SRange = S(Result);
if (!SRange)
return SRange.takeError();
return Result.SourceManager->getExpansionRange(*SRange);
};
}