#include "llvm/Analysis/DivergenceAnalysis.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/Analysis/CFG.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/PostDominators.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Value.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
using namespace llvm;
#define DEBUG_TYPE "divergence"
DivergenceAnalysisImpl::DivergenceAnalysisImpl(
const Function &F, const Loop *RegionLoop, const DominatorTree &DT,
const LoopInfo &LI, SyncDependenceAnalysis &SDA, bool IsLCSSAForm)
: F(F), RegionLoop(RegionLoop), DT(DT), LI(LI), SDA(SDA),
IsLCSSAForm(IsLCSSAForm) {}
bool DivergenceAnalysisImpl::markDivergent(const Value &DivVal) {
if (isAlwaysUniform(DivVal))
return false;
assert(isa<Instruction>(DivVal) || isa<Argument>(DivVal));
assert(!isAlwaysUniform(DivVal) && "cannot be a divergent");
return DivergentValues.insert(&DivVal).second;
}
void DivergenceAnalysisImpl::addUniformOverride(const Value &UniVal) {
UniformOverrides.insert(&UniVal);
}
bool DivergenceAnalysisImpl::isTemporalDivergent(
const BasicBlock &ObservingBlock, const Value &Val) const {
const auto *Inst = dyn_cast<const Instruction>(&Val);
if (!Inst)
return false;
for (const auto *Loop = LI.getLoopFor(Inst->getParent());
Loop != RegionLoop && !Loop->contains(&ObservingBlock);
Loop = Loop->getParentLoop()) {
if (DivergentLoops.contains(Loop))
return true;
}
return false;
}
bool DivergenceAnalysisImpl::inRegion(const Instruction &I) const {
return I.getParent() && inRegion(*I.getParent());
}
bool DivergenceAnalysisImpl::inRegion(const BasicBlock &BB) const {
return RegionLoop ? RegionLoop->contains(&BB) : (BB.getParent() == &F);
}
void DivergenceAnalysisImpl::pushUsers(const Value &V) {
const auto *I = dyn_cast<const Instruction>(&V);
if (I && I->isTerminator()) {
analyzeControlDivergence(*I);
return;
}
for (const auto *User : V.users()) {
const auto *UserInst = dyn_cast<const Instruction>(User);
if (!UserInst)
continue;
if (!inRegion(*UserInst))
continue;
if (markDivergent(*UserInst))
Worklist.push_back(UserInst);
}
}
static const Instruction *getIfCarriedInstruction(const Use &U,
const Loop &DivLoop) {
const auto *I = dyn_cast<const Instruction>(&U);
if (!I)
return nullptr;
if (!DivLoop.contains(I))
return nullptr;
return I;
}
void DivergenceAnalysisImpl::analyzeTemporalDivergence(
const Instruction &I, const Loop &OuterDivLoop) {
if (isAlwaysUniform(I))
return;
if (isDivergent(I))
return;
LLVM_DEBUG(dbgs() << "Analyze temporal divergence: " << I.getName() << "\n");
assert((isa<PHINode>(I) || !IsLCSSAForm) &&
"In LCSSA form all users of loop-exiting defs are Phi nodes.");
for (const Use &Op : I.operands()) {
const auto *OpInst = getIfCarriedInstruction(Op, OuterDivLoop);
if (!OpInst)
continue;
if (markDivergent(I))
pushUsers(I);
return;
}
}
void DivergenceAnalysisImpl::analyzeLoopExitDivergence(
const BasicBlock &DivExit, const Loop &OuterDivLoop) {
if (IsLCSSAForm) {
for (const auto &Phi : DivExit.phis()) {
analyzeTemporalDivergence(Phi, OuterDivLoop);
}
return;
}
const BasicBlock &LoopHeader = *OuterDivLoop.getHeader();
SmallVector<const BasicBlock *, 8> TaintStack;
TaintStack.push_back(&DivExit);
DenseSet<const BasicBlock *> Visited;
Visited.insert(&DivExit);
do {
auto *UserBlock = TaintStack.pop_back_val();
if (!inRegion(*UserBlock))
continue;
assert(!OuterDivLoop.contains(UserBlock) &&
"irreducible control flow detected");
if (!DT.dominates(&LoopHeader, UserBlock)) {
for (const auto &Phi : UserBlock->phis()) {
analyzeTemporalDivergence(Phi, OuterDivLoop);
}
continue;
}
for (const auto &I : *UserBlock) {
analyzeTemporalDivergence(I, OuterDivLoop);
}
for (const auto *SuccBlock : successors(UserBlock)) {
if (!Visited.insert(SuccBlock).second) {
continue;
}
TaintStack.push_back(SuccBlock);
}
} while (!TaintStack.empty());
}
void DivergenceAnalysisImpl::propagateLoopExitDivergence(
const BasicBlock &DivExit, const Loop &InnerDivLoop) {
LLVM_DEBUG(dbgs() << "\tpropLoopExitDiv " << DivExit.getName() << "\n");
const Loop *DivLoop = &InnerDivLoop;
const Loop *OuterDivLoop = DivLoop;
const Loop *ExitLevelLoop = LI.getLoopFor(&DivExit);
const unsigned LoopExitDepth =
ExitLevelLoop ? ExitLevelLoop->getLoopDepth() : 0;
while (DivLoop && DivLoop->getLoopDepth() > LoopExitDepth) {
DivergentLoops.insert(DivLoop); OuterDivLoop = DivLoop;
DivLoop = DivLoop->getParentLoop();
}
LLVM_DEBUG(dbgs() << "\tOuter-most left loop: " << OuterDivLoop->getName()
<< "\n");
analyzeLoopExitDivergence(DivExit, *OuterDivLoop);
}
void DivergenceAnalysisImpl::taintAndPushPhiNodes(const BasicBlock &JoinBlock) {
LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << JoinBlock.getName()
<< "\n");
if (!inRegion(JoinBlock)) {
return;
}
for (const auto &Phi : JoinBlock.phis()) {
if (isDivergent(Phi))
continue;
if (Phi.hasConstantOrUndefValue())
continue;
if (markDivergent(Phi))
Worklist.push_back(&Phi);
}
}
void DivergenceAnalysisImpl::analyzeControlDivergence(const Instruction &Term) {
LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Term.getParent()->getName()
<< "\n");
if (!DT.isReachableFromEntry(Term.getParent()))
return;
const auto *BranchLoop = LI.getLoopFor(Term.getParent());
const auto &DivDesc = SDA.getJoinBlocks(Term);
for (const auto *JoinBlock : DivDesc.JoinDivBlocks) {
taintAndPushPhiNodes(*JoinBlock);
}
assert(DivDesc.LoopDivBlocks.empty() || BranchLoop);
for (const auto *DivExitBlock : DivDesc.LoopDivBlocks) {
propagateLoopExitDivergence(*DivExitBlock, *BranchLoop);
}
}
void DivergenceAnalysisImpl::compute() {
auto DivValuesCopy = DivergentValues;
for (const auto *DivVal : DivValuesCopy) {
assert(isDivergent(*DivVal) && "Worklist invariant violated!");
pushUsers(*DivVal);
}
while (!Worklist.empty()) {
const Instruction &I = *Worklist.back();
Worklist.pop_back();
assert(isDivergent(I) && "Worklist invariant violated!");
pushUsers(I);
}
}
bool DivergenceAnalysisImpl::isAlwaysUniform(const Value &V) const {
return UniformOverrides.contains(&V);
}
bool DivergenceAnalysisImpl::isDivergent(const Value &V) const {
return DivergentValues.contains(&V);
}
bool DivergenceAnalysisImpl::isDivergentUse(const Use &U) const {
Value &V = *U.get();
Instruction &I = *cast<Instruction>(U.getUser());
return isDivergent(V) || isTemporalDivergent(*I.getParent(), V);
}
DivergenceInfo::DivergenceInfo(Function &F, const DominatorTree &DT,
const PostDominatorTree &PDT, const LoopInfo &LI,
const TargetTransformInfo &TTI,
bool KnownReducible)
: F(F) {
if (!KnownReducible) {
using RPOTraversal = ReversePostOrderTraversal<const Function *>;
RPOTraversal FuncRPOT(&F);
if (containsIrreducibleCFG<const BasicBlock *, const RPOTraversal,
const LoopInfo>(FuncRPOT, LI)) {
ContainsIrreducible = true;
return;
}
}
SDA = std::make_unique<SyncDependenceAnalysis>(DT, PDT, LI);
DA = std::make_unique<DivergenceAnalysisImpl>(F, nullptr, DT, LI, *SDA,
false);
for (auto &I : instructions(F)) {
if (TTI.isSourceOfDivergence(&I)) {
DA->markDivergent(I);
} else if (TTI.isAlwaysUniform(&I)) {
DA->addUniformOverride(I);
}
}
for (auto &Arg : F.args()) {
if (TTI.isSourceOfDivergence(&Arg)) {
DA->markDivergent(Arg);
}
}
DA->compute();
}
AnalysisKey DivergenceAnalysis::Key;
DivergenceAnalysis::Result
DivergenceAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F);
auto &LI = AM.getResult<LoopAnalysis>(F);
auto &TTI = AM.getResult<TargetIRAnalysis>(F);
return DivergenceInfo(F, DT, PDT, LI, TTI, false);
}
PreservedAnalyses
DivergenceAnalysisPrinterPass::run(Function &F, FunctionAnalysisManager &FAM) {
auto &DI = FAM.getResult<DivergenceAnalysis>(F);
OS << "'Divergence Analysis' for function '" << F.getName() << "':\n";
if (DI.hasDivergence()) {
for (auto &Arg : F.args()) {
OS << (DI.isDivergent(Arg) ? "DIVERGENT: " : " ");
OS << Arg << "\n";
}
for (const BasicBlock &BB : F) {
OS << "\n " << BB.getName() << ":\n";
for (const auto &I : BB.instructionsWithoutDebug()) {
OS << (DI.isDivergent(I) ? "DIVERGENT: " : " ");
OS << I << "\n";
}
}
}
return PreservedAnalyses::all();
}