#include "llvm/Transforms/Scalar/BDCE.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/DemandedBits.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Local.h"
using namespace llvm;
#define DEBUG_TYPE "bdce"
STATISTIC(NumRemoved, "Number of instructions removed (unused)");
STATISTIC(NumSimplified, "Number of instructions trivialized (dead bits)");
STATISTIC(NumSExt2ZExt,
"Number of sign extension instructions converted to zero extension");
static void clearAssumptionsOfUsers(Instruction *I, DemandedBits &DB) {
assert(I->getType()->isIntOrIntVectorTy() &&
"Trivializing a non-integer value?");
SmallPtrSet<Instruction *, 16> Visited;
SmallVector<Instruction *, 16> WorkList;
for (User *JU : I->users()) {
auto *J = dyn_cast<Instruction>(JU);
if (J && J->getType()->isIntOrIntVectorTy() &&
!DB.getDemandedBits(J).isAllOnes()) {
Visited.insert(J);
WorkList.push_back(J);
}
}
while (!WorkList.empty()) {
Instruction *J = WorkList.pop_back_val();
J->dropPoisonGeneratingFlags();
for (User *KU : J->users()) {
auto *K = dyn_cast<Instruction>(KU);
if (K && Visited.insert(K).second && K->getType()->isIntOrIntVectorTy() &&
!DB.getDemandedBits(K).isAllOnes())
WorkList.push_back(K);
}
}
}
static bool bitTrackingDCE(Function &F, DemandedBits &DB) {
SmallVector<Instruction*, 128> Worklist;
bool Changed = false;
for (Instruction &I : instructions(F)) {
if (I.mayHaveSideEffects() && I.use_empty())
continue;
if (DB.isInstructionDead(&I) ||
(I.getType()->isIntOrIntVectorTy() && DB.getDemandedBits(&I).isZero() &&
wouldInstructionBeTriviallyDead(&I))) {
Worklist.push_back(&I);
Changed = true;
continue;
}
if (SExtInst *SE = dyn_cast<SExtInst>(&I)) {
APInt Demanded = DB.getDemandedBits(SE);
const uint32_t SrcBitSize = SE->getSrcTy()->getScalarSizeInBits();
auto *const DstTy = SE->getDestTy();
const uint32_t DestBitSize = DstTy->getScalarSizeInBits();
if (Demanded.countLeadingZeros() >= (DestBitSize - SrcBitSize)) {
clearAssumptionsOfUsers(SE, DB);
IRBuilder<> Builder(SE);
I.replaceAllUsesWith(
Builder.CreateZExt(SE->getOperand(0), DstTy, SE->getName()));
Worklist.push_back(SE);
Changed = true;
NumSExt2ZExt++;
continue;
}
}
for (Use &U : I.operands()) {
if (!U->getType()->isIntOrIntVectorTy())
continue;
if (!isa<Instruction>(U) && !isa<Argument>(U))
continue;
if (!DB.isUseDead(&U))
continue;
LLVM_DEBUG(dbgs() << "BDCE: Trivializing: " << U << " (all bits dead)\n");
clearAssumptionsOfUsers(&I, DB);
U.set(ConstantInt::get(U->getType(), 0));
++NumSimplified;
Changed = true;
}
}
for (Instruction *&I : llvm::reverse(Worklist)) {
salvageDebugInfo(*I);
I->dropAllReferences();
}
for (Instruction *&I : Worklist) {
++NumRemoved;
I->eraseFromParent();
}
return Changed;
}
PreservedAnalyses BDCEPass::run(Function &F, FunctionAnalysisManager &AM) {
auto &DB = AM.getResult<DemandedBitsAnalysis>(F);
if (!bitTrackingDCE(F, DB))
return PreservedAnalyses::all();
PreservedAnalyses PA;
PA.preserveSet<CFGAnalyses>();
return PA;
}
namespace {
struct BDCELegacyPass : public FunctionPass {
static char ID; BDCELegacyPass() : FunctionPass(ID) {
initializeBDCELegacyPassPass(*PassRegistry::getPassRegistry());
}
bool runOnFunction(Function &F) override {
if (skipFunction(F))
return false;
auto &DB = getAnalysis<DemandedBitsWrapperPass>().getDemandedBits();
return bitTrackingDCE(F, DB);
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
AU.addRequired<DemandedBitsWrapperPass>();
AU.addPreserved<GlobalsAAWrapperPass>();
}
};
}
char BDCELegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(BDCELegacyPass, "bdce",
"Bit-Tracking Dead Code Elimination", false, false)
INITIALIZE_PASS_DEPENDENCY(DemandedBitsWrapperPass)
INITIALIZE_PASS_END(BDCELegacyPass, "bdce",
"Bit-Tracking Dead Code Elimination", false, false)
FunctionPass *llvm::createBitTrackingDCEPass() { return new BDCELegacyPass(); }