#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Target/TargetMachine.h"
using namespace llvm;
#define DEBUG_TYPE "indirectbr-expand"
namespace {
class IndirectBrExpandPass : public FunctionPass {
const TargetLowering *TLI = nullptr;
public:
static char ID;
IndirectBrExpandPass() : FunctionPass(ID) {
initializeIndirectBrExpandPassPass(*PassRegistry::getPassRegistry());
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addPreserved<DominatorTreeWrapperPass>();
}
bool runOnFunction(Function &F) override;
};
}
char IndirectBrExpandPass::ID = 0;
INITIALIZE_PASS_BEGIN(IndirectBrExpandPass, DEBUG_TYPE,
"Expand indirectbr instructions", false, false)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_END(IndirectBrExpandPass, DEBUG_TYPE,
"Expand indirectbr instructions", false, false)
FunctionPass *llvm::createIndirectBrExpandPass() {
return new IndirectBrExpandPass();
}
bool IndirectBrExpandPass::runOnFunction(Function &F) {
auto &DL = F.getParent()->getDataLayout();
auto *TPC = getAnalysisIfAvailable<TargetPassConfig>();
if (!TPC)
return false;
auto &TM = TPC->getTM<TargetMachine>();
auto &STI = *TM.getSubtargetImpl(F);
if (!STI.enableIndirectBrExpand())
return false;
TLI = STI.getTargetLowering();
Optional<DomTreeUpdater> DTU;
if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>())
DTU.emplace(DTWP->getDomTree(), DomTreeUpdater::UpdateStrategy::Lazy);
SmallVector<IndirectBrInst *, 1> IndirectBrs;
SmallPtrSet<BasicBlock *, 4> IndirectBrSuccs;
for (BasicBlock &BB : F)
if (auto *IBr = dyn_cast<IndirectBrInst>(BB.getTerminator())) {
if (IBr->getNumSuccessors() == 0) {
(void)new UnreachableInst(F.getContext(), IBr);
IBr->eraseFromParent();
continue;
}
IndirectBrs.push_back(IBr);
for (BasicBlock *SuccBB : IBr->successors())
IndirectBrSuccs.insert(SuccBB);
}
if (IndirectBrs.empty())
return false;
SmallVector<BasicBlock *, 4> BBs;
for (BasicBlock &BB : F) {
if (!IndirectBrSuccs.count(&BB))
continue;
auto IsBlockAddressUse = [&](const Use &U) {
return isa<BlockAddress>(U.getUser());
};
auto BlockAddressUseIt = llvm::find_if(BB.uses(), IsBlockAddressUse);
if (BlockAddressUseIt == BB.use_end())
continue;
assert(std::find_if(std::next(BlockAddressUseIt), BB.use_end(),
IsBlockAddressUse) == BB.use_end() &&
"There should only ever be a single blockaddress use because it is "
"a constant and should be uniqued.");
auto *BA = cast<BlockAddress>(BlockAddressUseIt->getUser());
if (!BA->isConstantUsed())
continue;
int BBIndex = BBs.size() + 1;
BBs.push_back(&BB);
auto *ITy = cast<IntegerType>(DL.getIntPtrType(BA->getType()));
ConstantInt *BBIndexC = ConstantInt::get(ITy, BBIndex);
BA->replaceAllUsesWith(ConstantExpr::getIntToPtr(BBIndexC, BA->getType()));
}
if (BBs.empty()) {
SmallVector<DominatorTree::UpdateType, 8> Updates;
if (DTU)
Updates.reserve(IndirectBrSuccs.size());
for (auto *IBr : IndirectBrs) {
if (DTU) {
for (BasicBlock *SuccBB : IBr->successors())
Updates.push_back({DominatorTree::Delete, IBr->getParent(), SuccBB});
}
(void)new UnreachableInst(F.getContext(), IBr);
IBr->eraseFromParent();
}
if (DTU) {
assert(Updates.size() == IndirectBrSuccs.size() &&
"Got unexpected update count.");
DTU->applyUpdates(Updates);
}
return true;
}
BasicBlock *SwitchBB;
Value *SwitchValue;
IntegerType *CommonITy = nullptr;
for (auto *IBr : IndirectBrs) {
auto *ITy =
cast<IntegerType>(DL.getIntPtrType(IBr->getAddress()->getType()));
if (!CommonITy || ITy->getBitWidth() > CommonITy->getBitWidth())
CommonITy = ITy;
}
auto GetSwitchValue = [DL, CommonITy](IndirectBrInst *IBr) {
return CastInst::CreatePointerCast(
IBr->getAddress(), CommonITy,
Twine(IBr->getAddress()->getName()) + ".switch_cast", IBr);
};
SmallVector<DominatorTree::UpdateType, 8> Updates;
if (IndirectBrs.size() == 1) {
IndirectBrInst *IBr = IndirectBrs[0];
SwitchBB = IBr->getParent();
SwitchValue = GetSwitchValue(IBr);
if (DTU) {
Updates.reserve(IndirectBrSuccs.size());
for (BasicBlock *SuccBB : IBr->successors())
Updates.push_back({DominatorTree::Delete, IBr->getParent(), SuccBB});
assert(Updates.size() == IndirectBrSuccs.size() &&
"Got unexpected update count.");
}
IBr->eraseFromParent();
} else {
SwitchBB = BasicBlock::Create(F.getContext(), "switch_bb", &F);
auto *SwitchPN = PHINode::Create(CommonITy, IndirectBrs.size(),
"switch_value_phi", SwitchBB);
SwitchValue = SwitchPN;
if (DTU)
Updates.reserve(IndirectBrs.size() + 2 * IndirectBrSuccs.size());
for (auto *IBr : IndirectBrs) {
SwitchPN->addIncoming(GetSwitchValue(IBr), IBr->getParent());
BranchInst::Create(SwitchBB, IBr);
if (DTU) {
Updates.push_back({DominatorTree::Insert, IBr->getParent(), SwitchBB});
for (BasicBlock *SuccBB : IBr->successors())
Updates.push_back({DominatorTree::Delete, IBr->getParent(), SuccBB});
}
IBr->eraseFromParent();
}
}
auto *SI = SwitchInst::Create(SwitchValue, BBs[0], BBs.size(), SwitchBB);
for (int i : llvm::seq<int>(1, BBs.size()))
SI->addCase(ConstantInt::get(CommonITy, i + 1), BBs[i]);
if (DTU) {
SmallPtrSet<BasicBlock *, 8> UniqueSuccessors;
Updates.reserve(Updates.size() + BBs.size());
for (BasicBlock *BB : BBs) {
if (UniqueSuccessors.insert(BB).second)
Updates.push_back({DominatorTree::Insert, SwitchBB, BB});
}
DTU->applyUpdates(Updates);
}
return true;
}