#include "InstCombineInternal.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/GetElementPtrTypeIterator.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/KnownBits.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
using namespace llvm;
using namespace llvm::PatternMatch;
#define DEBUG_TYPE "instcombine"
static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo,
const APInt &Demanded) {
assert(I && "No instruction?");
assert(OpNo < I->getNumOperands() && "Operand index too large");
Value *Op = I->getOperand(OpNo);
const APInt *C;
if (!match(Op, m_APInt(C)))
return false;
if (C->isSubsetOf(Demanded))
return false;
I->setOperand(OpNo, ConstantInt::get(Op->getType(), *C & Demanded));
return true;
}
bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) {
unsigned BitWidth = Inst.getType()->getScalarSizeInBits();
KnownBits Known(BitWidth);
APInt DemandedMask(APInt::getAllOnes(BitWidth));
Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, Known,
0, &Inst);
if (!V) return false;
if (V == &Inst) return true;
replaceInstUsesWith(Inst, V);
return true;
}
bool InstCombinerImpl::SimplifyDemandedBits(Instruction *I, unsigned OpNo,
const APInt &DemandedMask,
KnownBits &Known, unsigned Depth) {
Use &U = I->getOperandUse(OpNo);
Value *NewVal = SimplifyDemandedUseBits(U.get(), DemandedMask, Known,
Depth, I);
if (!NewVal) return false;
if (Instruction* OpInst = dyn_cast<Instruction>(U))
salvageDebugInfo(*OpInst);
replaceUse(U, NewVal);
return true;
}
Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
KnownBits &Known,
unsigned Depth,
Instruction *CxtI) {
assert(V != nullptr && "Null pointer of Value???");
assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
uint32_t BitWidth = DemandedMask.getBitWidth();
Type *VTy = V->getType();
assert(
(!VTy->isIntOrIntVectorTy() || VTy->getScalarSizeInBits() == BitWidth) &&
Known.getBitWidth() == BitWidth &&
"Value *V, DemandedMask and Known must have same BitWidth");
if (isa<Constant>(V)) {
computeKnownBits(V, Known, Depth, CxtI);
return nullptr;
}
Known.resetAll();
if (DemandedMask.isZero()) return UndefValue::get(VTy);
if (Depth == MaxAnalysisRecursionDepth)
return nullptr;
if (isa<ScalableVectorType>(VTy))
return nullptr;
Instruction *I = dyn_cast<Instruction>(V);
if (!I) {
computeKnownBits(V, Known, Depth, CxtI);
return nullptr; }
if (Depth != 0 && !I->hasOneUse())
return SimplifyMultipleUseDemandedBits(I, DemandedMask, Known, Depth, CxtI);
KnownBits LHSKnown(BitWidth), RHSKnown(BitWidth);
if (Depth == 0 && !V->hasOneUse())
DemandedMask.setAllBits();
auto simplifyOperandsBasedOnUnusedHighBits = [&](APInt &DemandedFromOps) {
unsigned NLZ = DemandedMask.countLeadingZeros();
DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
if (ShrinkDemandedConstant(I, 0, DemandedFromOps) ||
SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1) ||
ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) {
if (NLZ > 0) {
I->setHasNoSignedWrap(false);
I->setHasNoUnsignedWrap(false);
}
return true;
}
return false;
};
switch (I->getOpcode()) {
default:
computeKnownBits(I, Known, Depth, CxtI);
break;
case Instruction::And: {
if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1) ||
SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnown.Zero, LHSKnown,
Depth + 1))
return I;
assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?");
assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?");
Known = LHSKnown & RHSKnown;
if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
return Constant::getIntegerValue(VTy, Known.One);
if (DemandedMask.isSubsetOf(LHSKnown.Zero | RHSKnown.One))
return I->getOperand(0);
if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.One))
return I->getOperand(1);
if (ShrinkDemandedConstant(I, 1, DemandedMask & ~LHSKnown.Zero))
return I;
break;
}
case Instruction::Or: {
if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1) ||
SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnown.One, LHSKnown,
Depth + 1))
return I;
assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?");
assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?");
Known = LHSKnown | RHSKnown;
if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
return Constant::getIntegerValue(VTy, Known.One);
if (DemandedMask.isSubsetOf(LHSKnown.One | RHSKnown.Zero))
return I->getOperand(0);
if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero))
return I->getOperand(1);
if (ShrinkDemandedConstant(I, 1, DemandedMask))
return I;
break;
}
case Instruction::Xor: {
if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1) ||
SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Depth + 1))
return I;
Value *LHS, *RHS;
if (DemandedMask == 1 &&
match(I->getOperand(0), m_Intrinsic<Intrinsic::ctpop>(m_Value(LHS))) &&
match(I->getOperand(1), m_Intrinsic<Intrinsic::ctpop>(m_Value(RHS)))) {
IRBuilderBase::InsertPointGuard Guard(Builder);
Builder.SetInsertPoint(I);
auto *Xor = Builder.CreateXor(LHS, RHS);
return Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, Xor);
}
assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?");
assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?");
Known = LHSKnown ^ RHSKnown;
if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
return Constant::getIntegerValue(VTy, Known.One);
if (DemandedMask.isSubsetOf(RHSKnown.Zero))
return I->getOperand(0);
if (DemandedMask.isSubsetOf(LHSKnown.Zero))
return I->getOperand(1);
if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.Zero)) {
Instruction *Or =
BinaryOperator::CreateOr(I->getOperand(0), I->getOperand(1),
I->getName());
return InsertNewInstWith(Or, *I);
}
if (DemandedMask.isSubsetOf(RHSKnown.Zero|RHSKnown.One) &&
RHSKnown.One.isSubsetOf(LHSKnown.One)) {
Constant *AndC = Constant::getIntegerValue(VTy,
~RHSKnown.One & DemandedMask);
Instruction *And = BinaryOperator::CreateAnd(I->getOperand(0), AndC);
return InsertNewInstWith(And, *I);
}
const APInt *C;
if (match(I->getOperand(1), m_APInt(C)) && !C->isAllOnes()) {
if ((*C | ~DemandedMask).isAllOnes()) {
I->setOperand(1, ConstantInt::getAllOnesValue(VTy));
return I;
}
if (ShrinkDemandedConstant(I, 1, DemandedMask))
return I;
}
if (Instruction *LHSInst = dyn_cast<Instruction>(I->getOperand(0))) {
ConstantInt *AndRHS, *XorRHS;
if (LHSInst->getOpcode() == Instruction::And && LHSInst->hasOneUse() &&
match(I->getOperand(1), m_ConstantInt(XorRHS)) &&
match(LHSInst->getOperand(1), m_ConstantInt(AndRHS)) &&
(LHSKnown.One & RHSKnown.One & DemandedMask) != 0) {
APInt NewMask = ~(LHSKnown.One & RHSKnown.One & DemandedMask);
Constant *AndC = ConstantInt::get(VTy, NewMask & AndRHS->getValue());
Instruction *NewAnd = BinaryOperator::CreateAnd(I->getOperand(0), AndC);
InsertNewInstWith(NewAnd, *I);
Constant *XorC = ConstantInt::get(VTy, NewMask & XorRHS->getValue());
Instruction *NewXor = BinaryOperator::CreateXor(NewAnd, XorC);
return InsertNewInstWith(NewXor, *I);
}
}
break;
}
case Instruction::Select: {
if (SimplifyDemandedBits(I, 2, DemandedMask, RHSKnown, Depth + 1) ||
SimplifyDemandedBits(I, 1, DemandedMask, LHSKnown, Depth + 1))
return I;
assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?");
assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?");
auto CanonicalizeSelectConstant = [](Instruction *I, unsigned OpNo,
const APInt &DemandedMask) {
const APInt *SelC;
if (!match(I->getOperand(OpNo), m_APInt(SelC)))
return false;
Value *X;
const APInt *CmpC;
ICmpInst::Predicate Pred;
if (!match(I->getOperand(0), m_ICmp(Pred, m_Value(X), m_APInt(CmpC))) ||
isa<Constant>(X) || CmpC->getBitWidth() != SelC->getBitWidth())
return ShrinkDemandedConstant(I, OpNo, DemandedMask);
if (*CmpC == *SelC)
return false;
if ((*CmpC & DemandedMask) == (*SelC & DemandedMask)) {
I->setOperand(OpNo, ConstantInt::get(I->getType(), *CmpC));
return true;
}
return ShrinkDemandedConstant(I, OpNo, DemandedMask);
};
if (CanonicalizeSelectConstant(I, 1, DemandedMask) ||
CanonicalizeSelectConstant(I, 2, DemandedMask))
return I;
Known = KnownBits::commonBits(LHSKnown, RHSKnown);
break;
}
case Instruction::Trunc: {
Value *X;
const APInt *C;
if (match(I->getOperand(0), m_OneUse(m_LShr(m_Value(X), m_APInt(C))))) {
if (C->ult(VTy->getScalarSizeInBits()) &&
C->ule(DemandedMask.countLeadingZeros())) {
IRBuilderBase::InsertPointGuard Guard(Builder);
Builder.SetInsertPoint(I);
Value *Trunc = Builder.CreateTrunc(X, VTy);
return Builder.CreateLShr(Trunc, C->getZExtValue());
}
}
}
LLVM_FALLTHROUGH;
case Instruction::ZExt: {
unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits();
APInt InputDemandedMask = DemandedMask.zextOrTrunc(SrcBitWidth);
KnownBits InputKnown(SrcBitWidth);
if (SimplifyDemandedBits(I, 0, InputDemandedMask, InputKnown, Depth + 1))
return I;
assert(InputKnown.getBitWidth() == SrcBitWidth && "Src width changed?");
Known = InputKnown.zextOrTrunc(BitWidth);
assert(!Known.hasConflict() && "Bits known to be one AND zero?");
break;
}
case Instruction::BitCast:
if (!I->getOperand(0)->getType()->isIntOrIntVectorTy())
return nullptr;
if (auto *DstVTy = dyn_cast<VectorType>(VTy)) {
if (auto *SrcVTy = dyn_cast<VectorType>(I->getOperand(0)->getType())) {
if (cast<FixedVectorType>(DstVTy)->getNumElements() !=
cast<FixedVectorType>(SrcVTy)->getNumElements())
return nullptr;
} else
return nullptr;
} else if (I->getOperand(0)->getType()->isVectorTy())
return nullptr;
if (SimplifyDemandedBits(I, 0, DemandedMask, Known, Depth + 1))
return I;
assert(!Known.hasConflict() && "Bits known to be one AND zero?");
break;
case Instruction::SExt: {
unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits();
APInt InputDemandedBits = DemandedMask.trunc(SrcBitWidth);
if (DemandedMask.getActiveBits() > SrcBitWidth)
InputDemandedBits.setBit(SrcBitWidth-1);
KnownBits InputKnown(SrcBitWidth);
if (SimplifyDemandedBits(I, 0, InputDemandedBits, InputKnown, Depth + 1))
return I;
if (InputKnown.isNonNegative() ||
DemandedMask.getActiveBits() <= SrcBitWidth) {
CastInst *NewCast = new ZExtInst(I->getOperand(0), VTy, I->getName());
return InsertNewInstWith(NewCast, *I);
}
Known = InputKnown.sext(BitWidth);
assert(!Known.hasConflict() && "Bits known to be one AND zero?");
break;
}
case Instruction::Add:
if ((DemandedMask & 1) == 0) {
Value *X, *Y;
if (match(I, m_c_Add(m_OneUse(m_ZExt(m_Value(X))),
m_OneUse(m_SExt(m_Value(Y))))) &&
X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType()) {
IRBuilderBase::InsertPointGuard Guard(Builder);
Builder.SetInsertPoint(I);
Value *AndNot = Builder.CreateAnd(Builder.CreateNot(X), Y);
return Builder.CreateSExt(AndNot, VTy);
}
if (match(I, m_Add(m_OneUse(m_SExt(m_Value(X))),
m_OneUse(m_SExt(m_Value(Y))))) &&
X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType()) {
IRBuilderBase::InsertPointGuard Guard(Builder);
Builder.SetInsertPoint(I);
Value *Or = Builder.CreateOr(X, Y);
return Builder.CreateSExt(Or, VTy);
}
}
LLVM_FALLTHROUGH;
case Instruction::Sub: {
APInt DemandedFromOps;
if (simplifyOperandsBasedOnUnusedHighBits(DemandedFromOps))
return I;
if (DemandedFromOps.isSubsetOf(RHSKnown.Zero))
return I->getOperand(0);
if ((I->getOpcode() == Instruction::Add || DemandedFromOps.isOne()) &&
DemandedFromOps.isSubsetOf(LHSKnown.Zero))
return I->getOperand(1);
bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
Known = KnownBits::computeForAddSub(I->getOpcode() == Instruction::Add,
NSW, LHSKnown, RHSKnown);
break;
}
case Instruction::Mul: {
APInt DemandedFromOps;
if (simplifyOperandsBasedOnUnusedHighBits(DemandedFromOps))
return I;
if (DemandedMask.isPowerOf2()) {
unsigned CTZ = DemandedMask.countTrailingZeros();
const APInt *C;
if (match(I->getOperand(1), m_APInt(C)) &&
C->countTrailingZeros() == CTZ) {
Constant *ShiftC = ConstantInt::get(VTy, CTZ);
Instruction *Shl = BinaryOperator::CreateShl(I->getOperand(0), ShiftC);
return InsertNewInstWith(Shl, *I);
}
}
if (I->getOperand(0) == I->getOperand(1) && DemandedMask.ult(4)) {
Constant *One = ConstantInt::get(VTy, 1);
Instruction *And1 = BinaryOperator::CreateAnd(I->getOperand(0), One);
return InsertNewInstWith(And1, *I);
}
computeKnownBits(I, Known, Depth, CxtI);
break;
}
case Instruction::Shl: {
const APInt *SA;
if (match(I->getOperand(1), m_APInt(SA))) {
const APInt *ShrAmt;
if (match(I->getOperand(0), m_Shr(m_Value(), m_APInt(ShrAmt))))
if (Instruction *Shr = dyn_cast<Instruction>(I->getOperand(0)))
if (Value *R = simplifyShrShlDemandedBits(Shr, *ShrAmt, I, *SA,
DemandedMask, Known))
return R;
uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
Value *X;
Constant *C;
if (DemandedMask.countTrailingZeros() >= ShiftAmt &&
match(I->getOperand(0), m_LShr(m_ImmConstant(C), m_Value(X)))) {
Constant *LeftShiftAmtC = ConstantInt::get(VTy, ShiftAmt);
Constant *NewC = ConstantExpr::getShl(C, LeftShiftAmtC);
if (ConstantExpr::getLShr(NewC, LeftShiftAmtC) == C) {
Instruction *Lshr = BinaryOperator::CreateLShr(NewC, X);
return InsertNewInstWith(Lshr, *I);
}
}
APInt DemandedMaskIn(DemandedMask.lshr(ShiftAmt));
ShlOperator *IOp = cast<ShlOperator>(I);
if (IOp->hasNoSignedWrap())
DemandedMaskIn.setHighBits(ShiftAmt+1);
else if (IOp->hasNoUnsignedWrap())
DemandedMaskIn.setHighBits(ShiftAmt);
if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1))
return I;
assert(!Known.hasConflict() && "Bits known to be one AND zero?");
bool SignBitZero = Known.Zero.isSignBitSet();
bool SignBitOne = Known.One.isSignBitSet();
Known.Zero <<= ShiftAmt;
Known.One <<= ShiftAmt;
if (ShiftAmt)
Known.Zero.setLowBits(ShiftAmt);
if (IOp->hasNoSignedWrap()) {
if (SignBitZero)
Known.Zero.setSignBit();
else if (SignBitOne)
Known.One.setSignBit();
if (Known.hasConflict())
return UndefValue::get(VTy);
}
} else {
if (unsigned CTLZ = DemandedMask.countLeadingZeros()) {
APInt DemandedFromOp(APInt::getLowBitsSet(BitWidth, BitWidth - CTLZ));
if (SimplifyDemandedBits(I, 0, DemandedFromOp, Known, Depth + 1)) {
I->dropPoisonGeneratingFlags();
return I;
}
}
computeKnownBits(I, Known, Depth, CxtI);
}
break;
}
case Instruction::LShr: {
const APInt *SA;
if (match(I->getOperand(1), m_APInt(SA))) {
uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
if (DemandedMask.countLeadingZeros() >= ShiftAmt) {
unsigned NumHiDemandedBits =
BitWidth - DemandedMask.countTrailingZeros();
unsigned SignBits =
ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI);
if (SignBits >= NumHiDemandedBits)
return I->getOperand(0);
Value *X;
Constant *C;
if (match(I->getOperand(0), m_Shl(m_ImmConstant(C), m_Value(X)))) {
Constant *RightShiftAmtC = ConstantInt::get(VTy, ShiftAmt);
Constant *NewC = ConstantExpr::getLShr(C, RightShiftAmtC);
if (ConstantExpr::getShl(NewC, RightShiftAmtC) == C) {
Instruction *Shl = BinaryOperator::CreateShl(NewC, X);
return InsertNewInstWith(Shl, *I);
}
}
}
APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt));
if (cast<LShrOperator>(I)->isExact())
DemandedMaskIn.setLowBits(ShiftAmt);
if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1))
return I;
assert(!Known.hasConflict() && "Bits known to be one AND zero?");
Known.Zero.lshrInPlace(ShiftAmt);
Known.One.lshrInPlace(ShiftAmt);
if (ShiftAmt)
Known.Zero.setHighBits(ShiftAmt); } else {
computeKnownBits(I, Known, Depth, CxtI);
}
break;
}
case Instruction::AShr: {
unsigned SignBits = ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI);
unsigned NumHiDemandedBits = BitWidth - DemandedMask.countTrailingZeros();
if (SignBits >= NumHiDemandedBits)
return I->getOperand(0);
if (DemandedMask.isOne()) {
Instruction *NewVal = BinaryOperator::CreateLShr(
I->getOperand(0), I->getOperand(1), I->getName());
return InsertNewInstWith(NewVal, *I);
}
const APInt *SA;
if (match(I->getOperand(1), m_APInt(SA))) {
uint32_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt));
if (DemandedMask.countLeadingZeros() <= ShiftAmt)
DemandedMaskIn.setSignBit();
if (cast<AShrOperator>(I)->isExact())
DemandedMaskIn.setLowBits(ShiftAmt);
if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1))
return I;
assert(!Known.hasConflict() && "Bits known to be one AND zero?");
APInt HighBits(APInt::getHighBitsSet(
BitWidth, std::min(SignBits + ShiftAmt - 1, BitWidth)));
Known.Zero.lshrInPlace(ShiftAmt);
Known.One.lshrInPlace(ShiftAmt);
assert(BitWidth > ShiftAmt && "Shift amount not saturated?");
if (Known.Zero[BitWidth-ShiftAmt-1] ||
!DemandedMask.intersects(HighBits)) {
BinaryOperator *LShr = BinaryOperator::CreateLShr(I->getOperand(0),
I->getOperand(1));
LShr->setIsExact(cast<BinaryOperator>(I)->isExact());
return InsertNewInstWith(LShr, *I);
} else if (Known.One[BitWidth-ShiftAmt-1]) { Known.One |= HighBits;
}
} else {
computeKnownBits(I, Known, Depth, CxtI);
}
break;
}
case Instruction::UDiv: {
const APInt *SA;
if (match(I->getOperand(1), m_APInt(SA))) {
if (cast<UDivOperator>(I)->isExact())
break;
unsigned RHSTrailingZeros = SA->countTrailingZeros();
APInt DemandedMaskIn =
APInt::getHighBitsSet(BitWidth, BitWidth - RHSTrailingZeros);
if (SimplifyDemandedBits(I, 0, DemandedMaskIn, LHSKnown, Depth + 1))
return I;
Known.Zero.setHighBits(std::min(
BitWidth, LHSKnown.Zero.countLeadingOnes() + RHSTrailingZeros));
} else {
computeKnownBits(I, Known, Depth, CxtI);
}
break;
}
case Instruction::SRem: {
const APInt *Rem;
if (match(I->getOperand(1), m_APInt(Rem))) {
if (Rem->isAllOnes())
break;
APInt RA = Rem->abs();
if (RA.isPowerOf2()) {
if (DemandedMask.ult(RA)) return I->getOperand(0);
APInt LowBits = RA - 1;
APInt Mask2 = LowBits | APInt::getSignMask(BitWidth);
if (SimplifyDemandedBits(I, 0, Mask2, LHSKnown, Depth + 1))
return I;
Known.Zero = LHSKnown.Zero & LowBits;
Known.One = LHSKnown.One & LowBits;
if (LHSKnown.isNonNegative() || LowBits.isSubsetOf(LHSKnown.Zero))
Known.Zero |= ~LowBits;
if (LHSKnown.isNegative() && LowBits.intersects(LHSKnown.One))
Known.One |= ~LowBits;
assert(!Known.hasConflict() && "Bits known to be one AND zero?");
break;
}
}
if (DemandedMask.isSignBitSet()) {
computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI);
if (LHSKnown.isNonNegative())
Known.makeNonNegative();
}
break;
}
case Instruction::URem: {
KnownBits Known2(BitWidth);
APInt AllOnes = APInt::getAllOnes(BitWidth);
if (SimplifyDemandedBits(I, 0, AllOnes, Known2, Depth + 1) ||
SimplifyDemandedBits(I, 1, AllOnes, Known2, Depth + 1))
return I;
unsigned Leaders = Known2.countMinLeadingZeros();
Known.Zero = APInt::getHighBitsSet(BitWidth, Leaders) & DemandedMask;
break;
}
case Instruction::Call: {
bool KnownBitsComputed = false;
if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
switch (II->getIntrinsicID()) {
case Intrinsic::abs: {
if (DemandedMask == 1)
return II->getArgOperand(0);
break;
}
case Intrinsic::ctpop: {
Value *X;
if (DemandedMask == 1 && VTy->getScalarSizeInBits() % 2 == 0 &&
match(II->getArgOperand(0), m_Not(m_Value(X)))) {
Function *Ctpop = Intrinsic::getDeclaration(
II->getModule(), Intrinsic::ctpop, VTy);
return InsertNewInstWith(CallInst::Create(Ctpop, {X}), *I);
}
break;
}
case Intrinsic::bswap: {
unsigned NLZ = DemandedMask.countLeadingZeros();
unsigned NTZ = DemandedMask.countTrailingZeros();
NLZ = alignDown(NLZ, 8);
NTZ = alignDown(NTZ, 8);
if (BitWidth - NLZ - NTZ == 8) {
Instruction *NewVal;
if (NLZ > NTZ)
NewVal = BinaryOperator::CreateLShr(
II->getArgOperand(0), ConstantInt::get(VTy, NLZ - NTZ));
else
NewVal = BinaryOperator::CreateShl(
II->getArgOperand(0), ConstantInt::get(VTy, NTZ - NLZ));
NewVal->takeName(I);
return InsertNewInstWith(NewVal, *I);
}
break;
}
case Intrinsic::fshr:
case Intrinsic::fshl: {
const APInt *SA;
if (!match(I->getOperand(2), m_APInt(SA)))
break;
uint64_t ShiftAmt = SA->urem(BitWidth);
if (II->getIntrinsicID() == Intrinsic::fshr)
ShiftAmt = BitWidth - ShiftAmt;
APInt DemandedMaskLHS(DemandedMask.lshr(ShiftAmt));
APInt DemandedMaskRHS(DemandedMask.shl(BitWidth - ShiftAmt));
if (SimplifyDemandedBits(I, 0, DemandedMaskLHS, LHSKnown, Depth + 1) ||
SimplifyDemandedBits(I, 1, DemandedMaskRHS, RHSKnown, Depth + 1))
return I;
Known.Zero = LHSKnown.Zero.shl(ShiftAmt) |
RHSKnown.Zero.lshr(BitWidth - ShiftAmt);
Known.One = LHSKnown.One.shl(ShiftAmt) |
RHSKnown.One.lshr(BitWidth - ShiftAmt);
KnownBitsComputed = true;
break;
}
case Intrinsic::umax: {
const APInt *C;
unsigned CTZ = DemandedMask.countTrailingZeros();
if (match(II->getArgOperand(1), m_APInt(C)) &&
CTZ >= C->getActiveBits())
return II->getArgOperand(0);
break;
}
case Intrinsic::umin: {
const APInt *C;
unsigned CTZ = DemandedMask.countTrailingZeros();
if (match(II->getArgOperand(1), m_APInt(C)) &&
CTZ >= C->getBitWidth() - C->countLeadingOnes())
return II->getArgOperand(0);
break;
}
default: {
Optional<Value *> V = targetSimplifyDemandedUseBitsIntrinsic(
*II, DemandedMask, Known, KnownBitsComputed);
if (V)
return V.value();
break;
}
}
}
if (!KnownBitsComputed)
computeKnownBits(V, Known, Depth, CxtI);
break;
}
}
if (DemandedMask.isSubsetOf(Known.Zero|Known.One))
return Constant::getIntegerValue(VTy, Known.One);
return nullptr;
}
Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
Instruction *I, const APInt &DemandedMask, KnownBits &Known, unsigned Depth,
Instruction *CxtI) {
unsigned BitWidth = DemandedMask.getBitWidth();
Type *ITy = I->getType();
KnownBits LHSKnown(BitWidth);
KnownBits RHSKnown(BitWidth);
switch (I->getOpcode()) {
case Instruction::And: {
computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI);
computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1,
CxtI);
Known = LHSKnown & RHSKnown;
if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
return Constant::getIntegerValue(ITy, Known.One);
if (DemandedMask.isSubsetOf(LHSKnown.Zero | RHSKnown.One))
return I->getOperand(0);
if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.One))
return I->getOperand(1);
break;
}
case Instruction::Or: {
computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI);
computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1,
CxtI);
Known = LHSKnown | RHSKnown;
if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
return Constant::getIntegerValue(ITy, Known.One);
if (DemandedMask.isSubsetOf(LHSKnown.One | RHSKnown.Zero))
return I->getOperand(0);
if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero))
return I->getOperand(1);
break;
}
case Instruction::Xor: {
computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI);
computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1,
CxtI);
Known = LHSKnown ^ RHSKnown;
if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
return Constant::getIntegerValue(ITy, Known.One);
if (DemandedMask.isSubsetOf(RHSKnown.Zero))
return I->getOperand(0);
if (DemandedMask.isSubsetOf(LHSKnown.Zero))
return I->getOperand(1);
break;
}
case Instruction::AShr: {
computeKnownBits(I, Known, Depth, CxtI);
if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
return Constant::getIntegerValue(ITy, Known.One);
const APInt *ShiftRC;
const APInt *ShiftLC;
Value *X;
unsigned BitWidth = DemandedMask.getBitWidth();
if (match(I,
m_AShr(m_Shl(m_Value(X), m_APInt(ShiftLC)), m_APInt(ShiftRC))) &&
ShiftLC == ShiftRC && ShiftLC->ult(BitWidth) &&
DemandedMask.isSubsetOf(APInt::getLowBitsSet(
BitWidth, BitWidth - ShiftRC->getZExtValue()))) {
return X;
}
break;
}
default:
computeKnownBits(I, Known, Depth, CxtI);
if (DemandedMask.isSubsetOf(Known.Zero|Known.One))
return Constant::getIntegerValue(ITy, Known.One);
break;
}
return nullptr;
}
Value *InstCombinerImpl::simplifyShrShlDemandedBits(
Instruction *Shr, const APInt &ShrOp1, Instruction *Shl,
const APInt &ShlOp1, const APInt &DemandedMask, KnownBits &Known) {
if (!ShlOp1 || !ShrOp1)
return nullptr;
Value *VarX = Shr->getOperand(0);
Type *Ty = VarX->getType();
unsigned BitWidth = Ty->getScalarSizeInBits();
if (ShlOp1.uge(BitWidth) || ShrOp1.uge(BitWidth))
return nullptr;
unsigned ShlAmt = ShlOp1.getZExtValue();
unsigned ShrAmt = ShrOp1.getZExtValue();
Known.One.clearAllBits();
Known.Zero.setLowBits(ShlAmt - 1);
Known.Zero &= DemandedMask;
APInt BitMask1(APInt::getAllOnes(BitWidth));
APInt BitMask2(APInt::getAllOnes(BitWidth));
bool isLshr = (Shr->getOpcode() == Instruction::LShr);
BitMask1 = isLshr ? (BitMask1.lshr(ShrAmt) << ShlAmt) :
(BitMask1.ashr(ShrAmt) << ShlAmt);
if (ShrAmt <= ShlAmt) {
BitMask2 <<= (ShlAmt - ShrAmt);
} else {
BitMask2 = isLshr ? BitMask2.lshr(ShrAmt - ShlAmt):
BitMask2.ashr(ShrAmt - ShlAmt);
}
if ((BitMask1 & DemandedMask) == (BitMask2 & DemandedMask)) {
if (ShrAmt == ShlAmt)
return VarX;
if (!Shr->hasOneUse())
return nullptr;
BinaryOperator *New;
if (ShrAmt < ShlAmt) {
Constant *Amt = ConstantInt::get(VarX->getType(), ShlAmt - ShrAmt);
New = BinaryOperator::CreateShl(VarX, Amt);
BinaryOperator *Orig = cast<BinaryOperator>(Shl);
New->setHasNoSignedWrap(Orig->hasNoSignedWrap());
New->setHasNoUnsignedWrap(Orig->hasNoUnsignedWrap());
} else {
Constant *Amt = ConstantInt::get(VarX->getType(), ShrAmt - ShlAmt);
New = isLshr ? BinaryOperator::CreateLShr(VarX, Amt) :
BinaryOperator::CreateAShr(VarX, Amt);
if (cast<BinaryOperator>(Shr)->isExact())
New->setIsExact(true);
}
return InsertNewInstWith(New, *Shl);
}
return nullptr;
}
Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
APInt DemandedElts,
APInt &UndefElts,
unsigned Depth,
bool AllowMultipleUsers) {
if (isa<ScalableVectorType>(V->getType()))
return nullptr;
unsigned VWidth = cast<FixedVectorType>(V->getType())->getNumElements();
APInt EltMask(APInt::getAllOnes(VWidth));
assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!");
if (match(V, m_Undef())) {
UndefElts = EltMask;
return nullptr;
}
if (DemandedElts.isZero()) { UndefElts = EltMask;
return PoisonValue::get(V->getType());
}
UndefElts = 0;
if (auto *C = dyn_cast<Constant>(V)) {
if (DemandedElts.isAllOnes())
return nullptr;
Type *EltTy = cast<VectorType>(V->getType())->getElementType();
Constant *Poison = PoisonValue::get(EltTy);
SmallVector<Constant*, 16> Elts;
for (unsigned i = 0; i != VWidth; ++i) {
if (!DemandedElts[i]) { Elts.push_back(Poison);
UndefElts.setBit(i);
continue;
}
Constant *Elt = C->getAggregateElement(i);
if (!Elt) return nullptr;
Elts.push_back(Elt);
if (isa<UndefValue>(Elt)) UndefElts.setBit(i);
}
Constant *NewCV = ConstantVector::get(Elts);
return NewCV != C ? NewCV : nullptr;
}
if (Depth == 10)
return nullptr;
if (!AllowMultipleUsers) {
if (!V->hasOneUse()) {
if (Depth != 0)
return nullptr;
DemandedElts = EltMask;
}
}
Instruction *I = dyn_cast<Instruction>(V);
if (!I) return nullptr;
bool MadeChange = false;
auto simplifyAndSetOp = [&](Instruction *Inst, unsigned OpNum,
APInt Demanded, APInt &Undef) {
auto *II = dyn_cast<IntrinsicInst>(Inst);
Value *Op = II ? II->getArgOperand(OpNum) : Inst->getOperand(OpNum);
if (Value *V = SimplifyDemandedVectorElts(Op, Demanded, Undef, Depth + 1)) {
replaceOperand(*Inst, OpNum, V);
MadeChange = true;
}
};
APInt UndefElts2(VWidth, 0);
APInt UndefElts3(VWidth, 0);
switch (I->getOpcode()) {
default: break;
case Instruction::GetElementPtr: {
auto mayIndexStructType = [](GetElementPtrInst &GEP) {
for (auto I = gep_type_begin(GEP), E = gep_type_end(GEP);
I != E; I++)
if (I.isStruct())
return true;
return false;
};
if (mayIndexStructType(cast<GetElementPtrInst>(*I)))
break;
for (unsigned i = 0; i < I->getNumOperands(); i++) {
if (i == 0 ? match(I->getOperand(i), m_Undef())
: match(I->getOperand(i), m_Poison())) {
UndefElts = EltMask;
return nullptr;
}
if (I->getOperand(i)->getType()->isVectorTy()) {
APInt UndefEltsOp(VWidth, 0);
simplifyAndSetOp(I, i, DemandedElts, UndefEltsOp);
if (i == 0)
UndefElts |= UndefEltsOp;
}
}
break;
}
case Instruction::InsertElement: {
ConstantInt *Idx = dyn_cast<ConstantInt>(I->getOperand(2));
if (!Idx) {
simplifyAndSetOp(I, 0, DemandedElts, UndefElts2);
break;
}
unsigned IdxNo = Idx->getZExtValue();
APInt PreInsertDemandedElts = DemandedElts;
if (IdxNo < VWidth)
PreInsertDemandedElts.clearBit(IdxNo);
Value *Vec;
if (PreInsertDemandedElts == 0 &&
match(I->getOperand(1),
m_ExtractElt(m_Value(Vec), m_SpecificInt(IdxNo))) &&
Vec->getType() == I->getType()) {
return Vec;
}
simplifyAndSetOp(I, 0, PreInsertDemandedElts, UndefElts);
if (IdxNo >= VWidth || !DemandedElts[IdxNo]) {
Worklist.push(I);
return I->getOperand(0);
}
UndefElts.clearBit(IdxNo);
break;
}
case Instruction::ShuffleVector: {
auto *Shuffle = cast<ShuffleVectorInst>(I);
assert(Shuffle->getOperand(0)->getType() ==
Shuffle->getOperand(1)->getType() &&
"Expected shuffle operands to have same type");
unsigned OpWidth = cast<FixedVectorType>(Shuffle->getOperand(0)->getType())
->getNumElements();
if (all_of(Shuffle->getShuffleMask(), [](int Elt) { return Elt == 0; }) &&
DemandedElts.isAllOnes()) {
if (!match(I->getOperand(1), m_Undef())) {
I->setOperand(1, PoisonValue::get(I->getOperand(1)->getType()));
MadeChange = true;
}
APInt LeftDemanded(OpWidth, 1);
APInt LHSUndefElts(OpWidth, 0);
simplifyAndSetOp(I, 0, LeftDemanded, LHSUndefElts);
if (LHSUndefElts[0])
UndefElts = EltMask;
else
UndefElts.clearAllBits();
break;
}
APInt LeftDemanded(OpWidth, 0), RightDemanded(OpWidth, 0);
for (unsigned i = 0; i < VWidth; i++) {
if (DemandedElts[i]) {
unsigned MaskVal = Shuffle->getMaskValue(i);
if (MaskVal != -1u) {
assert(MaskVal < OpWidth * 2 &&
"shufflevector mask index out of range!");
if (MaskVal < OpWidth)
LeftDemanded.setBit(MaskVal);
else
RightDemanded.setBit(MaskVal - OpWidth);
}
}
}
APInt LHSUndefElts(OpWidth, 0);
simplifyAndSetOp(I, 0, LeftDemanded, LHSUndefElts);
APInt RHSUndefElts(OpWidth, 0);
simplifyAndSetOp(I, 1, RightDemanded, RHSUndefElts);
if (VWidth == OpWidth) {
bool IsIdentityShuffle = true;
for (unsigned i = 0; i < VWidth; i++) {
unsigned MaskVal = Shuffle->getMaskValue(i);
if (DemandedElts[i] && i != MaskVal) {
IsIdentityShuffle = false;
break;
}
}
if (IsIdentityShuffle)
return Shuffle->getOperand(0);
}
bool NewUndefElts = false;
unsigned LHSIdx = -1u, LHSValIdx = -1u;
unsigned RHSIdx = -1u, RHSValIdx = -1u;
bool LHSUniform = true;
bool RHSUniform = true;
for (unsigned i = 0; i < VWidth; i++) {
unsigned MaskVal = Shuffle->getMaskValue(i);
if (MaskVal == -1u) {
UndefElts.setBit(i);
} else if (!DemandedElts[i]) {
NewUndefElts = true;
UndefElts.setBit(i);
} else if (MaskVal < OpWidth) {
if (LHSUndefElts[MaskVal]) {
NewUndefElts = true;
UndefElts.setBit(i);
} else {
LHSIdx = LHSIdx == -1u ? i : OpWidth;
LHSValIdx = LHSValIdx == -1u ? MaskVal : OpWidth;
LHSUniform = LHSUniform && (MaskVal == i);
}
} else {
if (RHSUndefElts[MaskVal - OpWidth]) {
NewUndefElts = true;
UndefElts.setBit(i);
} else {
RHSIdx = RHSIdx == -1u ? i : OpWidth;
RHSValIdx = RHSValIdx == -1u ? MaskVal - OpWidth : OpWidth;
RHSUniform = RHSUniform && (MaskVal - OpWidth == i);
}
}
}
if (OpWidth ==
cast<FixedVectorType>(Shuffle->getType())->getNumElements()) {
Value *Op = nullptr;
Constant *Value = nullptr;
unsigned Idx = -1u;
if (LHSIdx < OpWidth && RHSUniform) {
if (auto *CV = dyn_cast<ConstantVector>(Shuffle->getOperand(0))) {
Op = Shuffle->getOperand(1);
Value = CV->getOperand(LHSValIdx);
Idx = LHSIdx;
}
}
if (RHSIdx < OpWidth && LHSUniform) {
if (auto *CV = dyn_cast<ConstantVector>(Shuffle->getOperand(1))) {
Op = Shuffle->getOperand(0);
Value = CV->getOperand(RHSValIdx);
Idx = RHSIdx;
}
}
if (Op && Value) {
Instruction *New = InsertElementInst::Create(
Op, Value, ConstantInt::get(Type::getInt32Ty(I->getContext()), Idx),
Shuffle->getName());
InsertNewInstWith(New, *Shuffle);
return New;
}
}
if (NewUndefElts) {
SmallVector<int, 16> Elts;
for (unsigned i = 0; i < VWidth; ++i) {
if (UndefElts[i])
Elts.push_back(UndefMaskElem);
else
Elts.push_back(Shuffle->getMaskValue(i));
}
Shuffle->setShuffleMask(Elts);
MadeChange = true;
}
break;
}
case Instruction::Select: {
SelectInst *Sel = cast<SelectInst>(I);
if (Sel->getCondition()->getType()->isVectorTy()) {
simplifyAndSetOp(I, 0, DemandedElts, UndefElts);
}
APInt DemandedLHS(DemandedElts), DemandedRHS(DemandedElts);
if (auto *CV = dyn_cast<ConstantVector>(Sel->getCondition())) {
for (unsigned i = 0; i < VWidth; i++) {
Constant *CElt = CV->getAggregateElement(i);
if (isa<ConstantExpr>(CElt))
continue;
if (CElt->isNullValue())
DemandedLHS.clearBit(i);
else
DemandedRHS.clearBit(i);
}
}
simplifyAndSetOp(I, 1, DemandedLHS, UndefElts2);
simplifyAndSetOp(I, 2, DemandedRHS, UndefElts3);
UndefElts = UndefElts2 & UndefElts3;
break;
}
case Instruction::BitCast: {
VectorType *VTy = dyn_cast<VectorType>(I->getOperand(0)->getType());
if (!VTy) break;
unsigned InVWidth = cast<FixedVectorType>(VTy)->getNumElements();
APInt InputDemandedElts(InVWidth, 0);
UndefElts2 = APInt(InVWidth, 0);
unsigned Ratio;
if (VWidth == InVWidth) {
Ratio = 1;
InputDemandedElts = DemandedElts;
} else if ((VWidth % InVWidth) == 0) {
Ratio = VWidth / InVWidth;
for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx)
if (DemandedElts[OutIdx])
InputDemandedElts.setBit(OutIdx / Ratio);
} else if ((InVWidth % VWidth) == 0) {
Ratio = InVWidth / VWidth;
for (unsigned InIdx = 0; InIdx != InVWidth; ++InIdx)
if (DemandedElts[InIdx / Ratio])
InputDemandedElts.setBit(InIdx);
} else {
break;
}
simplifyAndSetOp(I, 0, InputDemandedElts, UndefElts2);
if (VWidth == InVWidth) {
UndefElts = UndefElts2;
} else if ((VWidth % InVWidth) == 0) {
for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx)
if (UndefElts2[OutIdx / Ratio])
UndefElts.setBit(OutIdx);
} else if ((InVWidth % VWidth) == 0) {
for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) {
APInt SubUndef = UndefElts2.lshr(OutIdx * Ratio).zextOrTrunc(Ratio);
if (SubUndef.countPopulation() == Ratio)
UndefElts.setBit(OutIdx);
}
} else {
llvm_unreachable("Unimp");
}
break;
}
case Instruction::FPTrunc:
case Instruction::FPExt:
simplifyAndSetOp(I, 0, DemandedElts, UndefElts);
break;
case Instruction::Call: {
IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
if (!II) break;
switch (II->getIntrinsicID()) {
case Intrinsic::masked_gather: case Intrinsic::masked_load: {
APInt DemandedPtrs(APInt::getAllOnes(VWidth)),
DemandedPassThrough(DemandedElts);
if (auto *CV = dyn_cast<ConstantVector>(II->getOperand(2)))
for (unsigned i = 0; i < VWidth; i++) {
Constant *CElt = CV->getAggregateElement(i);
if (CElt->isNullValue())
DemandedPtrs.clearBit(i);
else if (CElt->isAllOnesValue())
DemandedPassThrough.clearBit(i);
}
if (II->getIntrinsicID() == Intrinsic::masked_gather)
simplifyAndSetOp(II, 0, DemandedPtrs, UndefElts2);
simplifyAndSetOp(II, 3, DemandedPassThrough, UndefElts3);
UndefElts = UndefElts2 & UndefElts3;
break;
}
default: {
Optional<Value *> V = targetSimplifyDemandedVectorEltsIntrinsic(
*II, DemandedElts, UndefElts, UndefElts2, UndefElts3,
simplifyAndSetOp);
if (V)
return V.value();
break;
}
} break;
} }
BinaryOperator *BO;
if (match(I, m_BinOp(BO)) && !BO->isIntDivRem() && !BO->isShift()) {
simplifyAndSetOp(I, 0, DemandedElts, UndefElts);
simplifyAndSetOp(I, 1, DemandedElts, UndefElts2);
UndefElts &= UndefElts2;
}
if (UndefElts.isAllOnes())
return UndefValue::get(I->getType());;
return MadeChange ? I : nullptr;
}