#include "clang/Tooling/Transformer/Parsing.h"
#include "clang/AST/Expr.h"
#include "clang/ASTMatchers/ASTMatchFinder.h"
#include "clang/Basic/CharInfo.h"
#include "clang/Basic/SourceLocation.h"
#include "clang/Lex/Lexer.h"
#include "clang/Tooling/Transformer/RangeSelector.h"
#include "clang/Tooling/Transformer/SourceCode.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Errc.h"
#include "llvm/Support/Error.h"
#include <string>
#include <utility>
#include <vector>
using namespace clang;
using namespace transformer;
namespace {
using llvm::Expected;
template <typename... Ts> using RangeSelectorOp = RangeSelector (*)(Ts...);
struct ParseState {
StringRef Input;
StringRef OriginalInput;
};
template <typename ResultType> struct ParseProgress {
ParseState State;
ResultType Value;
};
template <typename T> using ExpectedProgress = llvm::Expected<ParseProgress<T>>;
template <typename T> using ParseFunction = ExpectedProgress<T> (*)(ParseState);
class ParseError : public llvm::ErrorInfo<ParseError> {
public:
static char ID;
ParseError(size_t Pos, std::string ErrorMsg, std::string InputExcerpt)
: Pos(Pos), ErrorMsg(std::move(ErrorMsg)),
Excerpt(std::move(InputExcerpt)) {}
void log(llvm::raw_ostream &OS) const override {
OS << "parse error at position (" << Pos << "): " << ErrorMsg
<< ": " + Excerpt;
}
std::error_code convertToErrorCode() const override {
return llvm::inconvertibleErrorCode();
}
size_t Pos;
std::string ErrorMsg;
std::string Excerpt;
};
char ParseError::ID;
}
static const llvm::StringMap<RangeSelectorOp<std::string>> &
getUnaryStringSelectors() {
static const llvm::StringMap<RangeSelectorOp<std::string>> M = {
{"name", name},
{"node", node},
{"statement", statement},
{"statements", statements},
{"member", member},
{"callArgs", callArgs},
{"elseBranch", elseBranch},
{"initListElements", initListElements}};
return M;
}
static const llvm::StringMap<RangeSelectorOp<RangeSelector>> &
getUnaryRangeSelectors() {
static const llvm::StringMap<RangeSelectorOp<RangeSelector>> M = {
{"before", before}, {"after", after}, {"expansion", expansion}};
return M;
}
static const llvm::StringMap<RangeSelectorOp<std::string, std::string>> &
getBinaryStringSelectors() {
static const llvm::StringMap<RangeSelectorOp<std::string, std::string>> M = {
{"encloseNodes", encloseNodes}};
return M;
}
static const llvm::StringMap<RangeSelectorOp<RangeSelector, RangeSelector>> &
getBinaryRangeSelectors() {
static const llvm::StringMap<RangeSelectorOp<RangeSelector, RangeSelector>>
M = {{"enclose", enclose}, {"between", between}};
return M;
}
template <typename Element>
llvm::Optional<Element> findOptional(const llvm::StringMap<Element> &Map,
llvm::StringRef Key) {
auto it = Map.find(Key);
if (it == Map.end())
return llvm::None;
return it->second;
}
template <typename ResultType>
ParseProgress<ResultType> makeParseProgress(ParseState State,
ResultType Result) {
return ParseProgress<ResultType>{State, std::move(Result)};
}
static llvm::Error makeParseError(const ParseState &S, std::string ErrorMsg) {
size_t Pos = S.OriginalInput.size() - S.Input.size();
return llvm::make_error<ParseError>(Pos, std::move(ErrorMsg),
S.OriginalInput.substr(Pos, 20).str());
}
static ParseState advance(ParseState S, size_t N) {
S.Input = S.Input.drop_front(N);
return S;
}
static StringRef consumeWhitespace(StringRef S) {
return S.drop_while([](char c) { return isASCII(c) && isWhitespace(c); });
}
static ExpectedProgress<llvm::NoneType> parseChar(char c, ParseState State) {
State.Input = consumeWhitespace(State.Input);
if (State.Input.empty() || State.Input.front() != c)
return makeParseError(State,
("expected char not found: " + llvm::Twine(c)).str());
return makeParseProgress(advance(State, 1), llvm::None);
}
static ExpectedProgress<std::string> parseId(ParseState State) {
State.Input = consumeWhitespace(State.Input);
auto Id = State.Input.take_while(
[](char c) { return isASCII(c) && isAsciiIdentifierContinue(c); });
if (Id.empty())
return makeParseError(State, "failed to parse name");
return makeParseProgress(advance(State, Id.size()), Id.str());
}
static ExpectedProgress<std::string> parseStringId(ParseState State) {
State.Input = consumeWhitespace(State.Input);
if (State.Input.empty())
return makeParseError(State, "unexpected end of input");
if (!State.Input.consume_front("\""))
return makeParseError(
State,
"expecting string, but encountered other character or end of input");
StringRef Id = State.Input.take_until([](char c) { return c == '"'; });
if (State.Input.size() == Id.size())
return makeParseError(State, "unterminated string");
return makeParseProgress(advance(State, Id.size() + 1), Id.str());
}
template <typename T>
ExpectedProgress<RangeSelector> parseSingle(ParseFunction<T> ParseElement,
RangeSelectorOp<T> Op,
ParseState State) {
auto P = parseChar('(', State);
if (!P)
return P.takeError();
auto E = ParseElement(P->State);
if (!E)
return E.takeError();
P = parseChar(')', E->State);
if (!P)
return P.takeError();
return makeParseProgress(P->State, Op(std::move(E->Value)));
}
template <typename T>
ExpectedProgress<RangeSelector> parsePair(ParseFunction<T> ParseElement,
RangeSelectorOp<T, T> Op,
ParseState State) {
auto P = parseChar('(', State);
if (!P)
return P.takeError();
auto Left = ParseElement(P->State);
if (!Left)
return Left.takeError();
P = parseChar(',', Left->State);
if (!P)
return P.takeError();
auto Right = ParseElement(P->State);
if (!Right)
return Right.takeError();
P = parseChar(')', Right->State);
if (!P)
return P.takeError();
return makeParseProgress(P->State,
Op(std::move(Left->Value), std::move(Right->Value)));
}
static ExpectedProgress<RangeSelector>
parseRangeSelectorImpl(ParseState State) {
auto Id = parseId(State);
if (!Id)
return Id.takeError();
std::string OpName = std::move(Id->Value);
if (auto Op = findOptional(getUnaryStringSelectors(), OpName))
return parseSingle(parseStringId, *Op, Id->State);
if (auto Op = findOptional(getUnaryRangeSelectors(), OpName))
return parseSingle(parseRangeSelectorImpl, *Op, Id->State);
if (auto Op = findOptional(getBinaryStringSelectors(), OpName))
return parsePair(parseStringId, *Op, Id->State);
if (auto Op = findOptional(getBinaryRangeSelectors(), OpName))
return parsePair(parseRangeSelectorImpl, *Op, Id->State);
return makeParseError(State, "unknown selector name: " + OpName);
}
Expected<RangeSelector> transformer::parseRangeSelector(llvm::StringRef Input) {
ParseState State = {Input, Input};
ExpectedProgress<RangeSelector> Result = parseRangeSelectorImpl(State);
if (!Result)
return Result.takeError();
State = Result->State;
State.Input = consumeWhitespace(State.Input);
if (State.Input.empty())
return Result->Value;
return makeParseError(State, "unexpected input after selector");
}