#include "llvm/Analysis/DemandedBits.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Use.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/KnownBits.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cstdint>
using namespace llvm;
using namespace llvm::PatternMatch;
#define DEBUG_TYPE "demanded-bits"
char DemandedBitsWrapperPass::ID = 0;
INITIALIZE_PASS_BEGIN(DemandedBitsWrapperPass, "demanded-bits",
"Demanded bits analysis", false, false)
INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_END(DemandedBitsWrapperPass, "demanded-bits",
"Demanded bits analysis", false, false)
DemandedBitsWrapperPass::DemandedBitsWrapperPass() : FunctionPass(ID) {
initializeDemandedBitsWrapperPassPass(*PassRegistry::getPassRegistry());
}
void DemandedBitsWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesCFG();
AU.addRequired<AssumptionCacheTracker>();
AU.addRequired<DominatorTreeWrapperPass>();
AU.setPreservesAll();
}
void DemandedBitsWrapperPass::print(raw_ostream &OS, const Module *M) const {
DB->print(OS);
}
static bool isAlwaysLive(Instruction *I) {
return I->isTerminator() || isa<DbgInfoIntrinsic>(I) || I->isEHPad() ||
I->mayHaveSideEffects();
}
void DemandedBits::determineLiveOperandBits(
const Instruction *UserI, const Value *Val, unsigned OperandNo,
const APInt &AOut, APInt &AB, KnownBits &Known, KnownBits &Known2,
bool &KnownBitsComputed) {
unsigned BitWidth = AB.getBitWidth();
auto ComputeKnownBits =
[&](unsigned BitWidth, const Value *V1, const Value *V2) {
if (KnownBitsComputed)
return;
KnownBitsComputed = true;
const DataLayout &DL = UserI->getModule()->getDataLayout();
Known = KnownBits(BitWidth);
computeKnownBits(V1, Known, DL, 0, &AC, UserI, &DT);
if (V2) {
Known2 = KnownBits(BitWidth);
computeKnownBits(V2, Known2, DL, 0, &AC, UserI, &DT);
}
};
switch (UserI->getOpcode()) {
default: break;
case Instruction::Call:
case Instruction::Invoke:
if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(UserI)) {
switch (II->getIntrinsicID()) {
default: break;
case Intrinsic::bswap:
AB = AOut.byteSwap();
break;
case Intrinsic::bitreverse:
AB = AOut.reverseBits();
break;
case Intrinsic::ctlz:
if (OperandNo == 0) {
ComputeKnownBits(BitWidth, Val, nullptr);
AB = APInt::getHighBitsSet(BitWidth,
std::min(BitWidth, Known.countMaxLeadingZeros()+1));
}
break;
case Intrinsic::cttz:
if (OperandNo == 0) {
ComputeKnownBits(BitWidth, Val, nullptr);
AB = APInt::getLowBitsSet(BitWidth,
std::min(BitWidth, Known.countMaxTrailingZeros()+1));
}
break;
case Intrinsic::fshl:
case Intrinsic::fshr: {
const APInt *SA;
if (OperandNo == 2) {
if (isPowerOf2_32(BitWidth))
AB = BitWidth - 1;
} else if (match(II->getOperand(2), m_APInt(SA))) {
uint64_t ShiftAmt = SA->urem(BitWidth);
if (II->getIntrinsicID() == Intrinsic::fshr)
ShiftAmt = BitWidth - ShiftAmt;
if (OperandNo == 0)
AB = AOut.lshr(ShiftAmt);
else if (OperandNo == 1)
AB = AOut.shl(BitWidth - ShiftAmt);
}
break;
}
case Intrinsic::umax:
case Intrinsic::umin:
case Intrinsic::smax:
case Intrinsic::smin:
AB = APInt::getBitsSetFrom(BitWidth, AOut.countTrailingZeros());
break;
}
}
break;
case Instruction::Add:
if (AOut.isMask()) {
AB = AOut;
} else {
ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
AB = determineLiveOperandBitsAdd(OperandNo, AOut, Known, Known2);
}
break;
case Instruction::Sub:
if (AOut.isMask()) {
AB = AOut;
} else {
ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
AB = determineLiveOperandBitsSub(OperandNo, AOut, Known, Known2);
}
break;
case Instruction::Mul:
AB = APInt::getLowBitsSet(BitWidth, AOut.getActiveBits());
break;
case Instruction::Shl:
if (OperandNo == 0) {
const APInt *ShiftAmtC;
if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
AB = AOut.lshr(ShiftAmt);
const ShlOperator *S = cast<ShlOperator>(UserI);
if (S->hasNoSignedWrap())
AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt+1);
else if (S->hasNoUnsignedWrap())
AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt);
}
}
break;
case Instruction::LShr:
if (OperandNo == 0) {
const APInt *ShiftAmtC;
if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
AB = AOut.shl(ShiftAmt);
if (cast<LShrOperator>(UserI)->isExact())
AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
}
}
break;
case Instruction::AShr:
if (OperandNo == 0) {
const APInt *ShiftAmtC;
if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
AB = AOut.shl(ShiftAmt);
if ((AOut & APInt::getHighBitsSet(BitWidth, ShiftAmt))
.getBoolValue())
AB.setSignBit();
if (cast<AShrOperator>(UserI)->isExact())
AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
}
}
break;
case Instruction::And:
AB = AOut;
ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
if (OperandNo == 0)
AB &= ~Known2.Zero;
else
AB &= ~(Known.Zero & ~Known2.Zero);
break;
case Instruction::Or:
AB = AOut;
ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
if (OperandNo == 0)
AB &= ~Known2.One;
else
AB &= ~(Known.One & ~Known2.One);
break;
case Instruction::Xor:
case Instruction::PHI:
AB = AOut;
break;
case Instruction::Trunc:
AB = AOut.zext(BitWidth);
break;
case Instruction::ZExt:
AB = AOut.trunc(BitWidth);
break;
case Instruction::SExt:
AB = AOut.trunc(BitWidth);
if ((AOut & APInt::getHighBitsSet(AOut.getBitWidth(),
AOut.getBitWidth() - BitWidth))
.getBoolValue())
AB.setSignBit();
break;
case Instruction::Select:
if (OperandNo != 0)
AB = AOut;
break;
case Instruction::ExtractElement:
if (OperandNo == 0)
AB = AOut;
break;
case Instruction::InsertElement:
case Instruction::ShuffleVector:
if (OperandNo == 0 || OperandNo == 1)
AB = AOut;
break;
}
}
bool DemandedBitsWrapperPass::runOnFunction(Function &F) {
auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
DB.emplace(F, AC, DT);
return false;
}
void DemandedBitsWrapperPass::releaseMemory() {
DB.reset();
}
void DemandedBits::performAnalysis() {
if (Analyzed)
return;
Analyzed = true;
Visited.clear();
AliveBits.clear();
DeadUses.clear();
SmallSetVector<Instruction*, 16> Worklist;
for (Instruction &I : instructions(F)) {
if (!isAlwaysLive(&I))
continue;
LLVM_DEBUG(dbgs() << "DemandedBits: Root: " << I << "\n");
Type *T = I.getType();
if (T->isIntOrIntVectorTy()) {
if (AliveBits.try_emplace(&I, T->getScalarSizeInBits(), 0).second)
Worklist.insert(&I);
continue;
}
for (Use &OI : I.operands()) {
if (Instruction *J = dyn_cast<Instruction>(OI)) {
Type *T = J->getType();
if (T->isIntOrIntVectorTy())
AliveBits[J] = APInt::getAllOnes(T->getScalarSizeInBits());
else
Visited.insert(J);
Worklist.insert(J);
}
}
}
while (!Worklist.empty()) {
Instruction *UserI = Worklist.pop_back_val();
LLVM_DEBUG(dbgs() << "DemandedBits: Visiting: " << *UserI);
APInt AOut;
bool InputIsKnownDead = false;
if (UserI->getType()->isIntOrIntVectorTy()) {
AOut = AliveBits[UserI];
LLVM_DEBUG(dbgs() << " Alive Out: 0x"
<< Twine::utohexstr(AOut.getLimitedValue()));
InputIsKnownDead = !AOut && !isAlwaysLive(UserI);
}
LLVM_DEBUG(dbgs() << "\n");
KnownBits Known, Known2;
bool KnownBitsComputed = false;
for (Use &OI : UserI->operands()) {
Instruction *I = dyn_cast<Instruction>(OI);
if (!I && !isa<Argument>(OI))
continue;
Type *T = OI->getType();
if (T->isIntOrIntVectorTy()) {
unsigned BitWidth = T->getScalarSizeInBits();
APInt AB = APInt::getAllOnes(BitWidth);
if (InputIsKnownDead) {
AB = APInt(BitWidth, 0);
} else {
determineLiveOperandBits(UserI, OI, OI.getOperandNo(), AOut, AB,
Known, Known2, KnownBitsComputed);
if (AB.isZero())
DeadUses.insert(&OI);
else
DeadUses.erase(&OI);
}
if (I) {
auto Res = AliveBits.try_emplace(I);
if (Res.second || (AB |= Res.first->second) != Res.first->second) {
Res.first->second = std::move(AB);
Worklist.insert(I);
}
}
} else if (I && Visited.insert(I).second) {
Worklist.insert(I);
}
}
}
}
APInt DemandedBits::getDemandedBits(Instruction *I) {
performAnalysis();
auto Found = AliveBits.find(I);
if (Found != AliveBits.end())
return Found->second;
const DataLayout &DL = I->getModule()->getDataLayout();
return APInt::getAllOnes(DL.getTypeSizeInBits(I->getType()->getScalarType()));
}
APInt DemandedBits::getDemandedBits(Use *U) {
Type *T = (*U)->getType();
Instruction *UserI = cast<Instruction>(U->getUser());
const DataLayout &DL = UserI->getModule()->getDataLayout();
unsigned BitWidth = DL.getTypeSizeInBits(T->getScalarType());
if (!T->isIntOrIntVectorTy())
return APInt::getAllOnes(BitWidth);
if (isUseDead(U))
return APInt(BitWidth, 0);
performAnalysis();
APInt AOut = getDemandedBits(UserI);
APInt AB = APInt::getAllOnes(BitWidth);
KnownBits Known, Known2;
bool KnownBitsComputed = false;
determineLiveOperandBits(UserI, *U, U->getOperandNo(), AOut, AB, Known,
Known2, KnownBitsComputed);
return AB;
}
bool DemandedBits::isInstructionDead(Instruction *I) {
performAnalysis();
return !Visited.count(I) && AliveBits.find(I) == AliveBits.end() &&
!isAlwaysLive(I);
}
bool DemandedBits::isUseDead(Use *U) {
if (!(*U)->getType()->isIntOrIntVectorTy())
return false;
Instruction *UserI = cast<Instruction>(U->getUser());
if (isAlwaysLive(UserI))
return false;
performAnalysis();
if (DeadUses.count(U))
return true;
if (UserI->getType()->isIntOrIntVectorTy()) {
auto Found = AliveBits.find(UserI);
if (Found != AliveBits.end() && Found->second.isZero())
return true;
}
return false;
}
void DemandedBits::print(raw_ostream &OS) {
auto PrintDB = [&](const Instruction *I, const APInt &A, Value *V = nullptr) {
OS << "DemandedBits: 0x" << Twine::utohexstr(A.getLimitedValue())
<< " for ";
if (V) {
V->printAsOperand(OS, false);
OS << " in ";
}
OS << *I << '\n';
};
performAnalysis();
for (auto &KV : AliveBits) {
Instruction *I = KV.first;
PrintDB(I, KV.second);
for (Use &OI : I->operands()) {
PrintDB(I, getDemandedBits(&OI), OI);
}
}
}
static APInt determineLiveOperandBitsAddCarry(unsigned OperandNo,
const APInt &AOut,
const KnownBits &LHS,
const KnownBits &RHS,
bool CarryZero, bool CarryOne) {
assert(!(CarryZero && CarryOne) &&
"Carry can't be zero and one at the same time");
APInt Bound = (LHS.Zero & RHS.Zero) | (LHS.One & RHS.One);
APInt RBound = Bound.reverseBits();
APInt RAOut = AOut.reverseBits();
APInt RProp = RAOut + (RAOut | ~RBound);
APInt RACarry = RProp ^ ~RBound;
APInt ACarry = RACarry.reverseBits();
APInt NeededToMaintainCarryZero;
APInt NeededToMaintainCarryOne;
if (OperandNo == 0) {
NeededToMaintainCarryZero = LHS.Zero | ~RHS.Zero;
NeededToMaintainCarryOne = LHS.One | ~RHS.One;
} else {
NeededToMaintainCarryZero = RHS.Zero | ~LHS.Zero;
NeededToMaintainCarryOne = RHS.One | ~LHS.One;
}
APInt PossibleSumZero = ~LHS.Zero + ~RHS.Zero + !CarryZero;
APInt PossibleSumOne = LHS.One + RHS.One + CarryOne;
APInt NeededToMaintainCarry = (~PossibleSumZero | NeededToMaintainCarryZero) &
(PossibleSumOne | NeededToMaintainCarryOne);
APInt AB = AOut | (ACarry & NeededToMaintainCarry);
return AB;
}
APInt DemandedBits::determineLiveOperandBitsAdd(unsigned OperandNo,
const APInt &AOut,
const KnownBits &LHS,
const KnownBits &RHS) {
return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, RHS, true,
false);
}
APInt DemandedBits::determineLiveOperandBitsSub(unsigned OperandNo,
const APInt &AOut,
const KnownBits &LHS,
const KnownBits &RHS) {
KnownBits NRHS;
NRHS.Zero = RHS.One;
NRHS.One = RHS.Zero;
return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, NRHS, false,
true);
}
FunctionPass *llvm::createDemandedBitsWrapperPass() {
return new DemandedBitsWrapperPass();
}
AnalysisKey DemandedBitsAnalysis::Key;
DemandedBits DemandedBitsAnalysis::run(Function &F,
FunctionAnalysisManager &AM) {
auto &AC = AM.getResult<AssumptionAnalysis>(F);
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
return DemandedBits(F, AC, DT);
}
PreservedAnalyses DemandedBitsPrinterPass::run(Function &F,
FunctionAnalysisManager &AM) {
AM.getResult<DemandedBitsAnalysis>(F).print(OS);
return PreservedAnalyses::all();
}