#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/CFG.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Statepoint.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/Local.h"
#define DEBUG_TYPE "safepoint-placement"
STATISTIC(NumEntrySafepoints, "Number of entry safepoints inserted");
STATISTIC(NumBackedgeSafepoints, "Number of backedge safepoints inserted");
STATISTIC(CallInLoop,
"Number of loops without safepoints due to calls in loop");
STATISTIC(FiniteExecution,
"Number of loops without safepoints finite execution");
using namespace llvm;
static cl::opt<bool> AllBackedges("spp-all-backedges", cl::Hidden,
cl::init(false));
static cl::opt<int> CountedLoopTripWidth("spp-counted-loop-trip-width",
cl::Hidden, cl::init(32));
static cl::opt<bool> SplitBackedge("spp-split-backedge", cl::Hidden,
cl::init(false));
namespace {
struct PlaceBackedgeSafepointsImpl : public FunctionPass {
static char ID;
std::vector<Instruction *> PollLocations;
bool CallSafepointsEnabled;
ScalarEvolution *SE = nullptr;
DominatorTree *DT = nullptr;
LoopInfo *LI = nullptr;
TargetLibraryInfo *TLI = nullptr;
PlaceBackedgeSafepointsImpl(bool CallSafepoints = false)
: FunctionPass(ID), CallSafepointsEnabled(CallSafepoints) {
initializePlaceBackedgeSafepointsImplPass(*PassRegistry::getPassRegistry());
}
bool runOnLoop(Loop *);
void runOnLoopAndSubLoops(Loop *L) {
for (Loop *I : *L)
runOnLoopAndSubLoops(I);
runOnLoop(L);
}
bool runOnFunction(Function &F) override {
SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
for (Loop *I : *LI) {
runOnLoopAndSubLoops(I);
}
return false;
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<DominatorTreeWrapperPass>();
AU.addRequired<ScalarEvolutionWrapperPass>();
AU.addRequired<LoopInfoWrapperPass>();
AU.addRequired<TargetLibraryInfoWrapperPass>();
AU.setPreservesAll();
}
};
}
static cl::opt<bool> NoEntry("spp-no-entry", cl::Hidden, cl::init(false));
static cl::opt<bool> NoCall("spp-no-call", cl::Hidden, cl::init(false));
static cl::opt<bool> NoBackedge("spp-no-backedge", cl::Hidden, cl::init(false));
namespace {
struct PlaceSafepoints : public FunctionPass {
static char ID;
PlaceSafepoints() : FunctionPass(ID) {
initializePlaceSafepointsPass(*PassRegistry::getPassRegistry());
}
bool runOnFunction(Function &F) override;
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<TargetLibraryInfoWrapperPass>();
}
};
}
static void
InsertSafepointPoll(Instruction *InsertBefore,
std::vector<CallBase *> &ParsePointsNeeded ,
const TargetLibraryInfo &TLI);
static bool needsStatepoint(CallBase *Call, const TargetLibraryInfo &TLI) {
if (callsGCLeafFunction(Call, TLI))
return false;
if (auto *CI = dyn_cast<CallInst>(Call)) {
if (CI->isInlineAsm())
return false;
}
return !(isa<GCStatepointInst>(Call) || isa<GCRelocateInst>(Call) ||
isa<GCResultInst>(Call));
}
static bool containsUnconditionalCallSafepoint(Loop *L, BasicBlock *Header,
BasicBlock *Pred,
DominatorTree &DT,
const TargetLibraryInfo &TLI) {
assert(DT.dominates(Header, Pred) && "loop latch not dominated by header?");
BasicBlock *Current = Pred;
while (true) {
for (Instruction &I : *Current) {
if (auto *Call = dyn_cast<CallBase>(&I))
if (needsStatepoint(Call, TLI))
return true;
}
if (Current == Header)
break;
Current = DT.getNode(Current)->getIDom()->getBlock();
}
return false;
}
static bool mustBeFiniteCountedLoop(Loop *L, ScalarEvolution *SE,
BasicBlock *Pred) {
const SCEV *MaxTrips = SE->getConstantMaxBackedgeTakenCount(L);
if (!isa<SCEVCouldNotCompute>(MaxTrips) &&
SE->getUnsignedRange(MaxTrips).getUnsignedMax().isIntN(
CountedLoopTripWidth))
return true;
if (L->isLoopExiting(Pred)) {
const SCEV *MaxExec = SE->getExitCount(L, Pred);
if (!isa<SCEVCouldNotCompute>(MaxExec) &&
SE->getUnsignedRange(MaxExec).getUnsignedMax().isIntN(
CountedLoopTripWidth))
return true;
}
return false;
}
static void scanOneBB(Instruction *Start, Instruction *End,
std::vector<CallInst *> &Calls,
DenseSet<BasicBlock *> &Seen,
std::vector<BasicBlock *> &Worklist) {
for (BasicBlock::iterator BBI(Start), BBE0 = Start->getParent()->end(),
BBE1 = BasicBlock::iterator(End);
BBI != BBE0 && BBI != BBE1; BBI++) {
if (CallInst *CI = dyn_cast<CallInst>(&*BBI))
Calls.push_back(CI);
assert(!isa<InvokeInst>(&*BBI) &&
"support for invokes in poll code needed");
if (BBI->isTerminator()) {
BasicBlock *BB = BBI->getParent();
for (BasicBlock *Succ : successors(BB)) {
if (Seen.insert(Succ).second) {
Worklist.push_back(Succ);
}
}
}
}
}
static void scanInlinedCode(Instruction *Start, Instruction *End,
std::vector<CallInst *> &Calls,
DenseSet<BasicBlock *> &Seen) {
Calls.clear();
std::vector<BasicBlock *> Worklist;
Seen.insert(Start->getParent());
scanOneBB(Start, End, Calls, Seen, Worklist);
while (!Worklist.empty()) {
BasicBlock *BB = Worklist.back();
Worklist.pop_back();
scanOneBB(&*BB->begin(), End, Calls, Seen, Worklist);
}
}
bool PlaceBackedgeSafepointsImpl::runOnLoop(Loop *L) {
BasicBlock *Header = L->getHeader();
SmallVector<BasicBlock*, 16> LoopLatches;
L->getLoopLatches(LoopLatches);
for (BasicBlock *Pred : LoopLatches) {
assert(L->contains(Pred));
if (!AllBackedges) {
if (mustBeFiniteCountedLoop(L, SE, Pred)) {
LLVM_DEBUG(dbgs() << "skipping safepoint placement in finite loop\n");
FiniteExecution++;
continue;
}
if (CallSafepointsEnabled &&
containsUnconditionalCallSafepoint(L, Header, Pred, *DT, *TLI)) {
LLVM_DEBUG(
dbgs()
<< "skipping safepoint placement due to unconditional call\n");
CallInLoop++;
continue;
}
}
Instruction *Term = Pred->getTerminator();
LLVM_DEBUG(dbgs() << "[LSP] terminator instruction: " << *Term);
PollLocations.push_back(Term);
}
return false;
}
static bool doesNotRequireEntrySafepointBefore(CallBase *Call) {
if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Call)) {
switch (II->getIntrinsicID()) {
case Intrinsic::experimental_gc_statepoint:
case Intrinsic::experimental_patchpoint_void:
case Intrinsic::experimental_patchpoint_i64:
return false;
default:
return true;
}
}
return false;
}
static Instruction *findLocationForEntrySafepoint(Function &F,
DominatorTree &DT) {
auto HasNextInstruction = [](Instruction *I) {
if (!I->isTerminator())
return true;
BasicBlock *nextBB = I->getParent()->getUniqueSuccessor();
return nextBB && (nextBB->getUniquePredecessor() != nullptr);
};
auto NextInstruction = [&](Instruction *I) {
assert(HasNextInstruction(I) &&
"first check if there is a next instruction!");
if (I->isTerminator())
return &I->getParent()->getUniqueSuccessor()->front();
return &*++I->getIterator();
};
Instruction *Cursor = nullptr;
for (Cursor = &F.getEntryBlock().front(); HasNextInstruction(Cursor);
Cursor = NextInstruction(Cursor)) {
if (auto *Call = dyn_cast<CallBase>(Cursor)) {
if (doesNotRequireEntrySafepointBefore(Call))
continue;
break;
}
}
assert((HasNextInstruction(Cursor) || Cursor->isTerminator()) &&
"either we stopped because of a call, or because of terminator");
return Cursor;
}
const char GCSafepointPollName[] = "gc.safepoint_poll";
static bool isGCSafepointPoll(Function &F) {
return F.getName().equals(GCSafepointPollName);
}
static bool shouldRewriteFunction(Function &F) {
if (F.hasGC()) {
const auto &FunctionGCName = F.getGC();
const StringRef StatepointExampleName("statepoint-example");
const StringRef CoreCLRName("coreclr");
return (StatepointExampleName == FunctionGCName) ||
(CoreCLRName == FunctionGCName);
} else
return false;
}
static bool enableEntrySafepoints(Function &F) { return !NoEntry; }
static bool enableBackedgeSafepoints(Function &F) { return !NoBackedge; }
static bool enableCallSafepoints(Function &F) { return !NoCall; }
bool PlaceSafepoints::runOnFunction(Function &F) {
if (F.isDeclaration() || F.empty()) {
return false;
}
if (isGCSafepointPoll(F)) {
return false;
}
if (!shouldRewriteFunction(F))
return false;
const TargetLibraryInfo &TLI =
getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
bool Modified = false;
Modified |= removeUnreachableBlocks(F);
DominatorTree DT;
DT.recalculate(F);
SmallVector<Instruction *, 16> PollsNeeded;
std::vector<CallBase *> ParsePointNeeded;
if (enableBackedgeSafepoints(F)) {
legacy::FunctionPassManager FPM(F.getParent());
bool CanAssumeCallSafepoints = enableCallSafepoints(F);
auto *PBS = new PlaceBackedgeSafepointsImpl(CanAssumeCallSafepoints);
FPM.add(PBS);
FPM.run(F);
DT.recalculate(F);
auto &PollLocations = PBS->PollLocations;
auto OrderByBBName = [](Instruction *a, Instruction *b) {
return a->getParent()->getName() < b->getParent()->getName();
};
llvm::sort(PollLocations, OrderByBBName);
PollLocations.erase(std::unique(PollLocations.begin(),
PollLocations.end()),
PollLocations.end());
for (Instruction *Term : PollLocations) {
Modified = true;
if (SplitBackedge) {
SetVector<BasicBlock *> Headers;
for (unsigned i = 0; i < Term->getNumSuccessors(); i++) {
BasicBlock *Succ = Term->getSuccessor(i);
if (DT.dominates(Succ, Term->getParent())) {
Headers.insert(Succ);
}
}
assert(!Headers.empty() && "poll location is not a loop latch?");
SetVector<BasicBlock *> SplitBackedges;
for (BasicBlock *Header : Headers) {
BasicBlock *NewBB = SplitEdge(Term->getParent(), Header, &DT);
PollsNeeded.push_back(NewBB->getTerminator());
NumBackedgeSafepoints++;
}
} else {
PollsNeeded.push_back(Term);
NumBackedgeSafepoints++;
}
}
}
if (enableEntrySafepoints(F)) {
if (Instruction *Location = findLocationForEntrySafepoint(F, DT)) {
PollsNeeded.push_back(Location);
Modified = true;
NumEntrySafepoints++;
}
}
for (Instruction *PollLocation : PollsNeeded) {
std::vector<CallBase *> RuntimeCalls;
InsertSafepointPoll(PollLocation, RuntimeCalls, TLI);
llvm::append_range(ParsePointNeeded, RuntimeCalls);
}
return Modified;
}
char PlaceBackedgeSafepointsImpl::ID = 0;
char PlaceSafepoints::ID = 0;
FunctionPass *llvm::createPlaceSafepointsPass() {
return new PlaceSafepoints();
}
INITIALIZE_PASS_BEGIN(PlaceBackedgeSafepointsImpl,
"place-backedge-safepoints-impl",
"Place Backedge Safepoints", false, false)
INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
INITIALIZE_PASS_END(PlaceBackedgeSafepointsImpl,
"place-backedge-safepoints-impl",
"Place Backedge Safepoints", false, false)
INITIALIZE_PASS_BEGIN(PlaceSafepoints, "place-safepoints", "Place Safepoints",
false, false)
INITIALIZE_PASS_END(PlaceSafepoints, "place-safepoints", "Place Safepoints",
false, false)
static void
InsertSafepointPoll(Instruction *InsertBefore,
std::vector<CallBase *> &ParsePointsNeeded ,
const TargetLibraryInfo &TLI) {
BasicBlock *OrigBB = InsertBefore->getParent();
Module *M = InsertBefore->getModule();
assert(M && "must be part of a module");
auto *F = M->getFunction(GCSafepointPollName);
assert(F && "gc.safepoint_poll function is missing");
assert(F->getValueType() ==
FunctionType::get(Type::getVoidTy(M->getContext()), false) &&
"gc.safepoint_poll declared with wrong type");
assert(!F->empty() && "gc.safepoint_poll must be a non-empty function");
CallInst *PollCall = CallInst::Create(F, "", InsertBefore);
BasicBlock::iterator Before(PollCall), After(PollCall);
bool IsBegin = false;
if (Before == OrigBB->begin())
IsBegin = true;
else
Before--;
After++;
assert(After != OrigBB->end() && "must have successor");
InlineFunctionInfo IFI;
bool InlineStatus = InlineFunction(*PollCall, IFI).isSuccess();
assert(InlineStatus && "inline must succeed");
(void)InlineStatus;
assert(IFI.StaticAllocas.empty() && "can't have allocs");
std::vector<CallInst *> Calls; DenseSet<BasicBlock *> BBs;
BasicBlock::iterator Start = IsBegin ? OrigBB->begin() : std::next(Before);
assert(isPotentiallyReachable(&*Start, &*After) &&
"malformed poll function");
scanInlinedCode(&*Start, &*After, Calls, BBs);
assert(!Calls.empty() && "slow path not found for safepoint poll");
assert(ParsePointsNeeded.empty());
for (auto *CI : Calls) {
if (!needsStatepoint(CI, TLI))
continue;
ParsePointsNeeded.push_back(CI);
}
assert(ParsePointsNeeded.size() <= Calls.size());
}