#include "llvm/Analysis/SyncDependenceAnalysis.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
#include <functional>
#define DEBUG_TYPE "sync-dependence"
namespace {
using namespace llvm;
using POCB = std::function<void(const BasicBlock &)>;
using VisitedSet = std::set<const BasicBlock *>;
using BlockStack = std::vector<const BasicBlock *>;
static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack,
VisitedSet &Finalized);
static void computeStackPO(BlockStack &Stack, const LoopInfo &LI, Loop *Loop,
POCB CallBack, VisitedSet &Finalized) {
const auto *LoopHeader = Loop ? Loop->getHeader() : nullptr;
while (!Stack.empty()) {
const auto *NextBB = Stack.back();
auto *NestedLoop = LI.getLoopFor(NextBB);
bool IsNestedLoop = NestedLoop != Loop;
if (IsNestedLoop) {
SmallVector<BasicBlock *, 3> NestedExits;
NestedLoop->getUniqueExitBlocks(NestedExits);
bool PushedNodes = false;
for (const auto *NestedExitBB : NestedExits) {
if (NestedExitBB == LoopHeader)
continue;
if (Loop && !Loop->contains(NestedExitBB))
continue;
if (Finalized.count(NestedExitBB))
continue;
PushedNodes = true;
Stack.push_back(NestedExitBB);
}
if (!PushedNodes) {
Stack.pop_back();
computeLoopPO(LI, *NestedLoop, CallBack, Finalized);
}
continue;
}
bool PushedNodes = false;
for (const auto *SuccBB : successors(NextBB)) {
if (SuccBB == LoopHeader)
continue;
if (Loop && !Loop->contains(SuccBB))
continue;
if (Finalized.count(SuccBB))
continue;
PushedNodes = true;
Stack.push_back(SuccBB);
}
if (!PushedNodes) {
Stack.pop_back();
if (!Finalized.insert(NextBB).second)
continue;
CallBack(*NextBB);
}
}
}
static void computeTopLevelPO(Function &F, const LoopInfo &LI, POCB CallBack) {
VisitedSet Finalized;
BlockStack Stack;
Stack.reserve(24); Stack.push_back(&F.getEntryBlock());
computeStackPO(Stack, LI, nullptr, CallBack, Finalized);
}
static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack,
VisitedSet &Finalized) {
std::vector<const BasicBlock *> Stack;
const auto *LoopHeader = Loop.getHeader();
Finalized.insert(LoopHeader);
CallBack(*LoopHeader);
for (const auto *BB : successors(LoopHeader)) {
if (!Loop.contains(BB))
continue;
if (BB == LoopHeader)
continue;
Stack.push_back(BB);
}
computeStackPO(Stack, LI, &Loop, CallBack, Finalized);
}
}
namespace llvm {
ControlDivergenceDesc SyncDependenceAnalysis::EmptyDivergenceDesc;
SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT,
const PostDominatorTree &PDT,
const LoopInfo &LI)
: DT(DT), PDT(PDT), LI(LI) {
computeTopLevelPO(*DT.getRoot()->getParent(), LI,
[&](const BasicBlock &BB) { LoopPO.appendBlock(BB); });
}
SyncDependenceAnalysis::~SyncDependenceAnalysis() = default;
struct DivergencePropagator {
const ModifiedPO &LoopPOT;
const DominatorTree &DT;
const PostDominatorTree &PDT;
const LoopInfo &LI;
const BasicBlock &DivTermBlock;
using BlockLabelVec = std::vector<const BasicBlock *>;
BlockLabelVec BlockLabels;
std::unique_ptr<ControlDivergenceDesc> DivDesc;
DivergencePropagator(const ModifiedPO &LoopPOT, const DominatorTree &DT,
const PostDominatorTree &PDT, const LoopInfo &LI,
const BasicBlock &DivTermBlock)
: LoopPOT(LoopPOT), DT(DT), PDT(PDT), LI(LI), DivTermBlock(DivTermBlock),
BlockLabels(LoopPOT.size(), nullptr),
DivDesc(new ControlDivergenceDesc) {}
void printDefs(raw_ostream &Out) {
Out << "Propagator::BlockLabels {\n";
for (int BlockIdx = (int)BlockLabels.size() - 1; BlockIdx > 0; --BlockIdx) {
const auto *Label = BlockLabels[BlockIdx];
Out << LoopPOT.getBlockAt(BlockIdx)->getName().str() << "(" << BlockIdx
<< ") : ";
if (!Label) {
Out << "<null>\n";
} else {
Out << Label->getName() << "\n";
}
}
Out << "}\n";
}
bool computeJoin(const BasicBlock &SuccBlock, const BasicBlock &PushedLabel) {
auto SuccIdx = LoopPOT.getIndexOf(SuccBlock);
const auto *OldLabel = BlockLabels[SuccIdx];
if (!OldLabel || (OldLabel == &PushedLabel)) {
BlockLabels[SuccIdx] = &PushedLabel;
return false;
}
BlockLabels[SuccIdx] = &SuccBlock;
return true;
}
bool visitLoopExitEdge(const BasicBlock &ExitBlock,
const BasicBlock &DefBlock, bool FromParentLoop) {
if (!FromParentLoop)
return visitEdge(ExitBlock, DefBlock);
if (!computeJoin(ExitBlock, DefBlock))
return false;
DivDesc->LoopDivBlocks.insert(&ExitBlock);
LLVM_DEBUG(dbgs() << "\tDivergent loop exit: " << ExitBlock.getName()
<< "\n");
return true;
}
bool visitEdge(const BasicBlock &SuccBlock, const BasicBlock &DefBlock) {
if (!computeJoin(SuccBlock, DefBlock))
return false;
DivDesc->JoinDivBlocks.insert(&SuccBlock);
LLVM_DEBUG(dbgs() << "\tDivergent join: " << SuccBlock.getName());
return true;
}
std::unique_ptr<ControlDivergenceDesc> computeJoinPoints() {
assert(DivDesc);
LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: " << DivTermBlock.getName()
<< "\n");
const auto *DivBlockLoop = LI.getLoopFor(&DivTermBlock);
int FloorIdx = LoopPOT.size() - 1;
const BasicBlock *FloorLabel = nullptr;
int BlockIdx = 0;
for (const auto *SuccBlock : successors(&DivTermBlock)) {
auto SuccIdx = LoopPOT.getIndexOf(*SuccBlock);
BlockLabels[SuccIdx] = SuccBlock;
BlockIdx = std::max<int>(BlockIdx, SuccIdx);
FloorIdx = std::min<int>(FloorIdx, SuccIdx);
if (!DivBlockLoop)
continue;
const auto *BlockLoop = LI.getLoopFor(SuccBlock);
if (BlockLoop && DivBlockLoop->contains(BlockLoop))
continue;
DivDesc->LoopDivBlocks.insert(SuccBlock);
LLVM_DEBUG(dbgs() << "\tImmediate divergent loop exit: "
<< SuccBlock->getName() << "\n");
}
for (; BlockIdx >= FloorIdx; --BlockIdx) {
LLVM_DEBUG(dbgs() << "Before next visit:\n"; printDefs(dbgs()));
const auto *Label = BlockLabels[BlockIdx];
if (!Label)
continue;
const auto *Block = LoopPOT.getBlockAt(BlockIdx);
LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n");
auto *BlockLoop = LI.getLoopFor(Block);
bool IsLoopHeader = BlockLoop && BlockLoop->getHeader() == Block;
bool CausedJoin = false;
int LoweredFloorIdx = FloorIdx;
if (IsLoopHeader) {
SmallVector<BasicBlock *, 4> BlockLoopExits;
BlockLoop->getExitBlocks(BlockLoopExits);
bool IsParentLoop = BlockLoop->contains(&DivTermBlock);
for (const auto *BlockLoopExit : BlockLoopExits) {
CausedJoin |= visitLoopExitEdge(*BlockLoopExit, *Label, IsParentLoop);
LoweredFloorIdx = std::min<int>(LoweredFloorIdx,
LoopPOT.getIndexOf(*BlockLoopExit));
}
} else {
for (const auto *SuccBlock : successors(Block)) {
CausedJoin |= visitEdge(*SuccBlock, *Label);
LoweredFloorIdx =
std::min<int>(LoweredFloorIdx, LoopPOT.getIndexOf(*SuccBlock));
}
}
if (CausedJoin) {
FloorIdx = LoweredFloorIdx;
} else if (FloorLabel != Label) {
FloorIdx = LoweredFloorIdx;
FloorLabel = Label;
}
}
LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs()));
return std::move(DivDesc);
}
};
#ifndef NDEBUG
static void printBlockSet(ConstBlockSet &Blocks, raw_ostream &Out) {
Out << "[";
ListSeparator LS;
for (const auto *BB : Blocks)
Out << LS << BB->getName();
Out << "]";
}
#endif
const ControlDivergenceDesc &
SyncDependenceAnalysis::getJoinBlocks(const Instruction &Term) {
if (Term.getNumSuccessors() <= 1) {
return EmptyDivergenceDesc;
}
auto ItCached = CachedControlDivDescs.find(&Term);
if (ItCached != CachedControlDivDescs.end())
return *ItCached->second;
const auto &TermBlock = *Term.getParent();
DivergencePropagator Propagator(LoopPO, DT, PDT, LI, TermBlock);
auto DivDesc = Propagator.computeJoinPoints();
LLVM_DEBUG(dbgs() << "Result (" << Term.getParent()->getName() << "):\n";
dbgs() << "JoinDivBlocks: ";
printBlockSet(DivDesc->JoinDivBlocks, dbgs());
dbgs() << "\nLoopDivBlocks: ";
printBlockSet(DivDesc->LoopDivBlocks, dbgs()); dbgs() << "\n";);
auto ItInserted = CachedControlDivDescs.emplace(&Term, std::move(DivDesc));
assert(ItInserted.second);
return *ItInserted.first->second;
}
}