#include "llvm/Transforms/Scalar/DFAJumpThreading.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/CodeMetrics.h"
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/SSAUpdaterBulk.h"
#include "llvm/Transforms/Utils/ValueMapper.h"
#include <algorithm>
#include <deque>
#ifdef EXPENSIVE_CHECKS
#include "llvm/IR/Verifier.h"
#endif
using namespace llvm;
#define DEBUG_TYPE "dfa-jump-threading"
STATISTIC(NumTransforms, "Number of transformations done");
STATISTIC(NumCloned, "Number of blocks cloned");
STATISTIC(NumPaths, "Number of individual paths threaded");
static cl::opt<bool>
ClViewCfgBefore("dfa-jump-view-cfg-before",
cl::desc("View the CFG before DFA Jump Threading"),
cl::Hidden, cl::init(false));
static cl::opt<unsigned> MaxPathLength(
"dfa-max-path-length",
cl::desc("Max number of blocks searched to find a threading path"),
cl::Hidden, cl::init(20));
static cl::opt<unsigned> MaxNumPaths(
"dfa-max-num-paths",
cl::desc("Max number of paths enumerated around a switch"),
cl::Hidden, cl::init(200));
static cl::opt<unsigned>
CostThreshold("dfa-cost-threshold",
cl::desc("Maximum cost accepted for the transformation"),
cl::Hidden, cl::init(50));
namespace {
class SelectInstToUnfold {
SelectInst *SI;
PHINode *SIUse;
public:
SelectInstToUnfold(SelectInst *SI, PHINode *SIUse) : SI(SI), SIUse(SIUse) {}
SelectInst *getInst() { return SI; }
PHINode *getUse() { return SIUse; }
explicit operator bool() const { return SI && SIUse; }
};
void unfold(DomTreeUpdater *DTU, SelectInstToUnfold SIToUnfold,
std::vector<SelectInstToUnfold> *NewSIsToUnfold,
std::vector<BasicBlock *> *NewBBs);
class DFAJumpThreading {
public:
DFAJumpThreading(AssumptionCache *AC, DominatorTree *DT,
TargetTransformInfo *TTI, OptimizationRemarkEmitter *ORE)
: AC(AC), DT(DT), TTI(TTI), ORE(ORE) {}
bool run(Function &F);
private:
void
unfoldSelectInstrs(DominatorTree *DT,
const SmallVector<SelectInstToUnfold, 4> &SelectInsts) {
DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);
SmallVector<SelectInstToUnfold, 4> Stack;
for (SelectInstToUnfold SIToUnfold : SelectInsts)
Stack.push_back(SIToUnfold);
while (!Stack.empty()) {
SelectInstToUnfold SIToUnfold = Stack.pop_back_val();
std::vector<SelectInstToUnfold> NewSIsToUnfold;
std::vector<BasicBlock *> NewBBs;
unfold(&DTU, SIToUnfold, &NewSIsToUnfold, &NewBBs);
for (const SelectInstToUnfold &NewSIToUnfold : NewSIsToUnfold)
Stack.push_back(NewSIToUnfold);
}
}
AssumptionCache *AC;
DominatorTree *DT;
TargetTransformInfo *TTI;
OptimizationRemarkEmitter *ORE;
};
class DFAJumpThreadingLegacyPass : public FunctionPass {
public:
static char ID; DFAJumpThreadingLegacyPass() : FunctionPass(ID) {}
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<AssumptionCacheTracker>();
AU.addRequired<DominatorTreeWrapperPass>();
AU.addPreserved<DominatorTreeWrapperPass>();
AU.addRequired<TargetTransformInfoWrapperPass>();
AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
}
bool runOnFunction(Function &F) override {
if (skipFunction(F))
return false;
AssumptionCache *AC =
&getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
TargetTransformInfo *TTI =
&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
OptimizationRemarkEmitter *ORE =
&getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
return DFAJumpThreading(AC, DT, TTI, ORE).run(F);
}
};
}
char DFAJumpThreadingLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(DFAJumpThreadingLegacyPass, "dfa-jump-threading",
"DFA Jump Threading", false, false)
INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
INITIALIZE_PASS_END(DFAJumpThreadingLegacyPass, "dfa-jump-threading",
"DFA Jump Threading", false, false)
FunctionPass *llvm::createDFAJumpThreadingPass() {
return new DFAJumpThreadingLegacyPass();
}
namespace {
void createBasicBlockAndSinkSelectInst(
DomTreeUpdater *DTU, SelectInst *SI, PHINode *SIUse, SelectInst *SIToSink,
BasicBlock *EndBlock, StringRef NewBBName, BasicBlock **NewBlock,
BranchInst **NewBranch, std::vector<SelectInstToUnfold> *NewSIsToUnfold,
std::vector<BasicBlock *> *NewBBs) {
assert(SIToSink->hasOneUse());
assert(NewBlock);
assert(NewBranch);
*NewBlock = BasicBlock::Create(SI->getContext(), NewBBName,
EndBlock->getParent(), EndBlock);
NewBBs->push_back(*NewBlock);
*NewBranch = BranchInst::Create(EndBlock, *NewBlock);
SIToSink->moveBefore(*NewBranch);
NewSIsToUnfold->push_back(SelectInstToUnfold(SIToSink, SIUse));
DTU->applyUpdates({{DominatorTree::Insert, *NewBlock, EndBlock}});
}
void unfold(DomTreeUpdater *DTU, SelectInstToUnfold SIToUnfold,
std::vector<SelectInstToUnfold> *NewSIsToUnfold,
std::vector<BasicBlock *> *NewBBs) {
SelectInst *SI = SIToUnfold.getInst();
PHINode *SIUse = SIToUnfold.getUse();
BasicBlock *StartBlock = SI->getParent();
BasicBlock *EndBlock = SIUse->getParent();
BranchInst *StartBlockTerm =
dyn_cast<BranchInst>(StartBlock->getTerminator());
assert(StartBlockTerm && StartBlockTerm->isUnconditional());
assert(SI->hasOneUse());
BasicBlock *TrueBlock = nullptr;
BasicBlock *FalseBlock = nullptr;
BranchInst *TrueBranch = nullptr;
BranchInst *FalseBranch = nullptr;
if (SelectInst *SIOp = dyn_cast<SelectInst>(SI->getTrueValue())) {
createBasicBlockAndSinkSelectInst(DTU, SI, SIUse, SIOp, EndBlock,
"si.unfold.true", &TrueBlock, &TrueBranch,
NewSIsToUnfold, NewBBs);
}
if (SelectInst *SIOp = dyn_cast<SelectInst>(SI->getFalseValue())) {
createBasicBlockAndSinkSelectInst(DTU, SI, SIUse, SIOp, EndBlock,
"si.unfold.false", &FalseBlock,
&FalseBranch, NewSIsToUnfold, NewBBs);
}
if (!TrueBlock && !FalseBlock) {
FalseBlock = BasicBlock::Create(SI->getContext(), "si.unfold.false",
EndBlock->getParent(), EndBlock);
NewBBs->push_back(FalseBlock);
BranchInst::Create(EndBlock, FalseBlock);
DTU->applyUpdates({{DominatorTree::Insert, FalseBlock, EndBlock}});
}
BasicBlock *TT = EndBlock;
BasicBlock *FT = EndBlock;
if (TrueBlock && FalseBlock) {
TT = TrueBlock;
FT = FalseBlock;
SIUse->removeIncomingValue(StartBlock, false);
SIUse->addIncoming(SI->getTrueValue(), TrueBlock);
SIUse->addIncoming(SI->getFalseValue(), FalseBlock);
for (PHINode &Phi : EndBlock->phis()) {
if (&Phi != SIUse) {
Phi.addIncoming(Phi.getIncomingValueForBlock(StartBlock), TrueBlock);
Phi.addIncoming(Phi.getIncomingValueForBlock(StartBlock), FalseBlock);
}
}
} else {
BasicBlock *NewBlock = nullptr;
Value *SIOp1 = SI->getTrueValue();
Value *SIOp2 = SI->getFalseValue();
if (!TrueBlock) {
NewBlock = FalseBlock;
FT = FalseBlock;
}
else {
NewBlock = TrueBlock;
TT = TrueBlock;
std::swap(SIOp1, SIOp2);
}
for (unsigned Idx = 0; Idx < SIUse->getNumIncomingValues(); ++Idx) {
if (SIUse->getIncomingBlock(Idx) == StartBlock)
SIUse->setIncomingValue(Idx, SIOp1);
}
SIUse->addIncoming(SIOp2, NewBlock);
for (auto II = EndBlock->begin(); PHINode *Phi = dyn_cast<PHINode>(II);
++II) {
if (Phi != SIUse)
Phi->addIncoming(Phi->getIncomingValueForBlock(StartBlock), NewBlock);
}
}
StartBlockTerm->eraseFromParent();
BranchInst::Create(TT, FT, SI->getCondition(), StartBlock);
DTU->applyUpdates({{DominatorTree::Insert, StartBlock, TT},
{DominatorTree::Insert, StartBlock, FT}});
SI->eraseFromParent();
}
struct ClonedBlock {
BasicBlock *BB;
uint64_t State; };
typedef std::deque<BasicBlock *> PathType;
typedef std::vector<PathType> PathsType;
typedef SmallPtrSet<const BasicBlock *, 8> VisitedBlocks;
typedef std::vector<ClonedBlock> CloneList;
typedef DenseMap<BasicBlock *, CloneList> DuplicateBlockMap;
typedef MapVector<Instruction *, std::vector<Instruction *>> DefMap;
inline raw_ostream &operator<<(raw_ostream &OS, const PathType &Path) {
OS << "< ";
for (const BasicBlock *BB : Path) {
std::string BBName;
if (BB->hasName())
raw_string_ostream(BBName) << BB->getName();
else
raw_string_ostream(BBName) << BB;
OS << BBName << " ";
}
OS << ">";
return OS;
}
struct ThreadingPath {
uint64_t getExitValue() const { return ExitVal; }
void setExitValue(const ConstantInt *V) {
ExitVal = V->getZExtValue();
IsExitValSet = true;
}
bool isExitValueSet() const { return IsExitValSet; }
const BasicBlock *getDeterminatorBB() const { return DBB; }
void setDeterminator(const BasicBlock *BB) { DBB = BB; }
const PathType &getPath() const { return Path; }
void setPath(const PathType &NewPath) { Path = NewPath; }
void print(raw_ostream &OS) const {
OS << Path << " [ " << ExitVal << ", " << DBB->getName() << " ]";
}
private:
PathType Path;
uint64_t ExitVal;
const BasicBlock *DBB = nullptr;
bool IsExitValSet = false;
};
#ifndef NDEBUG
inline raw_ostream &operator<<(raw_ostream &OS, const ThreadingPath &TPath) {
TPath.print(OS);
return OS;
}
#endif
struct MainSwitch {
MainSwitch(SwitchInst *SI, OptimizationRemarkEmitter *ORE) {
if (isCandidate(SI)) {
Instr = SI;
} else {
ORE->emit([&]() {
return OptimizationRemarkMissed(DEBUG_TYPE, "SwitchNotPredictable", SI)
<< "Switch instruction is not predictable.";
});
}
}
virtual ~MainSwitch() = default;
SwitchInst *getInstr() const { return Instr; }
const SmallVector<SelectInstToUnfold, 4> getSelectInsts() {
return SelectInsts;
}
private:
bool isCandidate(const SwitchInst *SI) {
std::deque<Value *> Q;
SmallSet<Value *, 16> SeenValues;
SelectInsts.clear();
Value *SICond = SI->getCondition();
LLVM_DEBUG(dbgs() << "\tSICond: " << *SICond << "\n");
if (!isa<PHINode>(SICond))
return false;
addToQueue(SICond, Q, SeenValues);
while (!Q.empty()) {
Value *Current = Q.front();
Q.pop_front();
if (auto *Phi = dyn_cast<PHINode>(Current)) {
for (Value *Incoming : Phi->incoming_values()) {
addToQueue(Incoming, Q, SeenValues);
}
LLVM_DEBUG(dbgs() << "\tphi: " << *Phi << "\n");
} else if (SelectInst *SelI = dyn_cast<SelectInst>(Current)) {
if (!isValidSelectInst(SelI))
return false;
addToQueue(SelI->getTrueValue(), Q, SeenValues);
addToQueue(SelI->getFalseValue(), Q, SeenValues);
LLVM_DEBUG(dbgs() << "\tselect: " << *SelI << "\n");
if (auto *SelIUse = dyn_cast<PHINode>(SelI->user_back()))
SelectInsts.push_back(SelectInstToUnfold(SelI, SelIUse));
} else if (isa<Constant>(Current)) {
LLVM_DEBUG(dbgs() << "\tconst: " << *Current << "\n");
continue;
} else {
LLVM_DEBUG(dbgs() << "\tother: " << *Current << "\n");
continue;
}
}
return true;
}
void addToQueue(Value *Val, std::deque<Value *> &Q,
SmallSet<Value *, 16> &SeenValues) {
if (SeenValues.contains(Val))
return;
Q.push_back(Val);
SeenValues.insert(Val);
}
bool isValidSelectInst(SelectInst *SI) {
if (!SI->hasOneUse())
return false;
Instruction *SIUse = dyn_cast<Instruction>(SI->user_back());
if (!SIUse && !(isa<PHINode>(SIUse) || isa<SelectInst>(SIUse)))
return false;
BasicBlock *SIBB = SI->getParent();
BranchInst *SITerm = dyn_cast<BranchInst>(SIBB->getTerminator());
if (!SITerm || !SITerm->isUnconditional())
return false;
if (isa<PHINode>(SIUse) &&
SIBB->getSingleSuccessor() != cast<Instruction>(SIUse)->getParent())
return false;
for (SelectInstToUnfold SIToUnfold : SelectInsts) {
SelectInst *PrevSI = SIToUnfold.getInst();
if (PrevSI->getTrueValue() != SI && PrevSI->getFalseValue() != SI &&
PrevSI->getParent() == SI->getParent())
return false;
}
return true;
}
SwitchInst *Instr = nullptr;
SmallVector<SelectInstToUnfold, 4> SelectInsts;
};
struct AllSwitchPaths {
AllSwitchPaths(const MainSwitch *MSwitch, OptimizationRemarkEmitter *ORE)
: Switch(MSwitch->getInstr()), SwitchBlock(Switch->getParent()),
ORE(ORE) {}
std::vector<ThreadingPath> &getThreadingPaths() { return TPaths; }
unsigned getNumThreadingPaths() { return TPaths.size(); }
SwitchInst *getSwitchInst() { return Switch; }
BasicBlock *getSwitchBlock() { return SwitchBlock; }
void run() {
VisitedBlocks Visited;
PathsType LoopPaths = paths(SwitchBlock, Visited, 1);
StateDefMap StateDef = getStateDefMap(LoopPaths);
if (StateDef.empty()) {
ORE->emit([&]() {
return OptimizationRemarkMissed(DEBUG_TYPE, "SwitchNotPredictable",
Switch)
<< "Switch instruction is not predictable.";
});
return;
}
for (PathType Path : LoopPaths) {
ThreadingPath TPath;
const BasicBlock *PrevBB = Path.back();
for (const BasicBlock *BB : Path) {
if (StateDef.count(BB) != 0) {
const PHINode *Phi = dyn_cast<PHINode>(StateDef[BB]);
assert(Phi && "Expected a state-defining instr to be a phi node.");
const Value *V = Phi->getIncomingValueForBlock(PrevBB);
if (const ConstantInt *C = dyn_cast<const ConstantInt>(V)) {
TPath.setExitValue(C);
TPath.setDeterminator(BB);
TPath.setPath(Path);
}
}
if (TPath.isExitValueSet() && BB == Path.front())
break;
PrevBB = BB;
}
if (TPath.isExitValueSet() && isSupported(TPath))
TPaths.push_back(TPath);
}
}
private:
typedef DenseMap<const BasicBlock *, const PHINode *> StateDefMap;
PathsType paths(BasicBlock *BB, VisitedBlocks &Visited,
unsigned PathDepth) const {
PathsType Res;
if (PathDepth > MaxPathLength) {
ORE->emit([&]() {
return OptimizationRemarkAnalysis(DEBUG_TYPE, "MaxPathLengthReached",
Switch)
<< "Exploration stopped after visiting MaxPathLength="
<< ore::NV("MaxPathLength", MaxPathLength) << " blocks.";
});
return Res;
}
Visited.insert(BB);
SmallSet<BasicBlock *, 4> Successors;
for (BasicBlock *Succ : successors(BB)) {
if (!Successors.insert(Succ).second)
continue;
if (Succ == SwitchBlock) {
Res.push_back({BB});
continue;
}
if (Visited.contains(Succ))
continue;
PathsType SuccPaths = paths(Succ, Visited, PathDepth + 1);
for (PathType Path : SuccPaths) {
PathType NewPath(Path);
NewPath.push_front(BB);
Res.push_back(NewPath);
if (Res.size() >= MaxNumPaths) {
return Res;
}
}
}
Visited.erase(BB);
return Res;
}
StateDefMap getStateDefMap(const PathsType &LoopPaths) const {
StateDefMap Res;
SmallPtrSet<BasicBlock *, 16> LoopBBs;
for (const PathType &Path : LoopPaths) {
for (BasicBlock *BB : Path)
LoopBBs.insert(BB);
}
Value *FirstDef = Switch->getOperand(0);
assert(isa<PHINode>(FirstDef) && "The first definition must be a phi.");
SmallVector<PHINode *, 8> Stack;
Stack.push_back(dyn_cast<PHINode>(FirstDef));
SmallSet<Value *, 16> SeenValues;
while (!Stack.empty()) {
PHINode *CurPhi = Stack.pop_back_val();
Res[CurPhi->getParent()] = CurPhi;
SeenValues.insert(CurPhi);
for (BasicBlock *IncomingBB : CurPhi->blocks()) {
Value *Incoming = CurPhi->getIncomingValueForBlock(IncomingBB);
bool IsOutsideLoops = LoopBBs.count(IncomingBB) == 0;
if (Incoming == FirstDef || isa<ConstantInt>(Incoming) ||
SeenValues.contains(Incoming) || IsOutsideLoops) {
continue;
}
if (!isa<PHINode>(Incoming))
return StateDefMap();
Stack.push_back(cast<PHINode>(Incoming));
}
}
return Res;
}
bool isSupported(const ThreadingPath &TPath) {
Instruction *SwitchCondI = dyn_cast<Instruction>(Switch->getCondition());
assert(SwitchCondI);
if (!SwitchCondI)
return false;
const BasicBlock *SwitchCondDefBB = SwitchCondI->getParent();
const BasicBlock *SwitchCondUseBB = Switch->getParent();
const BasicBlock *DeterminatorBB = TPath.getDeterminatorBB();
assert(
SwitchCondUseBB == TPath.getPath().front() &&
"The first BB in a threading path should have the switch instruction");
if (SwitchCondUseBB != TPath.getPath().front())
return false;
PathType Path = TPath.getPath();
auto ItDet = std::find(Path.begin(), Path.end(), DeterminatorBB);
std::rotate(Path.begin(), ItDet, Path.end());
bool IsDetBBSeen = false;
bool IsDefBBSeen = false;
bool IsUseBBSeen = false;
for (BasicBlock *BB : Path) {
if (BB == DeterminatorBB)
IsDetBBSeen = true;
if (BB == SwitchCondDefBB)
IsDefBBSeen = true;
if (BB == SwitchCondUseBB)
IsUseBBSeen = true;
if (IsDetBBSeen && IsUseBBSeen && !IsDefBBSeen)
return false;
}
return true;
}
SwitchInst *Switch;
BasicBlock *SwitchBlock;
OptimizationRemarkEmitter *ORE;
std::vector<ThreadingPath> TPaths;
};
struct TransformDFA {
TransformDFA(AllSwitchPaths *SwitchPaths, DominatorTree *DT,
AssumptionCache *AC, TargetTransformInfo *TTI,
OptimizationRemarkEmitter *ORE,
SmallPtrSet<const Value *, 32> EphValues)
: SwitchPaths(SwitchPaths), DT(DT), AC(AC), TTI(TTI), ORE(ORE),
EphValues(EphValues) {}
void run() {
if (isLegalAndProfitableToTransform()) {
createAllExitPaths();
NumTransforms++;
}
}
private:
bool isLegalAndProfitableToTransform() {
CodeMetrics Metrics;
SwitchInst *Switch = SwitchPaths->getSwitchInst();
DuplicateBlockMap DuplicateMap;
for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) {
PathType PathBBs = TPath.getPath();
uint64_t NextState = TPath.getExitValue();
const BasicBlock *Determinator = TPath.getDeterminatorBB();
BasicBlock *BB = SwitchPaths->getSwitchBlock();
BasicBlock *VisitedBB = getClonedBB(BB, NextState, DuplicateMap);
if (!VisitedBB) {
Metrics.analyzeBasicBlock(BB, *TTI, EphValues);
DuplicateMap[BB].push_back({BB, NextState});
}
if (PathBBs.front() == Determinator)
continue;
auto DetIt = std::find(PathBBs.begin(), PathBBs.end(), Determinator);
for (auto BBIt = DetIt; BBIt != PathBBs.end(); BBIt++) {
BB = *BBIt;
VisitedBB = getClonedBB(BB, NextState, DuplicateMap);
if (VisitedBB)
continue;
Metrics.analyzeBasicBlock(BB, *TTI, EphValues);
DuplicateMap[BB].push_back({BB, NextState});
}
if (Metrics.notDuplicatable) {
LLVM_DEBUG(dbgs() << "DFA Jump Threading: Not jump threading, contains "
<< "non-duplicatable instructions.\n");
ORE->emit([&]() {
return OptimizationRemarkMissed(DEBUG_TYPE, "NonDuplicatableInst",
Switch)
<< "Contains non-duplicatable instructions.";
});
return false;
}
if (Metrics.convergent) {
LLVM_DEBUG(dbgs() << "DFA Jump Threading: Not jump threading, contains "
<< "convergent instructions.\n");
ORE->emit([&]() {
return OptimizationRemarkMissed(DEBUG_TYPE, "ConvergentInst", Switch)
<< "Contains convergent instructions.";
});
return false;
}
if (!Metrics.NumInsts.isValid()) {
LLVM_DEBUG(dbgs() << "DFA Jump Threading: Not jump threading, contains "
<< "instructions with invalid cost.\n");
ORE->emit([&]() {
return OptimizationRemarkMissed(DEBUG_TYPE, "ConvergentInst", Switch)
<< "Contains instructions with invalid cost.";
});
return false;
}
}
unsigned DuplicationCost = 0;
unsigned JumpTableSize = 0;
TTI->getEstimatedNumberOfCaseClusters(*Switch, JumpTableSize, nullptr,
nullptr);
if (JumpTableSize == 0) {
unsigned CondBranches =
APInt(32, Switch->getNumSuccessors()).ceilLogBase2();
DuplicationCost = *Metrics.NumInsts.getValue() / CondBranches;
} else {
DuplicationCost = *Metrics.NumInsts.getValue() / JumpTableSize;
}
LLVM_DEBUG(dbgs() << "\nDFA Jump Threading: Cost to jump thread block "
<< SwitchPaths->getSwitchBlock()->getName()
<< " is: " << DuplicationCost << "\n\n");
if (DuplicationCost > CostThreshold) {
LLVM_DEBUG(dbgs() << "Not jump threading, duplication cost exceeds the "
<< "cost threshold.\n");
ORE->emit([&]() {
return OptimizationRemarkMissed(DEBUG_TYPE, "NotProfitable", Switch)
<< "Duplication cost exceeds the cost threshold (cost="
<< ore::NV("Cost", DuplicationCost)
<< ", threshold=" << ore::NV("Threshold", CostThreshold) << ").";
});
return false;
}
ORE->emit([&]() {
return OptimizationRemark(DEBUG_TYPE, "JumpThreaded", Switch)
<< "Switch statement jump-threaded.";
});
return true;
}
void createAllExitPaths() {
DomTreeUpdater DTU(*DT, DomTreeUpdater::UpdateStrategy::Eager);
BasicBlock *SwitchBlock = SwitchPaths->getSwitchBlock();
for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) {
LLVM_DEBUG(dbgs() << TPath << "\n");
PathType NewPath(TPath.getPath());
NewPath.push_back(SwitchBlock);
TPath.setPath(NewPath);
}
DuplicateBlockMap DuplicateMap;
DefMap NewDefs;
SmallSet<BasicBlock *, 16> BlocksToClean;
for (BasicBlock *BB : successors(SwitchBlock))
BlocksToClean.insert(BB);
for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) {
createExitPath(NewDefs, TPath, DuplicateMap, BlocksToClean, &DTU);
NumPaths++;
}
for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths())
updateLastSuccessor(TPath, DuplicateMap, &DTU);
updateSSA(NewDefs);
for (BasicBlock *BB : BlocksToClean)
cleanPhiNodes(BB);
}
void createExitPath(DefMap &NewDefs, ThreadingPath &Path,
DuplicateBlockMap &DuplicateMap,
SmallSet<BasicBlock *, 16> &BlocksToClean,
DomTreeUpdater *DTU) {
uint64_t NextState = Path.getExitValue();
const BasicBlock *Determinator = Path.getDeterminatorBB();
PathType PathBBs = Path.getPath();
if (PathBBs.front() == Determinator)
PathBBs.pop_front();
auto DetIt = std::find(PathBBs.begin(), PathBBs.end(), Determinator);
auto Prev = std::prev(DetIt);
BasicBlock *PrevBB = *Prev;
for (auto BBIt = DetIt; BBIt != PathBBs.end(); BBIt++) {
BasicBlock *BB = *BBIt;
BlocksToClean.insert(BB);
BasicBlock *NextBB = getClonedBB(BB, NextState, DuplicateMap);
if (NextBB) {
updatePredecessor(PrevBB, BB, NextBB, DTU);
PrevBB = NextBB;
continue;
}
BasicBlock *NewBB = cloneBlockAndUpdatePredecessor(
BB, PrevBB, NextState, DuplicateMap, NewDefs, DTU);
DuplicateMap[BB].push_back({NewBB, NextState});
BlocksToClean.insert(NewBB);
PrevBB = NewBB;
}
}
void updateSSA(DefMap &NewDefs) {
SSAUpdaterBulk SSAUpdate;
SmallVector<Use *, 16> UsesToRename;
for (auto KV : NewDefs) {
Instruction *I = KV.first;
BasicBlock *BB = I->getParent();
std::vector<Instruction *> Cloned = KV.second;
for (Use &U : I->uses()) {
Instruction *User = cast<Instruction>(U.getUser());
if (PHINode *UserPN = dyn_cast<PHINode>(User)) {
if (UserPN->getIncomingBlock(U) == BB)
continue;
} else if (User->getParent() == BB) {
continue;
}
UsesToRename.push_back(&U);
}
if (UsesToRename.empty())
continue;
LLVM_DEBUG(dbgs() << "DFA-JT: Renaming non-local uses of: " << *I
<< "\n");
unsigned VarNum = SSAUpdate.AddVariable(I->getName(), I->getType());
SSAUpdate.AddAvailableValue(VarNum, BB, I);
for (Instruction *New : Cloned)
SSAUpdate.AddAvailableValue(VarNum, New->getParent(), New);
while (!UsesToRename.empty())
SSAUpdate.AddUse(VarNum, UsesToRename.pop_back_val());
LLVM_DEBUG(dbgs() << "\n");
}
SSAUpdate.RewriteAllUses(DT);
}
BasicBlock *cloneBlockAndUpdatePredecessor(BasicBlock *BB, BasicBlock *PrevBB,
uint64_t NextState,
DuplicateBlockMap &DuplicateMap,
DefMap &NewDefs,
DomTreeUpdater *DTU) {
ValueToValueMapTy VMap;
BasicBlock *NewBB = CloneBasicBlock(
BB, VMap, ".jt" + std::to_string(NextState), BB->getParent());
NewBB->moveAfter(BB);
NumCloned++;
for (Instruction &I : *NewBB) {
if (isa<PHINode>(&I))
continue;
RemapInstruction(&I, VMap,
RF_IgnoreMissingLocals | RF_NoModuleLevelChanges);
if (AssumeInst *II = dyn_cast<AssumeInst>(&I))
AC->registerAssumption(II);
}
updateSuccessorPhis(BB, NewBB, NextState, VMap, DuplicateMap);
updatePredecessor(PrevBB, BB, NewBB, DTU);
updateDefMap(NewDefs, VMap);
SmallPtrSet<BasicBlock *, 4> SuccSet;
for (auto *SuccBB : successors(NewBB)) {
if (SuccSet.insert(SuccBB).second)
DTU->applyUpdates({{DominatorTree::Insert, NewBB, SuccBB}});
}
SuccSet.clear();
return NewBB;
}
void updateSuccessorPhis(BasicBlock *BB, BasicBlock *ClonedBB,
uint64_t NextState, ValueToValueMapTy &VMap,
DuplicateBlockMap &DuplicateMap) {
std::vector<BasicBlock *> BlocksToUpdate;
if (BB == SwitchPaths->getSwitchBlock()) {
SwitchInst *Switch = SwitchPaths->getSwitchInst();
BasicBlock *NextCase = getNextCaseSuccessor(Switch, NextState);
BlocksToUpdate.push_back(NextCase);
BasicBlock *ClonedSucc = getClonedBB(NextCase, NextState, DuplicateMap);
if (ClonedSucc)
BlocksToUpdate.push_back(ClonedSucc);
}
else {
for (BasicBlock *Succ : successors(BB)) {
BlocksToUpdate.push_back(Succ);
BasicBlock *ClonedSucc = getClonedBB(Succ, NextState, DuplicateMap);
if (ClonedSucc)
BlocksToUpdate.push_back(ClonedSucc);
}
}
for (BasicBlock *Succ : BlocksToUpdate) {
for (auto II = Succ->begin(); PHINode *Phi = dyn_cast<PHINode>(II);
++II) {
Value *Incoming = Phi->getIncomingValueForBlock(BB);
if (Incoming) {
if (isa<Constant>(Incoming)) {
Phi->addIncoming(Incoming, ClonedBB);
continue;
}
Value *ClonedVal = VMap[Incoming];
if (ClonedVal)
Phi->addIncoming(ClonedVal, ClonedBB);
else
Phi->addIncoming(Incoming, ClonedBB);
}
}
}
}
void updatePredecessor(BasicBlock *PrevBB, BasicBlock *OldBB,
BasicBlock *NewBB, DomTreeUpdater *DTU) {
if (!isPredecessor(OldBB, PrevBB))
return;
Instruction *PrevTerm = PrevBB->getTerminator();
for (unsigned Idx = 0; Idx < PrevTerm->getNumSuccessors(); Idx++) {
if (PrevTerm->getSuccessor(Idx) == OldBB) {
OldBB->removePredecessor(PrevBB, true);
PrevTerm->setSuccessor(Idx, NewBB);
}
}
DTU->applyUpdates({{DominatorTree::Delete, PrevBB, OldBB},
{DominatorTree::Insert, PrevBB, NewBB}});
}
void updateDefMap(DefMap &NewDefs, ValueToValueMapTy &VMap) {
SmallVector<std::pair<Instruction *, Instruction *>> NewDefsVector;
NewDefsVector.reserve(VMap.size());
for (auto Entry : VMap) {
Instruction *Inst =
dyn_cast<Instruction>(const_cast<Value *>(Entry.first));
if (!Inst || !Entry.second || isa<BranchInst>(Inst) ||
isa<SwitchInst>(Inst)) {
continue;
}
Instruction *Cloned = dyn_cast<Instruction>(Entry.second);
if (!Cloned)
continue;
NewDefsVector.push_back({Inst, Cloned});
}
sort(NewDefsVector, [](const auto &LHS, const auto &RHS) {
if (LHS.first == RHS.first)
return LHS.second->comesBefore(RHS.second);
return LHS.first->comesBefore(RHS.first);
});
for (const auto &KV : NewDefsVector)
NewDefs[KV.first].push_back(KV.second);
}
void updateLastSuccessor(ThreadingPath &TPath,
DuplicateBlockMap &DuplicateMap,
DomTreeUpdater *DTU) {
uint64_t NextState = TPath.getExitValue();
BasicBlock *BB = TPath.getPath().back();
BasicBlock *LastBlock = getClonedBB(BB, NextState, DuplicateMap);
if (!isa<SwitchInst>(LastBlock->getTerminator()))
return;
SwitchInst *Switch = cast<SwitchInst>(LastBlock->getTerminator());
BasicBlock *NextCase = getNextCaseSuccessor(Switch, NextState);
std::vector<DominatorTree::UpdateType> DTUpdates;
SmallPtrSet<BasicBlock *, 4> SuccSet;
for (BasicBlock *Succ : successors(LastBlock)) {
if (Succ != NextCase && SuccSet.insert(Succ).second)
DTUpdates.push_back({DominatorTree::Delete, LastBlock, Succ});
}
Switch->eraseFromParent();
BranchInst::Create(NextCase, LastBlock);
DTU->applyUpdates(DTUpdates);
}
void cleanPhiNodes(BasicBlock *BB) {
if (pred_empty(BB)) {
std::vector<PHINode *> PhiToRemove;
for (auto II = BB->begin(); PHINode *Phi = dyn_cast<PHINode>(II); ++II) {
PhiToRemove.push_back(Phi);
}
for (PHINode *PN : PhiToRemove) {
PN->replaceAllUsesWith(PoisonValue::get(PN->getType()));
PN->eraseFromParent();
}
return;
}
for (auto II = BB->begin(); PHINode *Phi = dyn_cast<PHINode>(II); ++II) {
std::vector<BasicBlock *> BlocksToRemove;
for (BasicBlock *IncomingBB : Phi->blocks()) {
if (!isPredecessor(BB, IncomingBB))
BlocksToRemove.push_back(IncomingBB);
}
for (BasicBlock *BB : BlocksToRemove)
Phi->removeIncomingValue(BB);
}
}
BasicBlock *getClonedBB(BasicBlock *BB, uint64_t NextState,
DuplicateBlockMap &DuplicateMap) {
CloneList ClonedBBs = DuplicateMap[BB];
auto It = llvm::find_if(ClonedBBs, [NextState](const ClonedBlock &C) {
return C.State == NextState;
});
return It != ClonedBBs.end() ? (*It).BB : nullptr;
}
BasicBlock *getNextCaseSuccessor(SwitchInst *Switch, uint64_t NextState) {
BasicBlock *NextCase = nullptr;
for (auto Case : Switch->cases()) {
if (Case.getCaseValue()->getZExtValue() == NextState) {
NextCase = Case.getCaseSuccessor();
break;
}
}
if (!NextCase)
NextCase = Switch->getDefaultDest();
return NextCase;
}
bool isPredecessor(BasicBlock *BB, BasicBlock *IncomingBB) {
return llvm::is_contained(predecessors(BB), IncomingBB);
}
AllSwitchPaths *SwitchPaths;
DominatorTree *DT;
AssumptionCache *AC;
TargetTransformInfo *TTI;
OptimizationRemarkEmitter *ORE;
SmallPtrSet<const Value *, 32> EphValues;
std::vector<ThreadingPath> TPaths;
};
bool DFAJumpThreading::run(Function &F) {
LLVM_DEBUG(dbgs() << "\nDFA Jump threading: " << F.getName() << "\n");
if (F.hasOptSize()) {
LLVM_DEBUG(dbgs() << "Skipping due to the 'minsize' attribute\n");
return false;
}
if (ClViewCfgBefore)
F.viewCFG();
SmallVector<AllSwitchPaths, 2> ThreadableLoops;
bool MadeChanges = false;
for (BasicBlock &BB : F) {
auto *SI = dyn_cast<SwitchInst>(BB.getTerminator());
if (!SI)
continue;
LLVM_DEBUG(dbgs() << "\nCheck if SwitchInst in BB " << BB.getName()
<< " is a candidate\n");
MainSwitch Switch(SI, ORE);
if (!Switch.getInstr())
continue;
LLVM_DEBUG(dbgs() << "\nSwitchInst in BB " << BB.getName() << " is a "
<< "candidate for jump threading\n");
LLVM_DEBUG(SI->dump());
unfoldSelectInstrs(DT, Switch.getSelectInsts());
if (!Switch.getSelectInsts().empty())
MadeChanges = true;
AllSwitchPaths SwitchPaths(&Switch, ORE);
SwitchPaths.run();
if (SwitchPaths.getNumThreadingPaths() > 0) {
ThreadableLoops.push_back(SwitchPaths);
break;
}
}
SmallPtrSet<const Value *, 32> EphValues;
if (ThreadableLoops.size() > 0)
CodeMetrics::collectEphemeralValues(&F, AC, EphValues);
for (AllSwitchPaths SwitchPaths : ThreadableLoops) {
TransformDFA Transform(&SwitchPaths, DT, AC, TTI, ORE, EphValues);
Transform.run();
MadeChanges = true;
}
#ifdef EXPENSIVE_CHECKS
assert(DT->verify(DominatorTree::VerificationLevel::Full));
verifyFunction(F, &dbgs());
#endif
return MadeChanges;
}
}
PreservedAnalyses DFAJumpThreadingPass::run(Function &F,
FunctionAnalysisManager &AM) {
AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
OptimizationRemarkEmitter ORE(&F);
if (!DFAJumpThreading(&AC, &DT, &TTI, &ORE).run(F))
return PreservedAnalyses::all();
PreservedAnalyses PA;
PA.preserve<DominatorTreeAnalysis>();
return PA;
}