#include "llvm/Transforms/Scalar/LoopBoundSplit.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/Analysis/LoopAnalysisManager.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Transforms/Scalar/LoopPassManager.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/LoopSimplify.h"
#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
#define DEBUG_TYPE "loop-bound-split"
namespace llvm {
using namespace PatternMatch;
namespace {
struct ConditionInfo {
BranchInst *BI = nullptr;
ICmpInst *ICmp = nullptr;
ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
Value *AddRecValue = nullptr;
Value *NonPHIAddRecValue;
Value *BoundValue = nullptr;
const SCEVAddRecExpr *AddRecSCEV = nullptr;
const SCEV *BoundSCEV = nullptr;
ConditionInfo() = default;
};
}
static void analyzeICmp(ScalarEvolution &SE, ICmpInst *ICmp,
ConditionInfo &Cond, const Loop &L) {
Cond.ICmp = ICmp;
if (match(ICmp, m_ICmp(Cond.Pred, m_Value(Cond.AddRecValue),
m_Value(Cond.BoundValue)))) {
const SCEV *AddRecSCEV = SE.getSCEV(Cond.AddRecValue);
const SCEV *BoundSCEV = SE.getSCEV(Cond.BoundValue);
const SCEVAddRecExpr *LHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV);
const SCEVAddRecExpr *RHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(BoundSCEV);
if (!LHSAddRecSCEV && RHSAddRecSCEV) {
std::swap(Cond.AddRecValue, Cond.BoundValue);
std::swap(AddRecSCEV, BoundSCEV);
Cond.Pred = ICmpInst::getSwappedPredicate(Cond.Pred);
}
Cond.AddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV);
Cond.BoundSCEV = BoundSCEV;
Cond.NonPHIAddRecValue = Cond.AddRecValue;
if (Cond.AddRecSCEV && isa<PHINode>(Cond.AddRecValue)) {
PHINode *PN = cast<PHINode>(Cond.AddRecValue);
Cond.NonPHIAddRecValue = PN->getIncomingValueForBlock(L.getLoopLatch());
}
}
}
static bool calculateUpperBound(const Loop &L, ScalarEvolution &SE,
ConditionInfo &Cond, bool IsExitCond) {
if (IsExitCond) {
const SCEV *ExitCount = SE.getExitCount(&L, Cond.ICmp->getParent());
if (isa<SCEVCouldNotCompute>(ExitCount))
return false;
Cond.BoundSCEV = ExitCount;
return true;
}
if (Cond.Pred == ICmpInst::ICMP_SLT || Cond.Pred == ICmpInst::ICMP_ULT)
return true;
if (Cond.Pred != ICmpInst::ICMP_ULE && Cond.Pred != ICmpInst::ICMP_SLE)
return false;
if (IntegerType *BoundSCEVIntType =
dyn_cast<IntegerType>(Cond.BoundSCEV->getType())) {
unsigned BitWidth = BoundSCEVIntType->getBitWidth();
APInt Max = ICmpInst::isSigned(Cond.Pred)
? APInt::getSignedMaxValue(BitWidth)
: APInt::getMaxValue(BitWidth);
const SCEV *MaxSCEV = SE.getConstant(Max);
ICmpInst::Predicate Pred =
ICmpInst::isSigned(Cond.Pred) ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
if (SE.isKnownPredicate(Pred, Cond.BoundSCEV, MaxSCEV)) {
const SCEV *BoundPlusOneSCEV =
SE.getAddExpr(Cond.BoundSCEV, SE.getOne(BoundSCEVIntType));
Cond.BoundSCEV = BoundPlusOneSCEV;
Cond.Pred = Pred;
return true;
}
}
return false;
}
static bool hasProcessableCondition(const Loop &L, ScalarEvolution &SE,
ICmpInst *ICmp, ConditionInfo &Cond,
bool IsExitCond) {
analyzeICmp(SE, ICmp, Cond, L);
if (!SE.isAvailableAtLoopEntry(Cond.BoundSCEV, &L))
return false;
if (!Cond.AddRecSCEV)
return false;
if (!Cond.AddRecSCEV->isAffine())
return false;
const SCEV *StepRecSCEV = Cond.AddRecSCEV->getStepRecurrence(SE);
if (!isa<SCEVConstant>(StepRecSCEV))
return false;
ConstantInt *StepCI = cast<SCEVConstant>(StepRecSCEV)->getValue();
if (StepCI->isNegative() || StepCI->isZero())
return false;
if (!calculateUpperBound(L, SE, Cond, IsExitCond))
return false;
return true;
}
static bool isProcessableCondBI(const ScalarEvolution &SE,
const BranchInst *BI) {
BasicBlock *TrueSucc = nullptr;
BasicBlock *FalseSucc = nullptr;
ICmpInst::Predicate Pred;
Value *LHS, *RHS;
if (!match(BI, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)),
m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc))))
return false;
if (!SE.isSCEVable(LHS->getType()))
return false;
assert(SE.isSCEVable(RHS->getType()) && "Expected RHS's type is SCEVable");
if (TrueSucc == FalseSucc)
return false;
return true;
}
static bool canSplitLoopBound(const Loop &L, const DominatorTree &DT,
ScalarEvolution &SE, ConditionInfo &Cond) {
if (L.getHeader()->getParent()->hasOptSize())
return false;
if (!L.isInnermost())
return false;
if (!L.isLoopSimplifyForm())
return false;
if (!L.isLCSSAForm(DT))
return false;
if (!L.isSafeToClone())
return false;
BasicBlock *ExitingBB = L.getExitingBlock();
if (!ExitingBB)
return false;
BranchInst *ExitingBI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
if (!ExitingBI)
return false;
if (!isProcessableCondBI(SE, ExitingBI))
return false;
ICmpInst *ICmp = cast<ICmpInst>(ExitingBI->getCondition());
if (!hasProcessableCondition(L, SE, ICmp, Cond, true))
return false;
Cond.BI = ExitingBI;
return true;
}
static bool isProfitableToTransform(const Loop &L, const BranchInst *BI) {
BasicBlock *Succ0 = BI->getSuccessor(0);
BasicBlock *Succ1 = BI->getSuccessor(1);
BasicBlock *Succ0Succ = Succ0->getSingleSuccessor();
BasicBlock *Succ1Succ = Succ1->getSingleSuccessor();
if (!Succ0Succ || !Succ1Succ || Succ0Succ != Succ1Succ)
return false;
return true;
}
static BranchInst *findSplitCandidate(const Loop &L, ScalarEvolution &SE,
ConditionInfo &ExitingCond,
ConditionInfo &SplitCandidateCond) {
for (auto *BB : L.blocks()) {
if (L.getLoopLatch() == BB)
continue;
auto *BI = dyn_cast<BranchInst>(BB->getTerminator());
if (!BI)
continue;
if (!isProcessableCondBI(SE, BI))
continue;
if (L.isLoopInvariant(BI->getCondition()))
continue;
ICmpInst *ICmp = cast<ICmpInst>(BI->getCondition());
if (!hasProcessableCondition(L, SE, ICmp, SplitCandidateCond,
false))
continue;
if (ExitingCond.BoundSCEV->getType() !=
SplitCandidateCond.BoundSCEV->getType())
continue;
if (!SE.isLoopEntryGuardedByCond(&L, SplitCandidateCond.Pred,
SplitCandidateCond.AddRecSCEV->getStart(),
SplitCandidateCond.BoundSCEV))
continue;
SplitCandidateCond.BI = BI;
return BI;
}
return nullptr;
}
static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI,
ScalarEvolution &SE, LPMUpdater &U) {
ConditionInfo SplitCandidateCond;
ConditionInfo ExitingCond;
if (!canSplitLoopBound(L, DT, SE, ExitingCond))
return false;
if (!findSplitCandidate(L, SE, ExitingCond, SplitCandidateCond))
return false;
if (!isProfitableToTransform(L, SplitCandidateCond.BI))
return false;
SmallVector<BasicBlock *, 8> PostLoopBlocks;
Loop *PostLoop;
ValueToValueMapTy VMap;
BasicBlock *PreHeader = L.getLoopPreheader();
BasicBlock *SplitLoopPH = SplitEdge(PreHeader, L.getHeader(), &DT, &LI);
PostLoop = cloneLoopWithPreheader(L.getExitBlock(), SplitLoopPH, &L, VMap,
".split", &LI, &DT, PostLoopBlocks);
remapInstructionsInBlocks(PostLoopBlocks, VMap);
BasicBlock *PostLoopPreHeader = PostLoop->getLoopPreheader();
IRBuilder<> Builder(&PostLoopPreHeader->front());
bool isExitingLatch =
(L.getExitingBlock() == L.getLoopLatch()) ? true : false;
Value *ExitingCondLCSSAPhi = nullptr;
for (PHINode &PN : L.getHeader()->phis()) {
PHINode *LCSSAPhi =
Builder.CreatePHI(PN.getType(), 1, PN.getName() + ".lcssa");
LCSSAPhi->setDebugLoc(PN.getDebugLoc());
LCSSAPhi->addIncoming(
isExitingLatch ? PN.getIncomingValueForBlock(L.getLoopLatch()) : &PN,
L.getExitingBlock());
PHINode *PostLoopPN = cast<PHINode>(VMap[&PN]);
PostLoopPN->setIncomingValueForBlock(PostLoopPreHeader, LCSSAPhi);
if (!SE.isSCEVable(PN.getType()))
continue;
const SCEVAddRecExpr *PhiSCEV = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
if (PhiSCEV && ExitingCond.NonPHIAddRecValue ==
PN.getIncomingValueForBlock(L.getLoopLatch()))
ExitingCondLCSSAPhi = LCSSAPhi;
}
Instruction *OrigBI = PostLoopPreHeader->getTerminator();
ICmpInst::Predicate Pred = ICmpInst::ICMP_NE;
Value *Cond =
Builder.CreateICmp(Pred, ExitingCondLCSSAPhi, ExitingCond.BoundValue);
Builder.CreateCondBr(Cond, PostLoop->getHeader(), PostLoop->getExitBlock());
OrigBI->eraseFromParent();
const SCEV *NewBoundSCEV = ExitingCond.BoundSCEV;
const SCEV *SplitBoundSCEV = SplitCandidateCond.BoundSCEV;
NewBoundSCEV = ICmpInst::isSigned(ExitingCond.Pred)
? SE.getSMinExpr(NewBoundSCEV, SplitBoundSCEV)
: SE.getUMinExpr(NewBoundSCEV, SplitBoundSCEV);
SCEVExpander Expander(
SE, L.getHeader()->getParent()->getParent()->getDataLayout(), "split");
Instruction *InsertPt = SplitLoopPH->getTerminator();
Value *NewBoundValue =
Expander.expandCodeFor(NewBoundSCEV, NewBoundSCEV->getType(), InsertPt);
NewBoundValue->setName("new.bound");
ExitingCond.ICmp->setOperand(1, NewBoundValue);
LLVMContext &Context = PreHeader->getContext();
SplitCandidateCond.BI->setCondition(ConstantInt::getTrue(Context));
BranchInst *ClonedSplitCandidateBI =
cast<BranchInst>(VMap[SplitCandidateCond.BI]);
ClonedSplitCandidateBI->setCondition(ConstantInt::getFalse(Context));
if (L.getExitBlock() == ExitingCond.BI->getSuccessor(0))
ExitingCond.BI->setSuccessor(0, PostLoopPreHeader);
else
ExitingCond.BI->setSuccessor(1, PostLoopPreHeader);
Builder.SetInsertPoint(&PostLoopPreHeader->front());
for (PHINode &PN : PostLoop->getExitBlock()->phis()) {
for (auto i : seq<int>(0, PN.getNumOperands())) {
if (PN.getIncomingBlock(i) == L.getExitingBlock()) {
Value *IncomingValue = PN.getIncomingValue(i);
PHINode *LCSSAPhi =
Builder.CreatePHI(PN.getType(), 1, PN.getName() + ".lcssa");
LCSSAPhi->setDebugLoc(PN.getDebugLoc());
LCSSAPhi->addIncoming(IncomingValue, PN.getIncomingBlock(i));
PN.setIncomingBlock(i, PostLoopPreHeader);
PN.setIncomingValue(i, LCSSAPhi);
PN.addIncoming(VMap[IncomingValue], PostLoop->getExitingBlock());
}
}
}
DT.changeImmediateDominator(PostLoopPreHeader, L.getExitingBlock());
DT.changeImmediateDominator(PostLoop->getExitBlock(), PostLoopPreHeader);
SE.forgetLoop(&L);
simplifyLoop(&L, &DT, &LI, &SE, nullptr, nullptr, true);
simplifyLoop(PostLoop, &DT, &LI, &SE, nullptr, nullptr, true);
U.addSiblingLoops(PostLoop);
return true;
}
PreservedAnalyses LoopBoundSplitPass::run(Loop &L, LoopAnalysisManager &AM,
LoopStandardAnalysisResults &AR,
LPMUpdater &U) {
Function &F = *L.getHeader()->getParent();
(void)F;
LLVM_DEBUG(dbgs() << "Spliting bound of loop in " << F.getName() << ": " << L
<< "\n");
if (!splitLoopBound(L, AR.DT, AR.LI, AR.SE, U))
return PreservedAnalyses::all();
assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast));
AR.LI.verify(AR.DT);
return getLoopPassPreservedAnalyses();
}
}