#include "clang/AST/ASTContext.h"
#include "clang/ASTMatchers/ASTMatchFinder.h"
#include "clang/Tooling/Tooling.h"
#include "llvm/ADT/SmallString.h"
#include "gtest/gtest.h"
namespace clang {
using PrintingPolicyAdjuster = llvm::function_ref<void(PrintingPolicy &Policy)>;
template <typename NodeType>
using NodePrinter =
std::function<void(llvm::raw_ostream &Out, const ASTContext *Context,
const NodeType *Node,
PrintingPolicyAdjuster PolicyAdjuster)>;
template <typename NodeType>
using NodeFilter = std::function<bool(const NodeType *Node)>;
template <typename NodeType>
class PrintMatch : public ast_matchers::MatchFinder::MatchCallback {
using PrinterT = NodePrinter<NodeType>;
using FilterT = NodeFilter<NodeType>;
SmallString<1024> Printed;
unsigned NumFoundNodes;
PrinterT Printer;
FilterT Filter;
PrintingPolicyAdjuster PolicyAdjuster;
public:
PrintMatch(PrinterT Printer, PrintingPolicyAdjuster PolicyAdjuster,
FilterT Filter)
: NumFoundNodes(0), Printer(std::move(Printer)),
Filter(std::move(Filter)), PolicyAdjuster(PolicyAdjuster) {}
void run(const ast_matchers::MatchFinder::MatchResult &Result) override {
const NodeType *N = Result.Nodes.getNodeAs<NodeType>("id");
if (!N || !Filter(N))
return;
NumFoundNodes++;
if (NumFoundNodes > 1)
return;
llvm::raw_svector_ostream Out(Printed);
Printer(Out, Result.Context, N, PolicyAdjuster);
}
StringRef getPrinted() const { return Printed; }
unsigned getNumFoundNodes() const { return NumFoundNodes; }
};
template <typename NodeType> bool NoNodeFilter(const NodeType *) {
return true;
}
template <typename NodeType, typename Matcher>
::testing::AssertionResult
PrintedNodeMatches(StringRef Code, const std::vector<std::string> &Args,
const Matcher &NodeMatch, StringRef ExpectedPrinted,
StringRef FileName, NodePrinter<NodeType> Printer,
PrintingPolicyAdjuster PolicyAdjuster = nullptr,
bool AllowError = false,
NodeFilter<NodeType> Filter = &NoNodeFilter<NodeType>) {
PrintMatch<NodeType> Callback(Printer, PolicyAdjuster, Filter);
ast_matchers::MatchFinder Finder;
Finder.addMatcher(NodeMatch, &Callback);
std::unique_ptr<tooling::FrontendActionFactory> Factory(
tooling::newFrontendActionFactory(&Finder));
bool ToolResult;
if (FileName.empty()) {
ToolResult = tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args);
} else {
ToolResult =
tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName);
}
if (!ToolResult && !AllowError)
return testing::AssertionFailure()
<< "Parsing error in \"" << Code.str() << "\"";
if (Callback.getNumFoundNodes() == 0)
return testing::AssertionFailure() << "Matcher didn't find any nodes";
if (Callback.getNumFoundNodes() > 1)
return testing::AssertionFailure()
<< "Matcher should match only one node (found "
<< Callback.getNumFoundNodes() << ")";
if (Callback.getPrinted() != ExpectedPrinted)
return ::testing::AssertionFailure()
<< "Expected \"" << ExpectedPrinted.str() << "\", got \""
<< Callback.getPrinted().str() << "\"";
return ::testing::AssertionSuccess();
}
}