#include "llvm/Transforms/Scalar/MergeICmps.h"
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/Loads.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/BuildLibCalls.h"
#include <algorithm>
#include <numeric>
#include <utility>
#include <vector>
using namespace llvm;
namespace {
#define DEBUG_TYPE "mergeicmps"
struct BCEAtom {
BCEAtom() = default;
BCEAtom(GetElementPtrInst *GEP, LoadInst *LoadI, int BaseId, APInt Offset)
: GEP(GEP), LoadI(LoadI), BaseId(BaseId), Offset(Offset) {}
BCEAtom(const BCEAtom &) = delete;
BCEAtom &operator=(const BCEAtom &) = delete;
BCEAtom(BCEAtom &&that) = default;
BCEAtom &operator=(BCEAtom &&that) {
if (this == &that)
return *this;
GEP = that.GEP;
LoadI = that.LoadI;
BaseId = that.BaseId;
Offset = std::move(that.Offset);
return *this;
}
bool operator<(const BCEAtom &O) const {
return BaseId != O.BaseId ? BaseId < O.BaseId : Offset.slt(O.Offset);
}
GetElementPtrInst *GEP = nullptr;
LoadInst *LoadI = nullptr;
unsigned BaseId = 0;
APInt Offset;
};
class BaseIdentifier {
public:
int getBaseId(const Value *Base) {
assert(Base && "invalid base");
const auto Insertion = BaseToIndex.try_emplace(Base, Order);
if (Insertion.second)
++Order;
return Insertion.first->second;
}
private:
unsigned Order = 1;
DenseMap<const Value*, int> BaseToIndex;
};
BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) {
auto *const LoadI = dyn_cast<LoadInst>(Val);
if (!LoadI)
return {};
LLVM_DEBUG(dbgs() << "load\n");
if (LoadI->isUsedOutsideOfBlock(LoadI->getParent())) {
LLVM_DEBUG(dbgs() << "used outside of block\n");
return {};
}
if (!LoadI->isSimple()) {
LLVM_DEBUG(dbgs() << "volatile or atomic\n");
return {};
}
Value *Addr = LoadI->getOperand(0);
if (Addr->getType()->getPointerAddressSpace() != 0) {
LLVM_DEBUG(dbgs() << "from non-zero AddressSpace\n");
return {};
}
const auto &DL = LoadI->getModule()->getDataLayout();
if (!isDereferenceablePointer(Addr, LoadI->getType(), DL)) {
LLVM_DEBUG(dbgs() << "not dereferenceable\n");
return {};
}
APInt Offset = APInt(DL.getPointerTypeSizeInBits(Addr->getType()), 0);
Value *Base = Addr;
auto *GEP = dyn_cast<GetElementPtrInst>(Addr);
if (GEP) {
LLVM_DEBUG(dbgs() << "GEP\n");
if (GEP->isUsedOutsideOfBlock(LoadI->getParent())) {
LLVM_DEBUG(dbgs() << "used outside of block\n");
return {};
}
if (!GEP->accumulateConstantOffset(DL, Offset))
return {};
Base = GEP->getPointerOperand();
}
return BCEAtom(GEP, LoadI, BaseId.getBaseId(Base), Offset);
}
struct BCECmp {
BCEAtom Lhs;
BCEAtom Rhs;
int SizeBits;
const ICmpInst *CmpI;
BCECmp(BCEAtom L, BCEAtom R, int SizeBits, const ICmpInst *CmpI)
: Lhs(std::move(L)), Rhs(std::move(R)), SizeBits(SizeBits), CmpI(CmpI) {
if (Rhs < Lhs) std::swap(Rhs, Lhs);
}
};
class BCECmpBlock {
public:
typedef SmallDenseSet<const Instruction *, 8> InstructionSet;
BCECmpBlock(BCECmp Cmp, BasicBlock *BB, InstructionSet BlockInsts)
: BB(BB), BlockInsts(std::move(BlockInsts)), Cmp(std::move(Cmp)) {}
const BCEAtom &Lhs() const { return Cmp.Lhs; }
const BCEAtom &Rhs() const { return Cmp.Rhs; }
int SizeBits() const { return Cmp.SizeBits; }
bool doesOtherWork() const;
bool canSplit(AliasAnalysis &AA) const;
bool canSinkBCECmpInst(const Instruction *, AliasAnalysis &AA) const;
void split(BasicBlock *NewParent, AliasAnalysis &AA) const;
BasicBlock *BB;
InstructionSet BlockInsts;
bool RequireSplit = false;
unsigned OrigOrder = 0;
private:
BCECmp Cmp;
};
bool BCECmpBlock::canSinkBCECmpInst(const Instruction *Inst,
AliasAnalysis &AA) const {
if (Inst->mayWriteToMemory()) {
auto MayClobber = [&](LoadInst *LI) {
return (Inst->getParent() != LI->getParent() || !Inst->comesBefore(LI)) &&
isModSet(AA.getModRefInfo(Inst, MemoryLocation::get(LI)));
};
if (MayClobber(Cmp.Lhs.LoadI) || MayClobber(Cmp.Rhs.LoadI))
return false;
}
return llvm::none_of(Inst->operands(), [&](const Value *Op) {
const Instruction *OpI = dyn_cast<Instruction>(Op);
return OpI && BlockInsts.contains(OpI);
});
}
void BCECmpBlock::split(BasicBlock *NewParent, AliasAnalysis &AA) const {
llvm::SmallVector<Instruction *, 4> OtherInsts;
for (Instruction &Inst : *BB) {
if (BlockInsts.count(&Inst))
continue;
assert(canSinkBCECmpInst(&Inst, AA) && "Split unsplittable block");
OtherInsts.push_back(&Inst);
}
for (Instruction *Inst : reverse(OtherInsts))
Inst->moveBefore(*NewParent, NewParent->begin());
}
bool BCECmpBlock::canSplit(AliasAnalysis &AA) const {
for (Instruction &Inst : *BB) {
if (!BlockInsts.count(&Inst)) {
if (!canSinkBCECmpInst(&Inst, AA))
return false;
}
}
return true;
}
bool BCECmpBlock::doesOtherWork() const {
for (const Instruction &Inst : *BB) {
if (!BlockInsts.count(&Inst))
return true;
}
return false;
}
Optional<BCECmp> visitICmp(const ICmpInst *const CmpI,
const ICmpInst::Predicate ExpectedPredicate,
BaseIdentifier &BaseId) {
if (!CmpI->hasOneUse()) {
LLVM_DEBUG(dbgs() << "cmp has several uses\n");
return None;
}
if (CmpI->getPredicate() != ExpectedPredicate)
return None;
LLVM_DEBUG(dbgs() << "cmp "
<< (ExpectedPredicate == ICmpInst::ICMP_EQ ? "eq" : "ne")
<< "\n");
auto Lhs = visitICmpLoadOperand(CmpI->getOperand(0), BaseId);
if (!Lhs.BaseId)
return None;
auto Rhs = visitICmpLoadOperand(CmpI->getOperand(1), BaseId);
if (!Rhs.BaseId)
return None;
const auto &DL = CmpI->getModule()->getDataLayout();
return BCECmp(std::move(Lhs), std::move(Rhs),
DL.getTypeSizeInBits(CmpI->getOperand(0)->getType()), CmpI);
}
Optional<BCECmpBlock> visitCmpBlock(Value *const Val, BasicBlock *const Block,
const BasicBlock *const PhiBlock,
BaseIdentifier &BaseId) {
if (Block->empty()) return None;
auto *const BranchI = dyn_cast<BranchInst>(Block->getTerminator());
if (!BranchI) return None;
LLVM_DEBUG(dbgs() << "branch\n");
Value *Cond;
ICmpInst::Predicate ExpectedPredicate;
if (BranchI->isUnconditional()) {
Cond = Val;
ExpectedPredicate = ICmpInst::ICMP_EQ;
} else {
const auto *const Const = cast<ConstantInt>(Val);
LLVM_DEBUG(dbgs() << "const\n");
if (!Const->isZero()) return None;
LLVM_DEBUG(dbgs() << "false\n");
assert(BranchI->getNumSuccessors() == 2 && "expecting a cond branch");
BasicBlock *const FalseBlock = BranchI->getSuccessor(1);
Cond = BranchI->getCondition();
ExpectedPredicate =
FalseBlock == PhiBlock ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
}
auto *CmpI = dyn_cast<ICmpInst>(Cond);
if (!CmpI) return None;
LLVM_DEBUG(dbgs() << "icmp\n");
Optional<BCECmp> Result = visitICmp(CmpI, ExpectedPredicate, BaseId);
if (!Result)
return None;
BCECmpBlock::InstructionSet BlockInsts(
{Result->Lhs.LoadI, Result->Rhs.LoadI, Result->CmpI, BranchI});
if (Result->Lhs.GEP)
BlockInsts.insert(Result->Lhs.GEP);
if (Result->Rhs.GEP)
BlockInsts.insert(Result->Rhs.GEP);
return BCECmpBlock(std::move(*Result), Block, BlockInsts);
}
static inline void enqueueBlock(std::vector<BCECmpBlock> &Comparisons,
BCECmpBlock &&Comparison) {
LLVM_DEBUG(dbgs() << "Block '" << Comparison.BB->getName()
<< "': Found cmp of " << Comparison.SizeBits()
<< " bits between " << Comparison.Lhs().BaseId << " + "
<< Comparison.Lhs().Offset << " and "
<< Comparison.Rhs().BaseId << " + "
<< Comparison.Rhs().Offset << "\n");
LLVM_DEBUG(dbgs() << "\n");
Comparison.OrigOrder = Comparisons.size();
Comparisons.push_back(std::move(Comparison));
}
class BCECmpChain {
public:
using ContiguousBlocks = std::vector<BCECmpBlock>;
BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi,
AliasAnalysis &AA);
bool simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA,
DomTreeUpdater &DTU);
bool atLeastOneMerged() const {
return any_of(MergedBlocks_,
[](const auto &Blocks) { return Blocks.size() > 1; });
}
private:
PHINode &Phi_;
std::vector<ContiguousBlocks> MergedBlocks_;
BasicBlock *EntryBlock_;
};
static bool areContiguous(const BCECmpBlock &First, const BCECmpBlock &Second) {
return First.Lhs().BaseId == Second.Lhs().BaseId &&
First.Rhs().BaseId == Second.Rhs().BaseId &&
First.Lhs().Offset + First.SizeBits() / 8 == Second.Lhs().Offset &&
First.Rhs().Offset + First.SizeBits() / 8 == Second.Rhs().Offset;
}
static unsigned getMinOrigOrder(const BCECmpChain::ContiguousBlocks &Blocks) {
unsigned MinOrigOrder = std::numeric_limits<unsigned>::max();
for (const BCECmpBlock &Block : Blocks)
MinOrigOrder = std::min(MinOrigOrder, Block.OrigOrder);
return MinOrigOrder;
}
static std::vector<BCECmpChain::ContiguousBlocks>
mergeBlocks(std::vector<BCECmpBlock> &&Blocks) {
std::vector<BCECmpChain::ContiguousBlocks> MergedBlocks;
llvm::sort(Blocks,
[](const BCECmpBlock &LhsBlock, const BCECmpBlock &RhsBlock) {
return std::tie(LhsBlock.Lhs(), LhsBlock.Rhs()) <
std::tie(RhsBlock.Lhs(), RhsBlock.Rhs());
});
BCECmpChain::ContiguousBlocks *LastMergedBlock = nullptr;
for (BCECmpBlock &Block : Blocks) {
if (!LastMergedBlock || !areContiguous(LastMergedBlock->back(), Block)) {
MergedBlocks.emplace_back();
LastMergedBlock = &MergedBlocks.back();
} else {
LLVM_DEBUG(dbgs() << "Merging block " << Block.BB->getName() << " into "
<< LastMergedBlock->back().BB->getName() << "\n");
}
LastMergedBlock->push_back(std::move(Block));
}
llvm::sort(MergedBlocks, [](const BCECmpChain::ContiguousBlocks &LhsBlocks,
const BCECmpChain::ContiguousBlocks &RhsBlocks) {
return getMinOrigOrder(LhsBlocks) < getMinOrigOrder(RhsBlocks);
});
return MergedBlocks;
}
BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi,
AliasAnalysis &AA)
: Phi_(Phi) {
assert(!Blocks.empty() && "a chain should have at least one block");
std::vector<BCECmpBlock> Comparisons;
BaseIdentifier BaseId;
for (BasicBlock *const Block : Blocks) {
assert(Block && "invalid block");
Optional<BCECmpBlock> Comparison = visitCmpBlock(
Phi.getIncomingValueForBlock(Block), Block, Phi.getParent(), BaseId);
if (!Comparison) {
LLVM_DEBUG(dbgs() << "chain with invalid BCECmpBlock, no merge.\n");
return;
}
if (Comparison->doesOtherWork()) {
LLVM_DEBUG(dbgs() << "block '" << Comparison->BB->getName()
<< "' does extra work besides compare\n");
if (Comparisons.empty()) {
if (Comparison->canSplit(AA)) {
LLVM_DEBUG(dbgs()
<< "Split initial block '" << Comparison->BB->getName()
<< "' that does extra work besides compare\n");
Comparison->RequireSplit = true;
enqueueBlock(Comparisons, std::move(*Comparison));
} else {
LLVM_DEBUG(dbgs()
<< "ignoring initial block '" << Comparison->BB->getName()
<< "' that does extra work besides compare\n");
}
continue;
}
return;
}
enqueueBlock(Comparisons, std::move(*Comparison));
}
if (Comparisons.empty()) {
LLVM_DEBUG(dbgs() << "chain with no BCE basic blocks, no merge\n");
return;
}
EntryBlock_ = Comparisons[0].BB;
MergedBlocks_ = mergeBlocks(std::move(Comparisons));
}
namespace {
class MergedBlockName {
SmallString<16> Scratch;
public:
explicit MergedBlockName(ArrayRef<BCECmpBlock> Comparisons)
: Name(makeName(Comparisons)) {}
const StringRef Name;
private:
StringRef makeName(ArrayRef<BCECmpBlock> Comparisons) {
assert(!Comparisons.empty() && "no basic block");
if (Comparisons.size() == 1)
return Comparisons[0].BB->getName();
const int size = std::accumulate(Comparisons.begin(), Comparisons.end(), 0,
[](int i, const BCECmpBlock &Cmp) {
return i + Cmp.BB->getName().size();
});
if (size == 0)
return StringRef("", 0);
Scratch.clear();
Scratch.reserve(size + Comparisons.size() - 1);
const auto append = [this](StringRef str) {
Scratch.append(str.begin(), str.end());
};
append(Comparisons[0].BB->getName());
for (int I = 1, E = Comparisons.size(); I < E; ++I) {
const BasicBlock *const BB = Comparisons[I].BB;
if (!BB->getName().empty()) {
append("+");
append(BB->getName());
}
}
return Scratch.str();
}
};
}
static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons,
BasicBlock *const InsertBefore,
BasicBlock *const NextCmpBlock,
PHINode &Phi, const TargetLibraryInfo &TLI,
AliasAnalysis &AA, DomTreeUpdater &DTU) {
assert(!Comparisons.empty() && "merging zero comparisons");
LLVMContext &Context = NextCmpBlock->getContext();
const BCECmpBlock &FirstCmp = Comparisons[0];
BasicBlock *const BB =
BasicBlock::Create(Context, MergedBlockName(Comparisons).Name,
NextCmpBlock->getParent(), InsertBefore);
IRBuilder<> Builder(BB);
Value *Lhs, *Rhs;
if (FirstCmp.Lhs().GEP)
Lhs = Builder.Insert(FirstCmp.Lhs().GEP->clone());
else
Lhs = FirstCmp.Lhs().LoadI->getPointerOperand();
if (FirstCmp.Rhs().GEP)
Rhs = Builder.Insert(FirstCmp.Rhs().GEP->clone());
else
Rhs = FirstCmp.Rhs().LoadI->getPointerOperand();
Value *IsEqual = nullptr;
LLVM_DEBUG(dbgs() << "Merging " << Comparisons.size() << " comparisons -> "
<< BB->getName() << "\n");
const auto ToSplit = llvm::find_if(
Comparisons, [](const BCECmpBlock &B) { return B.RequireSplit; });
if (ToSplit != Comparisons.end()) {
LLVM_DEBUG(dbgs() << "Splitting non_BCE work to header\n");
ToSplit->split(BB, AA);
}
if (Comparisons.size() == 1) {
LLVM_DEBUG(dbgs() << "Only one comparison, updating branches\n");
Value *const LhsLoad =
Builder.CreateLoad(FirstCmp.Lhs().LoadI->getType(), Lhs);
Value *const RhsLoad =
Builder.CreateLoad(FirstCmp.Rhs().LoadI->getType(), Rhs);
IsEqual = Builder.CreateICmpEQ(LhsLoad, RhsLoad);
} else {
const unsigned TotalSizeBits = std::accumulate(
Comparisons.begin(), Comparisons.end(), 0u,
[](int Size, const BCECmpBlock &C) { return Size + C.SizeBits(); });
const auto &DL = Phi.getModule()->getDataLayout();
Value *const MemCmpCall = emitMemCmp(
Lhs, Rhs,
ConstantInt::get(DL.getIntPtrType(Context), TotalSizeBits / 8), Builder,
DL, &TLI);
IsEqual = Builder.CreateICmpEQ(
MemCmpCall, ConstantInt::get(Type::getInt32Ty(Context), 0));
}
BasicBlock *const PhiBB = Phi.getParent();
if (NextCmpBlock == PhiBB) {
Builder.CreateBr(PhiBB);
Phi.addIncoming(IsEqual, BB);
DTU.applyUpdates({{DominatorTree::Insert, BB, PhiBB}});
} else {
Builder.CreateCondBr(IsEqual, NextCmpBlock, PhiBB);
Phi.addIncoming(ConstantInt::getFalse(Context), BB);
DTU.applyUpdates({{DominatorTree::Insert, BB, NextCmpBlock},
{DominatorTree::Insert, BB, PhiBB}});
}
return BB;
}
bool BCECmpChain::simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA,
DomTreeUpdater &DTU) {
assert(atLeastOneMerged() && "simplifying trivial BCECmpChain");
LLVM_DEBUG(dbgs() << "Simplifying comparison chain starting at block "
<< EntryBlock_->getName() << "\n");
BasicBlock *InsertBefore = EntryBlock_;
BasicBlock *NextCmpBlock = Phi_.getParent();
for (const auto &Blocks : reverse(MergedBlocks_)) {
InsertBefore = NextCmpBlock = mergeComparisons(
Blocks, InsertBefore, NextCmpBlock, Phi_, TLI, AA, DTU);
}
while (!pred_empty(EntryBlock_)) {
BasicBlock* const Pred = *pred_begin(EntryBlock_);
LLVM_DEBUG(dbgs() << "Updating jump into old chain from " << Pred->getName()
<< "\n");
Pred->getTerminator()->replaceUsesOfWith(EntryBlock_, NextCmpBlock);
DTU.applyUpdates({{DominatorTree::Delete, Pred, EntryBlock_},
{DominatorTree::Insert, Pred, NextCmpBlock}});
}
const bool ChainEntryIsFnEntry = EntryBlock_->isEntryBlock();
if (ChainEntryIsFnEntry && DTU.hasDomTree()) {
LLVM_DEBUG(dbgs() << "Changing function entry from "
<< EntryBlock_->getName() << " to "
<< NextCmpBlock->getName() << "\n");
DTU.getDomTree().setNewRoot(NextCmpBlock);
DTU.applyUpdates({{DominatorTree::Delete, NextCmpBlock, EntryBlock_}});
}
EntryBlock_ = nullptr;
SmallVector<BasicBlock *, 16> DeadBlocks;
for (const auto &Blocks : MergedBlocks_) {
for (const BCECmpBlock &Block : Blocks) {
LLVM_DEBUG(dbgs() << "Deleting merged block " << Block.BB->getName()
<< "\n");
DeadBlocks.push_back(Block.BB);
}
}
DeleteDeadBlocks(DeadBlocks, &DTU);
MergedBlocks_.clear();
return true;
}
std::vector<BasicBlock *> getOrderedBlocks(PHINode &Phi,
BasicBlock *const LastBlock,
int NumBlocks) {
std::vector<BasicBlock *> Blocks(NumBlocks);
assert(LastBlock && "invalid last block");
BasicBlock *CurBlock = LastBlock;
for (int BlockIndex = NumBlocks - 1; BlockIndex > 0; --BlockIndex) {
if (CurBlock->hasAddressTaken()) {
LLVM_DEBUG(dbgs() << "skip: block " << BlockIndex
<< " has its address taken\n");
return {};
}
Blocks[BlockIndex] = CurBlock;
auto *SinglePredecessor = CurBlock->getSinglePredecessor();
if (!SinglePredecessor) {
LLVM_DEBUG(dbgs() << "skip: block " << BlockIndex
<< " has two or more predecessors\n");
return {};
}
if (Phi.getBasicBlockIndex(SinglePredecessor) < 0) {
LLVM_DEBUG(dbgs() << "skip: block " << BlockIndex
<< " does not link back to the phi\n");
return {};
}
CurBlock = SinglePredecessor;
}
Blocks[0] = CurBlock;
return Blocks;
}
bool processPhi(PHINode &Phi, const TargetLibraryInfo &TLI, AliasAnalysis &AA,
DomTreeUpdater &DTU) {
LLVM_DEBUG(dbgs() << "processPhi()\n");
if (Phi.getNumIncomingValues() <= 1) {
LLVM_DEBUG(dbgs() << "skip: only one incoming value in phi\n");
return false;
}
BasicBlock *LastBlock = nullptr;
for (unsigned I = 0; I < Phi.getNumIncomingValues(); ++I) {
if (isa<ConstantInt>(Phi.getIncomingValue(I))) continue;
if (LastBlock) {
LLVM_DEBUG(dbgs() << "skip: several non-constant values\n");
return false;
}
if (!isa<ICmpInst>(Phi.getIncomingValue(I)) ||
cast<ICmpInst>(Phi.getIncomingValue(I))->getParent() !=
Phi.getIncomingBlock(I)) {
LLVM_DEBUG(
dbgs()
<< "skip: non-constant value not from cmp or not from last block.\n");
return false;
}
LastBlock = Phi.getIncomingBlock(I);
}
if (!LastBlock) {
LLVM_DEBUG(dbgs() << "skip: no non-constant block\n");
return false;
}
if (LastBlock->getSingleSuccessor() != Phi.getParent()) {
LLVM_DEBUG(dbgs() << "skip: last block non-phi successor\n");
return false;
}
const auto Blocks =
getOrderedBlocks(Phi, LastBlock, Phi.getNumIncomingValues());
if (Blocks.empty()) return false;
BCECmpChain CmpChain(Blocks, Phi, AA);
if (!CmpChain.atLeastOneMerged()) {
LLVM_DEBUG(dbgs() << "skip: nothing merged\n");
return false;
}
return CmpChain.simplify(TLI, AA, DTU);
}
static bool runImpl(Function &F, const TargetLibraryInfo &TLI,
const TargetTransformInfo &TTI, AliasAnalysis &AA,
DominatorTree *DT) {
LLVM_DEBUG(dbgs() << "MergeICmpsLegacyPass: " << F.getName() << "\n");
if (!TTI.enableMemCmpExpansion(F.hasOptSize(), true))
return false;
if (!TLI.has(LibFunc_memcmp))
return false;
DomTreeUpdater DTU(DT, nullptr,
DomTreeUpdater::UpdateStrategy::Eager);
bool MadeChange = false;
for (BasicBlock &BB : llvm::drop_begin(F)) {
if (auto *const Phi = dyn_cast<PHINode>(&*BB.begin()))
MadeChange |= processPhi(*Phi, TLI, AA, DTU);
}
return MadeChange;
}
class MergeICmpsLegacyPass : public FunctionPass {
public:
static char ID;
MergeICmpsLegacyPass() : FunctionPass(ID) {
initializeMergeICmpsLegacyPassPass(*PassRegistry::getPassRegistry());
}
bool runOnFunction(Function &F) override {
if (skipFunction(F)) return false;
const auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
return runImpl(F, TLI, TTI, AA, DTWP ? &DTWP->getDomTree() : nullptr);
}
private:
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<TargetLibraryInfoWrapperPass>();
AU.addRequired<TargetTransformInfoWrapperPass>();
AU.addRequired<AAResultsWrapperPass>();
AU.addPreserved<GlobalsAAWrapperPass>();
AU.addPreserved<DominatorTreeWrapperPass>();
}
};
}
char MergeICmpsLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(MergeICmpsLegacyPass, "mergeicmps",
"Merge contiguous icmps into a memcmp", false, false)
INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
INITIALIZE_PASS_END(MergeICmpsLegacyPass, "mergeicmps",
"Merge contiguous icmps into a memcmp", false, false)
Pass *llvm::createMergeICmpsLegacyPass() { return new MergeICmpsLegacyPass(); }
PreservedAnalyses MergeICmpsPass::run(Function &F,
FunctionAnalysisManager &AM) {
auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
auto &TTI = AM.getResult<TargetIRAnalysis>(F);
auto &AA = AM.getResult<AAManager>(F);
auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
const bool MadeChanges = runImpl(F, TLI, TTI, AA, DT);
if (!MadeChanges)
return PreservedAnalyses::all();
PreservedAnalyses PA;
PA.preserve<DominatorTreeAnalysis>();
return PA;
}