#include "InstCombineInternal.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/CmpInstAnalysis.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/GetElementPtrTypeIterator.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/KnownBits.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
using namespace llvm;
using namespace PatternMatch;
#define DEBUG_TYPE "instcombine"
STATISTIC(NumSel, "Number of select opts");
static bool addWithOverflow(APInt &Result, const APInt &In1,
const APInt &In2, bool IsSigned = false) {
bool Overflow;
if (IsSigned)
Result = In1.sadd_ov(In2, Overflow);
else
Result = In1.uadd_ov(In2, Overflow);
return Overflow;
}
static bool subWithOverflow(APInt &Result, const APInt &In1,
const APInt &In2, bool IsSigned = false) {
bool Overflow;
if (IsSigned)
Result = In1.ssub_ov(In2, Overflow);
else
Result = In1.usub_ov(In2, Overflow);
return Overflow;
}
static bool hasBranchUse(ICmpInst &I) {
for (auto *U : I.users())
if (isa<BranchInst>(U))
return true;
return false;
}
static bool isSignTest(ICmpInst::Predicate &Pred, const APInt &C) {
if (!ICmpInst::isSigned(Pred))
return false;
if (C.isZero())
return ICmpInst::isRelational(Pred);
if (C.isOne()) {
if (Pred == ICmpInst::ICMP_SLT) {
Pred = ICmpInst::ICMP_SLE;
return true;
}
} else if (C.isAllOnes()) {
if (Pred == ICmpInst::ICMP_SGT) {
Pred = ICmpInst::ICMP_SGE;
return true;
}
}
return false;
}
Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal(
LoadInst *LI, GetElementPtrInst *GEP, GlobalVariable *GV, CmpInst &ICI,
ConstantInt *AndCst) {
if (LI->isVolatile() || LI->getType() != GEP->getResultElementType() ||
GV->getValueType() != GEP->getSourceElementType() ||
!GV->isConstant() || !GV->hasDefinitiveInitializer())
return nullptr;
Constant *Init = GV->getInitializer();
if (!isa<ConstantArray>(Init) && !isa<ConstantDataArray>(Init))
return nullptr;
uint64_t ArrayElementCount = Init->getType()->getArrayNumElements();
if (ArrayElementCount > MaxArraySizeForCombine)
return nullptr;
if (GEP->getNumOperands() < 3 ||
!isa<ConstantInt>(GEP->getOperand(1)) ||
!cast<ConstantInt>(GEP->getOperand(1))->isZero() ||
isa<Constant>(GEP->getOperand(2)))
return nullptr;
SmallVector<unsigned, 4> LaterIndices;
Type *EltTy = Init->getType()->getArrayElementType();
for (unsigned i = 3, e = GEP->getNumOperands(); i != e; ++i) {
ConstantInt *Idx = dyn_cast<ConstantInt>(GEP->getOperand(i));
if (!Idx) return nullptr;
uint64_t IdxVal = Idx->getZExtValue();
if ((unsigned)IdxVal != IdxVal) return nullptr;
if (StructType *STy = dyn_cast<StructType>(EltTy))
EltTy = STy->getElementType(IdxVal);
else if (ArrayType *ATy = dyn_cast<ArrayType>(EltTy)) {
if (IdxVal >= ATy->getNumElements()) return nullptr;
EltTy = ATy->getElementType();
} else {
return nullptr; }
LaterIndices.push_back(IdxVal);
}
enum { Overdefined = -3, Undefined = -2 };
int FirstTrueElement = Undefined, SecondTrueElement = Undefined;
int FirstFalseElement = Undefined, SecondFalseElement = Undefined;
int TrueRangeEnd = Undefined, FalseRangeEnd = Undefined;
uint64_t MagicBitvector = 0;
Constant *CompareRHS = cast<Constant>(ICI.getOperand(1));
for (unsigned i = 0, e = ArrayElementCount; i != e; ++i) {
Constant *Elt = Init->getAggregateElement(i);
if (!Elt) return nullptr;
if (!LaterIndices.empty()) {
Elt = ConstantFoldExtractValueInstruction(Elt, LaterIndices);
if (!Elt)
return nullptr;
}
if (AndCst) Elt = ConstantExpr::getAnd(Elt, AndCst);
Constant *C = ConstantFoldCompareInstOperands(ICI.getPredicate(), Elt,
CompareRHS, DL, &TLI);
if (isa<UndefValue>(C)) {
if (TrueRangeEnd == (int)i-1)
TrueRangeEnd = i;
if (FalseRangeEnd == (int)i-1)
FalseRangeEnd = i;
continue;
}
if (!isa<ConstantInt>(C)) return nullptr;
bool IsTrueForElt = !cast<ConstantInt>(C)->isZero();
if (IsTrueForElt) {
if (FirstTrueElement == Undefined)
FirstTrueElement = TrueRangeEnd = i; else {
if (SecondTrueElement == Undefined)
SecondTrueElement = i;
else
SecondTrueElement = Overdefined;
if (TrueRangeEnd == (int)i-1)
TrueRangeEnd = i;
else
TrueRangeEnd = Overdefined;
}
} else {
if (FirstFalseElement == Undefined)
FirstFalseElement = FalseRangeEnd = i; else {
if (SecondFalseElement == Undefined)
SecondFalseElement = i;
else
SecondFalseElement = Overdefined;
if (FalseRangeEnd == (int)i-1)
FalseRangeEnd = i;
else
FalseRangeEnd = Overdefined;
}
}
if (i < 64 && IsTrueForElt)
MagicBitvector |= 1ULL << i;
if ((i & 8) == 0 && i >= 64 && SecondTrueElement == Overdefined &&
SecondFalseElement == Overdefined && TrueRangeEnd == Overdefined &&
FalseRangeEnd == Overdefined)
return nullptr;
}
Value *Idx = GEP->getOperand(2);
if (!GEP->isInBounds()) {
Type *IntPtrTy = DL.getIntPtrType(GEP->getType());
unsigned PtrSize = IntPtrTy->getIntegerBitWidth();
if (Idx->getType()->getPrimitiveSizeInBits().getFixedSize() > PtrSize)
Idx = Builder.CreateTrunc(Idx, IntPtrTy);
}
unsigned ElementSize =
DL.getTypeAllocSize(Init->getType()->getArrayElementType());
auto MaskIdx = [&](Value* Idx){
if (!GEP->isInBounds() && countTrailingZeros(ElementSize) != 0) {
Value *Mask = ConstantInt::get(Idx->getType(), -1);
Mask = Builder.CreateLShr(Mask, countTrailingZeros(ElementSize));
Idx = Builder.CreateAnd(Idx, Mask);
}
return Idx;
};
if (SecondTrueElement != Overdefined) {
Idx = MaskIdx(Idx);
if (FirstTrueElement == Undefined)
return replaceInstUsesWith(ICI, Builder.getFalse());
Value *FirstTrueIdx = ConstantInt::get(Idx->getType(), FirstTrueElement);
if (SecondTrueElement == Undefined)
return new ICmpInst(ICmpInst::ICMP_EQ, Idx, FirstTrueIdx);
Value *C1 = Builder.CreateICmpEQ(Idx, FirstTrueIdx);
Value *SecondTrueIdx = ConstantInt::get(Idx->getType(), SecondTrueElement);
Value *C2 = Builder.CreateICmpEQ(Idx, SecondTrueIdx);
return BinaryOperator::CreateOr(C1, C2);
}
if (SecondFalseElement != Overdefined) {
Idx = MaskIdx(Idx);
if (FirstFalseElement == Undefined)
return replaceInstUsesWith(ICI, Builder.getTrue());
Value *FirstFalseIdx = ConstantInt::get(Idx->getType(), FirstFalseElement);
if (SecondFalseElement == Undefined)
return new ICmpInst(ICmpInst::ICMP_NE, Idx, FirstFalseIdx);
Value *C1 = Builder.CreateICmpNE(Idx, FirstFalseIdx);
Value *SecondFalseIdx = ConstantInt::get(Idx->getType(),SecondFalseElement);
Value *C2 = Builder.CreateICmpNE(Idx, SecondFalseIdx);
return BinaryOperator::CreateAnd(C1, C2);
}
if (TrueRangeEnd != Overdefined) {
assert(TrueRangeEnd != FirstTrueElement && "Should emit single compare");
Idx = MaskIdx(Idx);
if (FirstTrueElement) {
Value *Offs = ConstantInt::get(Idx->getType(), -FirstTrueElement);
Idx = Builder.CreateAdd(Idx, Offs);
}
Value *End = ConstantInt::get(Idx->getType(),
TrueRangeEnd-FirstTrueElement+1);
return new ICmpInst(ICmpInst::ICMP_ULT, Idx, End);
}
if (FalseRangeEnd != Overdefined) {
assert(FalseRangeEnd != FirstFalseElement && "Should emit single compare");
Idx = MaskIdx(Idx);
if (FirstFalseElement) {
Value *Offs = ConstantInt::get(Idx->getType(), -FirstFalseElement);
Idx = Builder.CreateAdd(Idx, Offs);
}
Value *End = ConstantInt::get(Idx->getType(),
FalseRangeEnd-FirstFalseElement);
return new ICmpInst(ICmpInst::ICMP_UGT, Idx, End);
}
{
Type *Ty = nullptr;
if (ArrayElementCount <= Idx->getType()->getIntegerBitWidth())
Ty = Idx->getType();
else
Ty = DL.getSmallestLegalIntType(Init->getContext(), ArrayElementCount);
if (Ty) {
Idx = MaskIdx(Idx);
Value *V = Builder.CreateIntCast(Idx, Ty, false);
V = Builder.CreateLShr(ConstantInt::get(Ty, MagicBitvector), V);
V = Builder.CreateAnd(ConstantInt::get(Ty, 1), V);
return new ICmpInst(ICmpInst::ICMP_NE, V, ConstantInt::get(Ty, 0));
}
}
return nullptr;
}
static Value *evaluateGEPOffsetExpression(User *GEP, InstCombinerImpl &IC,
const DataLayout &DL) {
gep_type_iterator GTI = gep_type_begin(GEP);
unsigned i, e = GEP->getNumOperands();
int64_t Offset = 0;
for (i = 1; i != e; ++i, ++GTI) {
if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) {
if (CI->isZero()) continue;
if (StructType *STy = GTI.getStructTypeOrNull()) {
Offset += DL.getStructLayout(STy)->getElementOffset(CI->getZExtValue());
} else {
uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType());
Offset += Size*CI->getSExtValue();
}
} else {
break;
}
}
if (i == e) return nullptr;
Value *VariableIdx = GEP->getOperand(i);
uint64_t VariableScale = DL.getTypeAllocSize(GTI.getIndexedType());
for (++i, ++GTI; i != e; ++i, ++GTI) {
ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i));
if (!CI) return nullptr;
if (CI->isZero()) continue;
if (StructType *STy = GTI.getStructTypeOrNull()) {
Offset += DL.getStructLayout(STy)->getElementOffset(CI->getZExtValue());
} else {
uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType());
Offset += Size*CI->getSExtValue();
}
}
Type *IntPtrTy = DL.getIntPtrType(GEP->getOperand(0)->getType());
unsigned IntPtrWidth = IntPtrTy->getIntegerBitWidth();
if (Offset == 0) {
if (VariableIdx->getType()->getPrimitiveSizeInBits().getFixedSize() >
IntPtrWidth) {
VariableIdx = IC.Builder.CreateTrunc(VariableIdx, IntPtrTy);
}
return VariableIdx;
}
Offset = SignExtend64(Offset, IntPtrWidth);
VariableScale = SignExtend64(VariableScale, IntPtrWidth);
int64_t NewOffs = Offset / (int64_t)VariableScale;
if (Offset != NewOffs*(int64_t)VariableScale)
return nullptr;
if (VariableIdx->getType() != IntPtrTy)
VariableIdx = IC.Builder.CreateIntCast(VariableIdx, IntPtrTy,
true );
Constant *OffsetVal = ConstantInt::get(IntPtrTy, NewOffs);
return IC.Builder.CreateAdd(VariableIdx, OffsetVal, "offset");
}
static bool canRewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base,
const DataLayout &DL,
SetVector<Value *> &Explored) {
SmallVector<Value *, 16> WorkList(1, Start);
Explored.insert(Base);
while (!WorkList.empty()) {
SetVector<PHINode *> PHIs;
while (!WorkList.empty()) {
if (Explored.size() >= 100)
return false;
Value *V = WorkList.back();
if (Explored.contains(V)) {
WorkList.pop_back();
continue;
}
if (!isa<IntToPtrInst>(V) && !isa<PtrToIntInst>(V) &&
!isa<GetElementPtrInst>(V) && !isa<PHINode>(V))
return false;
if (isa<IntToPtrInst>(V) || isa<PtrToIntInst>(V)) {
auto *CI = cast<CastInst>(V);
if (!CI->isNoopCast(DL))
return false;
if (!Explored.contains(CI->getOperand(0)))
WorkList.push_back(CI->getOperand(0));
}
if (auto *GEP = dyn_cast<GEPOperator>(V)) {
if (GEP->getNumIndices() != 1 || !GEP->isInBounds() ||
GEP->getSourceElementType() != ElemTy)
return false;
if (!Explored.contains(GEP->getOperand(0)))
WorkList.push_back(GEP->getOperand(0));
}
if (WorkList.back() == V) {
WorkList.pop_back();
Explored.insert(V);
}
if (auto *PN = dyn_cast<PHINode>(V)) {
if (isa<CatchSwitchInst>(PN->getParent()->getTerminator()))
return false;
Explored.insert(PN);
PHIs.insert(PN);
}
}
for (auto *PN : PHIs)
for (Value *Op : PN->incoming_values())
if (!Explored.contains(Op))
WorkList.push_back(Op);
}
for (Value *Val : Explored) {
for (Value *Use : Val->uses()) {
auto *PHI = dyn_cast<PHINode>(Use);
auto *Inst = dyn_cast<Instruction>(Val);
if (Inst == Base || Inst == PHI || !Inst || !PHI ||
!Explored.contains(PHI))
continue;
if (PHI->getParent() == Inst->getParent())
return false;
}
}
return true;
}
static void setInsertionPoint(IRBuilder<> &Builder, Value *V,
bool Before = true) {
if (auto *PHI = dyn_cast<PHINode>(V)) {
Builder.SetInsertPoint(&*PHI->getParent()->getFirstInsertionPt());
return;
}
if (auto *I = dyn_cast<Instruction>(V)) {
if (!Before)
I = &*std::next(I->getIterator());
Builder.SetInsertPoint(I);
return;
}
if (auto *A = dyn_cast<Argument>(V)) {
BasicBlock &Entry = A->getParent()->getEntryBlock();
Builder.SetInsertPoint(&*Entry.getFirstInsertionPt());
return;
}
assert(isa<Constant>(V) && "Setting insertion point for unknown value!");
}
static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base,
const DataLayout &DL,
SetVector<Value *> &Explored) {
Type *IndexType = IntegerType::get(
Base->getContext(), DL.getIndexTypeSizeInBits(Start->getType()));
DenseMap<Value *, Value *> NewInsts;
NewInsts[Base] = ConstantInt::getNullValue(IndexType);
for (Value *Val : Explored) {
if (Val == Base)
continue;
if (auto *PHI = dyn_cast<PHINode>(Val))
NewInsts[PHI] = PHINode::Create(IndexType, PHI->getNumIncomingValues(),
PHI->getName() + ".idx", PHI);
}
IRBuilder<> Builder(Base->getContext());
for (Value *Val : Explored) {
if (NewInsts.find(Val) != NewInsts.end())
continue;
if (auto *CI = dyn_cast<CastInst>(Val)) {
Value *V = NewInsts[CI->getOperand(0)];
NewInsts[CI] = V;
continue;
}
if (auto *GEP = dyn_cast<GEPOperator>(Val)) {
Value *Index = NewInsts[GEP->getOperand(1)] ? NewInsts[GEP->getOperand(1)]
: GEP->getOperand(1);
setInsertionPoint(Builder, GEP);
if (Index->getType()->getScalarSizeInBits() !=
NewInsts[GEP->getOperand(0)]->getType()->getScalarSizeInBits()) {
Index = Builder.CreateSExtOrTrunc(
Index, NewInsts[GEP->getOperand(0)]->getType(),
GEP->getOperand(0)->getName() + ".sext");
}
auto *Op = NewInsts[GEP->getOperand(0)];
if (isa<ConstantInt>(Op) && cast<ConstantInt>(Op)->isZero())
NewInsts[GEP] = Index;
else
NewInsts[GEP] = Builder.CreateNSWAdd(
Op, Index, GEP->getOperand(0)->getName() + ".add");
continue;
}
if (isa<PHINode>(Val))
continue;
llvm_unreachable("Unexpected instruction type");
}
for (Value *Val : Explored) {
if (Val == Base)
continue;
if (auto *PHI = dyn_cast<PHINode>(Val)) {
PHINode *NewPhi = static_cast<PHINode *>(NewInsts[PHI]);
for (unsigned I = 0, E = PHI->getNumIncomingValues(); I < E; ++I) {
Value *NewIncoming = PHI->getIncomingValue(I);
if (NewInsts.find(NewIncoming) != NewInsts.end())
NewIncoming = NewInsts[NewIncoming];
NewPhi->addIncoming(NewIncoming, PHI->getIncomingBlock(I));
}
}
}
PointerType *PtrTy =
ElemTy->getPointerTo(Start->getType()->getPointerAddressSpace());
for (Value *Val : Explored) {
if (Val == Base)
continue;
setInsertionPoint(Builder, Val, false);
Value *NewVal = Builder.CreateBitOrPointerCast(
Base, PtrTy, Start->getName() + "to.ptr");
NewVal = Builder.CreateInBoundsGEP(
ElemTy, NewVal, makeArrayRef(NewInsts[Val]), Val->getName() + ".ptr");
NewVal = Builder.CreateBitOrPointerCast(
NewVal, Val->getType(), Val->getName() + ".conv");
Val->replaceAllUsesWith(NewVal);
}
return NewInsts[Start];
}
static std::pair<Value *, Value *>
getAsConstantIndexedAddress(Type *ElemTy, Value *V, const DataLayout &DL) {
Type *IndexType = IntegerType::get(V->getContext(),
DL.getIndexTypeSizeInBits(V->getType()));
Constant *Index = ConstantInt::getNullValue(IndexType);
while (true) {
if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) {
if (!GEP->isInBounds())
break;
if (GEP->hasAllConstantIndices() && GEP->getNumIndices() == 1 &&
GEP->getSourceElementType() == ElemTy) {
V = GEP->getOperand(0);
Constant *GEPIndex = static_cast<Constant *>(GEP->getOperand(1));
Index = ConstantExpr::getAdd(
Index, ConstantExpr::getSExtOrTrunc(GEPIndex, IndexType));
continue;
}
break;
}
if (auto *CI = dyn_cast<IntToPtrInst>(V)) {
if (!CI->isNoopCast(DL))
break;
V = CI->getOperand(0);
continue;
}
if (auto *CI = dyn_cast<PtrToIntInst>(V)) {
if (!CI->isNoopCast(DL))
break;
V = CI->getOperand(0);
continue;
}
break;
}
return {V, Index};
}
static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS,
ICmpInst::Predicate Cond,
const DataLayout &DL) {
if (GEPLHS->getType()->isVectorTy())
return nullptr;
if (!GEPLHS->hasAllConstantIndices())
return nullptr;
Type *ElemTy = GEPLHS->getSourceElementType();
Value *PtrBase, *Index;
std::tie(PtrBase, Index) = getAsConstantIndexedAddress(ElemTy, GEPLHS, DL);
SetVector<Value *> Nodes;
if (!canRewriteGEPAsOffset(ElemTy, RHS, PtrBase, DL, Nodes))
return nullptr;
Value *NewRHS = rewriteGEPAsOffset(ElemTy, RHS, PtrBase, DL, Nodes);
return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Index, NewRHS);
}
Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
ICmpInst::Predicate Cond,
Instruction &I) {
if (ICmpInst::isSigned(Cond))
return nullptr;
if (!isa<GetElementPtrInst>(RHS))
RHS = RHS->stripPointerCasts();
Value *PtrBase = GEPLHS->getOperand(0);
if (PtrBase == RHS && GEPLHS->isInBounds() &&
!GEPLHS->getType()->isVectorTy()) {
Value *Offset = evaluateGEPOffsetExpression(GEPLHS, *this, DL);
if (!Offset)
Offset = EmitGEPOffset(GEPLHS);
return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Offset,
Constant::getNullValue(Offset->getType()));
}
if (GEPLHS->isInBounds() && ICmpInst::isEquality(Cond) &&
isa<Constant>(RHS) && cast<Constant>(RHS)->isNullValue() &&
!NullPointerIsDefined(I.getFunction(),
RHS->getType()->getPointerAddressSpace())) {
auto *Base = GEPLHS->getPointerOperand();
if (GEPLHS->getType()->isVectorTy() && Base->getType()->isPointerTy()) {
auto EC = cast<VectorType>(GEPLHS->getType())->getElementCount();
Base = Builder.CreateVectorSplat(EC, Base);
}
return new ICmpInst(Cond, Base,
ConstantExpr::getPointerBitCastOrAddrSpaceCast(
cast<Constant>(RHS), Base->getType()));
} else if (GEPOperator *GEPRHS = dyn_cast<GEPOperator>(RHS)) {
if (PtrBase != GEPRHS->getOperand(0)) {
bool IndicesTheSame =
GEPLHS->getNumOperands() == GEPRHS->getNumOperands() &&
GEPLHS->getPointerOperand()->getType() ==
GEPRHS->getPointerOperand()->getType() &&
GEPLHS->getSourceElementType() == GEPRHS->getSourceElementType();
if (IndicesTheSame)
for (unsigned i = 1, e = GEPLHS->getNumOperands(); i != e; ++i)
if (GEPLHS->getOperand(i) != GEPRHS->getOperand(i)) {
IndicesTheSame = false;
break;
}
Type *BaseType = GEPLHS->getOperand(0)->getType();
if (IndicesTheSame && CmpInst::makeCmpResultType(BaseType) == I.getType())
return new ICmpInst(Cond, GEPLHS->getOperand(0), GEPRHS->getOperand(0));
if (GEPLHS->isInBounds() && GEPRHS->isInBounds() &&
(GEPLHS->hasAllConstantIndices() || GEPLHS->hasOneUse()) &&
(GEPRHS->hasAllConstantIndices() || GEPRHS->hasOneUse()) &&
PtrBase->stripPointerCasts() ==
GEPRHS->getOperand(0)->stripPointerCasts() &&
!GEPLHS->getType()->isVectorTy()) {
Value *LOffset = EmitGEPOffset(GEPLHS);
Value *ROffset = EmitGEPOffset(GEPRHS);
Type *LHSIndexTy = LOffset->getType();
Type *RHSIndexTy = ROffset->getType();
if (LHSIndexTy != RHSIndexTy) {
if (LHSIndexTy->getPrimitiveSizeInBits().getFixedSize() <
RHSIndexTy->getPrimitiveSizeInBits().getFixedSize()) {
ROffset = Builder.CreateTrunc(ROffset, LHSIndexTy);
} else
LOffset = Builder.CreateTrunc(LOffset, RHSIndexTy);
}
Value *Cmp = Builder.CreateICmp(ICmpInst::getSignedPredicate(Cond),
LOffset, ROffset);
return replaceInstUsesWith(I, Cmp);
}
return transformToIndexedCompare(GEPLHS, RHS, Cond, DL);
}
if (!GEPLHS->getType()->isVectorTy() && GEPLHS->hasAllZeroIndices())
return foldGEPICmp(GEPRHS, GEPLHS->getOperand(0),
ICmpInst::getSwappedPredicate(Cond), I);
if (!GEPRHS->getType()->isVectorTy() && GEPRHS->hasAllZeroIndices())
return foldGEPICmp(GEPLHS, GEPRHS->getOperand(0), Cond, I);
bool GEPsInBounds = GEPLHS->isInBounds() && GEPRHS->isInBounds();
if (GEPLHS->getNumOperands() == GEPRHS->getNumOperands() &&
GEPLHS->getSourceElementType() == GEPRHS->getSourceElementType()) {
unsigned NumDifferences = 0; unsigned DiffOperand = 0; for (unsigned i = 1, e = GEPRHS->getNumOperands(); i != e; ++i)
if (GEPLHS->getOperand(i) != GEPRHS->getOperand(i)) {
Type *LHSType = GEPLHS->getOperand(i)->getType();
Type *RHSType = GEPRHS->getOperand(i)->getType();
if (LHSType->getPrimitiveSizeInBits() !=
RHSType->getPrimitiveSizeInBits() ||
(GEPLHS->getType()->isVectorTy() &&
(!LHSType->isVectorTy() || !RHSType->isVectorTy()))) {
NumDifferences = 2;
break;
}
if (NumDifferences++) break;
DiffOperand = i;
}
if (NumDifferences == 0) return replaceInstUsesWith(I, ConstantInt::get(I.getType(), ICmpInst::isTrueWhenEqual(Cond)));
else if (NumDifferences == 1 && GEPsInBounds) {
Value *LHSV = GEPLHS->getOperand(DiffOperand);
Value *RHSV = GEPRHS->getOperand(DiffOperand);
return new ICmpInst(ICmpInst::getSignedPredicate(Cond), LHSV, RHSV);
}
}
if (GEPsInBounds && (isa<ConstantExpr>(GEPLHS) || GEPLHS->hasOneUse()) &&
(isa<ConstantExpr>(GEPRHS) || GEPRHS->hasOneUse())) {
Value *L = EmitGEPOffset(GEPLHS);
Value *R = EmitGEPOffset(GEPRHS);
return new ICmpInst(ICmpInst::getSignedPredicate(Cond), L, R);
}
}
return transformToIndexedCompare(GEPLHS, RHS, Cond, DL);
}
Instruction *InstCombinerImpl::foldAllocaCmp(ICmpInst &ICI,
const AllocaInst *Alloca) {
assert(ICI.isEquality() && "Cannot fold non-equality comparison.");
unsigned MaxIter = 32;
SmallVector<const Use *, 32> Worklist;
for (const Use &U : Alloca->uses()) {
if (Worklist.size() >= MaxIter)
return nullptr;
Worklist.push_back(&U);
}
unsigned NumCmps = 0;
while (!Worklist.empty()) {
assert(Worklist.size() <= MaxIter);
const Use *U = Worklist.pop_back_val();
const Value *V = U->getUser();
--MaxIter;
if (isa<BitCastInst>(V) || isa<GetElementPtrInst>(V) || isa<PHINode>(V) ||
isa<SelectInst>(V)) {
} else if (isa<LoadInst>(V)) {
continue;
} else if (const auto *SI = dyn_cast<StoreInst>(V)) {
if (SI->getValueOperand() == U->get())
return nullptr;
continue;
} else if (isa<ICmpInst>(V)) {
if (NumCmps++)
return nullptr; continue;
} else if (const auto *Intrin = dyn_cast<IntrinsicInst>(V)) {
switch (Intrin->getIntrinsicID()) {
case Intrinsic::lifetime_start: case Intrinsic::lifetime_end:
case Intrinsic::memcpy: case Intrinsic::memmove: case Intrinsic::memset:
continue;
default:
return nullptr;
}
} else {
return nullptr;
}
for (const Use &U : V->uses()) {
if (Worklist.size() >= MaxIter)
return nullptr;
Worklist.push_back(&U);
}
}
auto *Res = ConstantInt::get(ICI.getType(),
!CmpInst::isTrueWhenEqual(ICI.getPredicate()));
return replaceInstUsesWith(ICI, Res);
}
Instruction *InstCombinerImpl::foldICmpAddOpConst(Value *X, const APInt &C,
ICmpInst::Predicate Pred) {
assert(!!C && "C should not be zero!");
if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) {
Constant *R = ConstantInt::get(X->getType(),
APInt::getMaxValue(C.getBitWidth()) - C);
return new ICmpInst(ICmpInst::ICMP_UGT, X, R);
}
if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE)
return new ICmpInst(ICmpInst::ICMP_ULT, X,
ConstantInt::get(X->getType(), -C));
APInt SMax = APInt::getSignedMaxValue(C.getBitWidth());
if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
return new ICmpInst(ICmpInst::ICMP_SGT, X,
ConstantInt::get(X->getType(), SMax - C));
assert(Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE);
return new ICmpInst(ICmpInst::ICMP_SLT, X,
ConstantInt::get(X->getType(), SMax - (C - 1)));
}
Instruction *InstCombinerImpl::foldICmpShrConstConst(ICmpInst &I, Value *A,
const APInt &AP1,
const APInt &AP2) {
assert(I.isEquality() && "Cannot fold icmp gt/lt");
auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) {
if (I.getPredicate() == I.ICMP_NE)
Pred = CmpInst::getInversePredicate(Pred);
return new ICmpInst(Pred, LHS, RHS);
};
if (AP2.isZero())
return nullptr;
bool IsAShr = isa<AShrOperator>(I.getOperand(0));
if (IsAShr) {
if (AP2.isAllOnes())
return nullptr;
if (AP2.isNegative() != AP1.isNegative())
return nullptr;
if (AP2.sgt(AP1))
return nullptr;
}
if (!AP1)
return getICmp(I.ICMP_UGT, A,
ConstantInt::get(A->getType(), AP2.logBase2()));
if (AP1 == AP2)
return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType()));
int Shift;
if (IsAShr && AP1.isNegative())
Shift = AP1.countLeadingOnes() - AP2.countLeadingOnes();
else
Shift = AP1.countLeadingZeros() - AP2.countLeadingZeros();
if (Shift > 0) {
if (IsAShr && AP1 == AP2.ashr(Shift)) {
if (AP1.isAllOnes() && !AP2.isPowerOf2())
return getICmp(I.ICMP_UGE, A, ConstantInt::get(A->getType(), Shift));
return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift));
} else if (AP1 == AP2.lshr(Shift)) {
return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift));
}
}
auto *TorF = ConstantInt::get(I.getType(), I.getPredicate() == I.ICMP_NE);
return replaceInstUsesWith(I, TorF);
}
Instruction *InstCombinerImpl::foldICmpShlConstConst(ICmpInst &I, Value *A,
const APInt &AP1,
const APInt &AP2) {
assert(I.isEquality() && "Cannot fold icmp gt/lt");
auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) {
if (I.getPredicate() == I.ICMP_NE)
Pred = CmpInst::getInversePredicate(Pred);
return new ICmpInst(Pred, LHS, RHS);
};
if (AP2.isZero())
return nullptr;
unsigned AP2TrailingZeros = AP2.countTrailingZeros();
if (!AP1 && AP2TrailingZeros != 0)
return getICmp(
I.ICMP_UGE, A,
ConstantInt::get(A->getType(), AP2.getBitWidth() - AP2TrailingZeros));
if (AP1 == AP2)
return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType()));
int Shift = AP1.countTrailingZeros() - AP2TrailingZeros;
if (Shift > 0 && AP2.shl(Shift) == AP1)
return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift));
auto *TorF = ConstantInt::get(I.getType(), I.getPredicate() == I.ICMP_NE);
return replaceInstUsesWith(I, TorF);
}
static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B,
ConstantInt *CI2, ConstantInt *CI1,
InstCombinerImpl &IC) {
Instruction *AddWithCst = cast<Instruction>(I.getOperand(0));
if (!AddWithCst->hasOneUse())
return nullptr;
if (!CI2->getValue().isPowerOf2())
return nullptr;
unsigned NewWidth = CI2->getValue().countTrailingZeros();
if (NewWidth != 7 && NewWidth != 15 && NewWidth != 31)
return nullptr;
++NewWidth;
if (CI1->getBitWidth() == NewWidth ||
CI1->getValue() != APInt::getLowBitsSet(CI1->getBitWidth(), NewWidth))
return nullptr;
if (IC.ComputeMaxSignificantBits(A, 0, &I) > NewWidth ||
IC.ComputeMaxSignificantBits(B, 0, &I) > NewWidth)
return nullptr;
Instruction *OrigAdd = cast<Instruction>(AddWithCst->getOperand(0));
for (User *U : OrigAdd->users()) {
if (U == AddWithCst)
continue;
TruncInst *TI = dyn_cast<TruncInst>(U);
if (!TI || TI->getType()->getPrimitiveSizeInBits() > NewWidth)
return nullptr;
}
Type *NewType = IntegerType::get(OrigAdd->getContext(), NewWidth);
Function *F = Intrinsic::getDeclaration(
I.getModule(), Intrinsic::sadd_with_overflow, NewType);
InstCombiner::BuilderTy &Builder = IC.Builder;
Builder.SetInsertPoint(OrigAdd);
Value *TruncA = Builder.CreateTrunc(A, NewType, A->getName() + ".trunc");
Value *TruncB = Builder.CreateTrunc(B, NewType, B->getName() + ".trunc");
CallInst *Call = Builder.CreateCall(F, {TruncA, TruncB}, "sadd");
Value *Add = Builder.CreateExtractValue(Call, 0, "sadd.result");
Value *ZExt = Builder.CreateZExt(Add, OrigAdd->getType());
IC.replaceInstUsesWith(*OrigAdd, ZExt);
IC.eraseInstFromFunction(*OrigAdd);
return ExtractValueInst::Create(Call, 1, "sadd.overflow");
}
Instruction *InstCombinerImpl::foldIRemByPowerOfTwoToBitTest(ICmpInst &I) {
if (!I.isEquality())
return nullptr;
ICmpInst::Predicate Pred;
Value *X, *Y, *Zero;
if (!match(&I, m_ICmp(Pred, m_OneUse(m_IRem(m_Value(X), m_Value(Y))),
m_CombineAnd(m_Zero(), m_Value(Zero)))))
return nullptr;
if (!isKnownToBeAPowerOfTwo(Y, true, 0, &I))
return nullptr;
Value *Mask = Builder.CreateAdd(Y, Constant::getAllOnesValue(Y->getType()));
Value *Masked = Builder.CreateAnd(X, Mask);
return ICmpInst::Create(Instruction::ICmp, Pred, Masked, Zero);
}
Instruction *InstCombinerImpl::foldSignBitTest(ICmpInst &I) {
Instruction *Val;
ICmpInst::Predicate Pred;
if (!I.isEquality() || !match(&I, m_ICmp(Pred, m_Instruction(Val), m_Zero())))
return nullptr;
Value *X;
Type *XTy;
Constant *C;
if (match(Val, m_TruncOrSelf(m_Shr(m_Value(X), m_Constant(C))))) {
XTy = X->getType();
unsigned XBitWidth = XTy->getScalarSizeInBits();
if (!match(C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ,
APInt(XBitWidth, XBitWidth - 1))))
return nullptr;
} else if (isa<BinaryOperator>(Val) &&
(X = reassociateShiftAmtsOfTwoSameDirectionShifts(
cast<BinaryOperator>(Val), SQ.getWithInstruction(Val),
true))) {
XTy = X->getType();
} else
return nullptr;
return ICmpInst::Create(Instruction::ICmp,
Pred == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_SGE
: ICmpInst::ICMP_SLT,
X, ConstantInt::getNullValue(XTy));
}
Instruction *InstCombinerImpl::foldICmpWithZero(ICmpInst &Cmp) {
CmpInst::Predicate Pred = Cmp.getPredicate();
if (!match(Cmp.getOperand(1), m_Zero()))
return nullptr;
if (Pred == ICmpInst::ICMP_SGT) {
Value *A, *B;
if (match(Cmp.getOperand(0), m_SMin(m_Value(A), m_Value(B)))) {
if (isKnownPositive(A, DL, 0, &AC, &Cmp, &DT))
return new ICmpInst(Pred, B, Cmp.getOperand(1));
if (isKnownPositive(B, DL, 0, &AC, &Cmp, &DT))
return new ICmpInst(Pred, A, Cmp.getOperand(1));
}
}
if (Instruction *New = foldIRemByPowerOfTwoToBitTest(Cmp))
return New;
Value *X, *Y;
if (match(Cmp.getOperand(0), m_URem(m_Value(X), m_Value(Y))) &&
ICmpInst::isEquality(Pred)) {
KnownBits XKnown = computeKnownBits(X, 0, &Cmp);
KnownBits YKnown = computeKnownBits(Y, 0, &Cmp);
if (XKnown.countMaxPopulation() == 1 && YKnown.countMinPopulation() >= 2)
return new ICmpInst(Pred, X, Cmp.getOperand(1));
}
return nullptr;
}
Instruction *InstCombinerImpl::foldICmpWithConstant(ICmpInst &Cmp) {
CmpInst::Predicate Pred = Cmp.getPredicate();
Value *Op0 = Cmp.getOperand(0), *Op1 = Cmp.getOperand(1);
Value *A, *B;
ConstantInt *CI, *CI2; if (Pred == ICmpInst::ICMP_UGT && match(Op1, m_ConstantInt(CI)) &&
match(Op0, m_Add(m_Add(m_Value(A), m_Value(B)), m_ConstantInt(CI2))))
if (Instruction *Res = processUGT_ADDCST_ADD(Cmp, A, B, CI2, CI, *this))
return Res;
Constant *C = dyn_cast<Constant>(Op1);
if (!C)
return nullptr;
if (auto *Phi = dyn_cast<PHINode>(Op0))
if (all_of(Phi->operands(), [](Value *V) { return isa<Constant>(V); })) {
Type *Ty = Cmp.getType();
Builder.SetInsertPoint(Phi);
PHINode *NewPhi =
Builder.CreatePHI(Ty, Phi->getNumOperands());
for (BasicBlock *Predecessor : predecessors(Phi->getParent())) {
auto *Input =
cast<Constant>(Phi->getIncomingValueForBlock(Predecessor));
auto *BoolInput = ConstantExpr::getCompare(Pred, Input, C);
NewPhi->addIncoming(BoolInput, Predecessor);
}
NewPhi->takeName(&Cmp);
return replaceInstUsesWith(Cmp, NewPhi);
}
return nullptr;
}
Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) {
BasicBlock *CmpBB = Cmp.getParent();
BasicBlock *DomBB = CmpBB->getSinglePredecessor();
if (!DomBB)
return nullptr;
Value *DomCond;
BasicBlock *TrueBB, *FalseBB;
if (!match(DomBB->getTerminator(), m_Br(m_Value(DomCond), TrueBB, FalseBB)))
return nullptr;
assert((TrueBB == CmpBB || FalseBB == CmpBB) &&
"Predecessor block does not point to successor?");
if (TrueBB == FalseBB)
return nullptr;
Optional<bool> Imp = isImpliedCondition(DomCond, &Cmp, DL, TrueBB == CmpBB);
if (Imp)
return replaceInstUsesWith(Cmp, ConstantInt::get(Cmp.getType(), *Imp));
CmpInst::Predicate Pred = Cmp.getPredicate();
Value *X = Cmp.getOperand(0), *Y = Cmp.getOperand(1);
ICmpInst::Predicate DomPred;
const APInt *C, *DomC;
if (match(DomCond, m_ICmp(DomPred, m_Specific(X), m_APInt(DomC))) &&
match(Y, m_APInt(C))) {
ConstantRange CR = ConstantRange::makeExactICmpRegion(Pred, *C);
ConstantRange DominatingCR =
(CmpBB == TrueBB) ? ConstantRange::makeExactICmpRegion(DomPred, *DomC)
: ConstantRange::makeExactICmpRegion(
CmpInst::getInversePredicate(DomPred), *DomC);
ConstantRange Intersection = DominatingCR.intersectWith(CR);
ConstantRange Difference = DominatingCR.difference(CR);
if (Intersection.isEmptySet())
return replaceInstUsesWith(Cmp, Builder.getFalse());
if (Difference.isEmptySet())
return replaceInstUsesWith(Cmp, Builder.getTrue());
bool UnusedBit;
bool IsSignBit = isSignBitCheck(Pred, *C, UnusedBit);
if (Cmp.isEquality() || (IsSignBit && hasBranchUse(Cmp)))
return nullptr;
if (Cmp.hasOneUse() &&
match(Cmp.user_back(), m_MaxOrMin(m_Value(), m_Value())))
return nullptr;
if (const APInt *EqC = Intersection.getSingleElement())
return new ICmpInst(ICmpInst::ICMP_EQ, X, Builder.getInt(*EqC));
if (const APInt *NeC = Difference.getSingleElement())
return new ICmpInst(ICmpInst::ICMP_NE, X, Builder.getInt(*NeC));
}
return nullptr;
}
Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp,
TruncInst *Trunc,
const APInt &C) {
ICmpInst::Predicate Pred = Cmp.getPredicate();
Value *X = Trunc->getOperand(0);
if (C.isOne() && C.getBitWidth() > 1) {
Value *V = nullptr;
if (Pred == ICmpInst::ICMP_SLT && match(X, m_Signum(m_Value(V))))
return new ICmpInst(ICmpInst::ICMP_SLT, V,
ConstantInt::get(V->getType(), 1));
}
unsigned DstBits = Trunc->getType()->getScalarSizeInBits(),
SrcBits = X->getType()->getScalarSizeInBits();
if (Cmp.isEquality() && Trunc->hasOneUse()) {
if (!X->getType()->isVectorTy() && shouldChangeType(DstBits, SrcBits)) {
Constant *Mask = ConstantInt::get(X->getType(),
APInt::getLowBitsSet(SrcBits, DstBits));
Value *And = Builder.CreateAnd(X, Mask);
Constant *WideC = ConstantInt::get(X->getType(), C.zext(SrcBits));
return new ICmpInst(Pred, And, WideC);
}
KnownBits Known = computeKnownBits(X, 0, &Cmp);
if ((Known.Zero | Known.One).countLeadingOnes() >= SrcBits - DstBits) {
APInt NewRHS = C.zext(SrcBits);
NewRHS |= Known.One & APInt::getHighBitsSet(SrcBits, SrcBits - DstBits);
return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), NewRHS));
}
}
Value *ShOp;
const APInt *ShAmtC;
bool TrueIfSigned;
if (isSignBitCheck(Pred, C, TrueIfSigned) &&
match(X, m_Shr(m_Value(ShOp), m_APInt(ShAmtC))) &&
DstBits == SrcBits - ShAmtC->getZExtValue()) {
return TrueIfSigned
? new ICmpInst(ICmpInst::ICMP_SLT, ShOp,
ConstantInt::getNullValue(X->getType()))
: new ICmpInst(ICmpInst::ICMP_SGT, ShOp,
ConstantInt::getAllOnesValue(X->getType()));
}
return nullptr;
}
Instruction *InstCombinerImpl::foldICmpXorConstant(ICmpInst &Cmp,
BinaryOperator *Xor,
const APInt &C) {
Value *X = Xor->getOperand(0);
Value *Y = Xor->getOperand(1);
const APInt *XorC;
if (!match(Y, m_APInt(XorC)))
return nullptr;
ICmpInst::Predicate Pred = Cmp.getPredicate();
bool TrueIfSigned = false;
if (isSignBitCheck(Cmp.getPredicate(), C, TrueIfSigned)) {
if (!XorC->isNegative())
return replaceOperand(Cmp, 0, X);
if (TrueIfSigned)
return new ICmpInst(ICmpInst::ICMP_SGT, X,
ConstantInt::getAllOnesValue(X->getType()));
else
return new ICmpInst(ICmpInst::ICMP_SLT, X,
ConstantInt::getNullValue(X->getType()));
}
if (Xor->hasOneUse()) {
if (!Cmp.isEquality() && XorC->isSignMask()) {
Pred = Cmp.getFlippedSignednessPredicate();
return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), C ^ *XorC));
}
if (!Cmp.isEquality() && XorC->isMaxSignedValue()) {
Pred = Cmp.getFlippedSignednessPredicate();
Pred = Cmp.getSwappedPredicate(Pred);
return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), C ^ *XorC));
}
}
if (Pred == ICmpInst::ICMP_UGT) {
if (*XorC == ~C && (C + 1).isPowerOf2())
return new ICmpInst(ICmpInst::ICMP_ULT, X, Y);
if (*XorC == C && (C + 1).isPowerOf2())
return new ICmpInst(ICmpInst::ICMP_UGT, X, Y);
}
if (Pred == ICmpInst::ICMP_ULT) {
if (*XorC == -C && C.isPowerOf2())
return new ICmpInst(ICmpInst::ICMP_UGT, X,
ConstantInt::get(X->getType(), ~C));
if (*XorC == C && (-C).isPowerOf2())
return new ICmpInst(ICmpInst::ICMP_UGT, X,
ConstantInt::get(X->getType(), ~C));
}
return nullptr;
}
Instruction *InstCombinerImpl::foldICmpAndShift(ICmpInst &Cmp,
BinaryOperator *And,
const APInt &C1,
const APInt &C2) {
BinaryOperator *Shift = dyn_cast<BinaryOperator>(And->getOperand(0));
if (!Shift || !Shift->isShift())
return nullptr;
unsigned ShiftOpcode = Shift->getOpcode();
bool IsShl = ShiftOpcode == Instruction::Shl;
const APInt *C3;
if (match(Shift->getOperand(1), m_APInt(C3))) {
APInt NewAndCst, NewCmpCst;
bool AnyCmpCstBitsShiftedOut;
if (ShiftOpcode == Instruction::Shl) {
if (Cmp.isSigned() && (C2.isNegative() || C1.isNegative()))
return nullptr;
NewCmpCst = C1.lshr(*C3);
NewAndCst = C2.lshr(*C3);
AnyCmpCstBitsShiftedOut = NewCmpCst.shl(*C3) != C1;
} else if (ShiftOpcode == Instruction::LShr) {
NewCmpCst = C1.shl(*C3);
NewAndCst = C2.shl(*C3);
AnyCmpCstBitsShiftedOut = NewCmpCst.lshr(*C3) != C1;
if (Cmp.isSigned() && (NewAndCst.isNegative() || NewCmpCst.isNegative()))
return nullptr;
} else {
assert(ShiftOpcode == Instruction::AShr && "Unknown shift opcode");
NewCmpCst = C1.shl(*C3);
NewAndCst = C2.shl(*C3);
AnyCmpCstBitsShiftedOut = NewCmpCst.ashr(*C3) != C1;
if (NewAndCst.ashr(*C3) != C2)
return nullptr;
}
if (AnyCmpCstBitsShiftedOut) {
if (Cmp.getPredicate() == ICmpInst::ICMP_EQ)
return replaceInstUsesWith(Cmp, ConstantInt::getFalse(Cmp.getType()));
if (Cmp.getPredicate() == ICmpInst::ICMP_NE)
return replaceInstUsesWith(Cmp, ConstantInt::getTrue(Cmp.getType()));
} else {
Value *NewAnd = Builder.CreateAnd(
Shift->getOperand(0), ConstantInt::get(And->getType(), NewAndCst));
return new ICmpInst(Cmp.getPredicate(),
NewAnd, ConstantInt::get(And->getType(), NewCmpCst));
}
}
if (Shift->hasOneUse() && C1.isZero() && Cmp.isEquality() &&
!Shift->isArithmeticShift() && !isa<Constant>(Shift->getOperand(0))) {
Value *NewShift =
IsShl ? Builder.CreateLShr(And->getOperand(1), Shift->getOperand(1))
: Builder.CreateShl(And->getOperand(1), Shift->getOperand(1));
Value *NewAnd = Builder.CreateAnd(Shift->getOperand(0), NewShift);
return replaceOperand(Cmp, 0, NewAnd);
}
return nullptr;
}
Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp,
BinaryOperator *And,
const APInt &C1) {
bool isICMP_NE = Cmp.getPredicate() == ICmpInst::ICMP_NE;
if (isICMP_NE && Cmp.getType()->isVectorTy() && C1.isZero() &&
match(And->getOperand(1), m_One()))
return new TruncInst(And->getOperand(0), Cmp.getType());
const APInt *C2;
Value *X;
if (!match(And, m_And(m_Value(X), m_APInt(C2))))
return nullptr;
if (!And->hasOneUse())
return nullptr;
if (Cmp.isEquality() && C1.isZero()) {
if (C2->isSignMask()) {
Constant *Zero = Constant::getNullValue(X->getType());
auto NewPred = isICMP_NE ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGE;
return new ICmpInst(NewPred, X, Zero);
}
APInt NewC2 = *C2;
KnownBits Know = computeKnownBits(And->getOperand(0), 0, And);
NewC2 = *C2 + APInt::getHighBitsSet(C2->getBitWidth(),
Know.countMinLeadingZeros());
if (NewC2.isNegatedPowerOf2()) {
Constant *NegBOC = ConstantInt::get(And->getType(), -NewC2);
auto NewPred = isICMP_NE ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT;
return new ICmpInst(NewPred, X, NegBOC);
}
}
Value *W;
if (match(And->getOperand(0), m_OneUse(m_Trunc(m_Value(W)))) &&
(Cmp.isEquality() || (!C1.isNegative() && !C2->isNegative()))) {
if (!Cmp.getType()->isVectorTy()) {
Type *WideType = W->getType();
unsigned WideScalarBits = WideType->getScalarSizeInBits();
Constant *ZextC1 = ConstantInt::get(WideType, C1.zext(WideScalarBits));
Constant *ZextC2 = ConstantInt::get(WideType, C2->zext(WideScalarBits));
Value *NewAnd = Builder.CreateAnd(W, ZextC2, And->getName());
return new ICmpInst(Cmp.getPredicate(), NewAnd, ZextC1);
}
}
if (Instruction *I = foldICmpAndShift(Cmp, And, C1, *C2))
return I;
if (!Cmp.isSigned() && C1.isZero() && And->getOperand(0)->hasOneUse() &&
match(And->getOperand(1), m_One())) {
Constant *One = cast<Constant>(And->getOperand(1));
Value *Or = And->getOperand(0);
Value *A, *B, *LShr;
if (match(Or, m_Or(m_Value(LShr), m_Value(A))) &&
match(LShr, m_LShr(m_Specific(A), m_Value(B)))) {
unsigned UsesRemoved = 0;
if (And->hasOneUse())
++UsesRemoved;
if (Or->hasOneUse())
++UsesRemoved;
if (LShr->hasOneUse())
++UsesRemoved;
Value *NewOr = nullptr;
if (auto *C = dyn_cast<Constant>(B)) {
if (UsesRemoved >= 1)
NewOr = ConstantExpr::getOr(ConstantExpr::getNUWShl(One, C), One);
} else {
if (UsesRemoved >= 3)
NewOr = Builder.CreateOr(Builder.CreateShl(One, B, LShr->getName(),
true),
One, Or->getName());
}
if (NewOr) {
Value *NewAnd = Builder.CreateAnd(A, NewOr, And->getName());
return replaceOperand(Cmp, 0, NewAnd);
}
}
}
return nullptr;
}
Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp,
BinaryOperator *And,
const APInt &C) {
if (Instruction *I = foldICmpAndConstConst(Cmp, And, C))
return I;
const ICmpInst::Predicate Pred = Cmp.getPredicate();
bool TrueIfNeg;
if (isSignBitCheck(Pred, C, TrueIfNeg)) {
Value *X;
if (match(And->getOperand(0), m_Add(m_Value(X), m_AllOnes())) &&
match(And->getOperand(1), m_Not(m_Specific(X)))) {
auto NewPred = TrueIfNeg ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE;
return new ICmpInst(NewPred, X, ConstantInt::getNullValue(X->getType()));
}
}
Value *X = And->getOperand(0);
Value *Y = And->getOperand(1);
if (auto *C2 = dyn_cast<ConstantInt>(Y))
if (auto *LI = dyn_cast<LoadInst>(X))
if (auto *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0)))
if (auto *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0)))
if (Instruction *Res =
foldCmpLoadFromIndexedGlobal(LI, GEP, GV, Cmp, C2))
return Res;
if (!Cmp.isEquality())
return nullptr;
if (Cmp.getOperand(1) == Y && C.isNegatedPowerOf2()) {
auto NewPred =
Pred == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGT : CmpInst::ICMP_ULE;
return new ICmpInst(NewPred, X, SubOne(cast<Constant>(Cmp.getOperand(1))));
}
return nullptr;
}
Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp,
BinaryOperator *Or,
const APInt &C) {
ICmpInst::Predicate Pred = Cmp.getPredicate();
if (C.isOne()) {
Value *V = nullptr;
if (Pred == ICmpInst::ICMP_SLT && match(Or, m_Signum(m_Value(V))))
return new ICmpInst(ICmpInst::ICMP_SLT, V,
ConstantInt::get(V->getType(), 1));
}
Value *OrOp0 = Or->getOperand(0), *OrOp1 = Or->getOperand(1);
const APInt *MaskC;
if (match(OrOp1, m_APInt(MaskC)) && Cmp.isEquality()) {
if (*MaskC == C && (C + 1).isPowerOf2()) {
Pred = (Pred == CmpInst::ICMP_EQ) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT;
return new ICmpInst(Pred, OrOp0, OrOp1);
}
if (Or->hasOneUse()) {
Value *And = Builder.CreateAnd(OrOp0, ~(*MaskC));
Constant *NewC = ConstantInt::get(Or->getType(), C ^ (*MaskC));
return new ICmpInst(Pred, And, NewC);
}
}
Value *X;
bool TrueIfSigned;
if (isSignBitCheck(Pred, C, TrueIfSigned) &&
match(Or, m_c_Or(m_Add(m_Value(X), m_AllOnes()), m_Deferred(X)))) {
auto NewPred = TrueIfSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGT;
Constant *NewC = ConstantInt::get(X->getType(), TrueIfSigned ? 1 : 0);
return new ICmpInst(NewPred, X, NewC);
}
if (!Cmp.isEquality() || !C.isZero() || !Or->hasOneUse())
return nullptr;
Value *P, *Q;
if (match(Or, m_Or(m_PtrToInt(m_Value(P)), m_PtrToInt(m_Value(Q))))) {
Value *CmpP =
Builder.CreateICmp(Pred, P, ConstantInt::getNullValue(P->getType()));
Value *CmpQ =
Builder.CreateICmp(Pred, Q, ConstantInt::getNullValue(Q->getType()));
auto BOpc = Pred == CmpInst::ICMP_EQ ? Instruction::And : Instruction::Or;
return BinaryOperator::Create(BOpc, CmpP, CmpQ);
}
Value *X1, *X2, *X3, *X4;
if (match(OrOp0, m_OneUse(m_Xor(m_Value(X1), m_Value(X2)))) &&
match(OrOp1, m_OneUse(m_Xor(m_Value(X3), m_Value(X4))))) {
Value *Cmp12 = Builder.CreateICmp(Pred, X1, X2);
Value *Cmp34 = Builder.CreateICmp(Pred, X3, X4);
auto BOpc = Pred == CmpInst::ICMP_EQ ? Instruction::And : Instruction::Or;
return BinaryOperator::Create(BOpc, Cmp12, Cmp34);
}
return nullptr;
}
Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp,
BinaryOperator *Mul,
const APInt &C) {
const APInt *MulC;
if (!match(Mul->getOperand(1), m_APInt(MulC)))
return nullptr;
ICmpInst::Predicate Pred = Cmp.getPredicate();
if (isSignTest(Pred, C) && Mul->hasNoSignedWrap()) {
if (MulC->isNegative())
Pred = ICmpInst::getSwappedPredicate(Pred);
return new ICmpInst(Pred, Mul->getOperand(0),
Constant::getNullValue(Mul->getType()));
}
if (MulC->isZero() || !(Mul->hasNoSignedWrap() || Mul->hasNoUnsignedWrap()))
return nullptr;
if (Cmp.isEquality()) {
if (Mul->hasNoSignedWrap() && C.srem(*MulC).isZero()) {
Constant *NewC = ConstantInt::get(Mul->getType(), C.sdiv(*MulC));
return new ICmpInst(Pred, Mul->getOperand(0), NewC);
}
if (Mul->hasNoUnsignedWrap() && C.urem(*MulC).isZero()) {
Constant *NewC = ConstantInt::get(Mul->getType(), C.udiv(*MulC));
return new ICmpInst(Pred, Mul->getOperand(0), NewC);
}
}
Constant *NewC = nullptr;
if (Mul->hasNoSignedWrap()) {
if (MulC->isNegative()) {
if (C.isMinSignedValue() && MulC->isAllOnes())
return nullptr;
Pred = ICmpInst::getSwappedPredicate(Pred);
}
if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE)
NewC = ConstantInt::get(
Mul->getType(),
APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::UP));
if (Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_SGT)
NewC = ConstantInt::get(
Mul->getType(),
APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::DOWN));
} else if (Mul->hasNoUnsignedWrap()) {
if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE)
NewC = ConstantInt::get(
Mul->getType(),
APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::UP));
if (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT)
NewC = ConstantInt::get(
Mul->getType(),
APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::DOWN));
}
return NewC ? new ICmpInst(Pred, Mul->getOperand(0), NewC) : nullptr;
}
static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl,
const APInt &C) {
Value *Y;
if (!match(Shl, m_Shl(m_One(), m_Value(Y))))
return nullptr;
Type *ShiftType = Shl->getType();
unsigned TypeBits = C.getBitWidth();
bool CIsPowerOf2 = C.isPowerOf2();
ICmpInst::Predicate Pred = Cmp.getPredicate();
if (Cmp.isUnsigned()) {
if (!CIsPowerOf2) {
if (Pred == ICmpInst::ICMP_ULT)
Pred = ICmpInst::ICMP_ULE;
else if (Pred == ICmpInst::ICMP_UGE)
Pred = ICmpInst::ICMP_UGT;
}
unsigned CLog2 = C.logBase2();
if (CLog2 == TypeBits - 1) {
if (Pred == ICmpInst::ICMP_UGE)
Pred = ICmpInst::ICMP_EQ;
else if (Pred == ICmpInst::ICMP_ULT)
Pred = ICmpInst::ICMP_NE;
}
return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, CLog2));
} else if (Cmp.isSigned()) {
Constant *BitWidthMinusOne = ConstantInt::get(ShiftType, TypeBits - 1);
if (C.isAllOnes()) {
if (Pred == ICmpInst::ICMP_SLE)
return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne);
if (Pred == ICmpInst::ICMP_SGT)
return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne);
} else if (!C) {
if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne);
if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE)
return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne);
}
} else if (Cmp.isEquality() && CIsPowerOf2) {
return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, C.logBase2()));
}
return nullptr;
}
Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp,
BinaryOperator *Shl,
const APInt &C) {
const APInt *ShiftVal;
if (Cmp.isEquality() && match(Shl->getOperand(0), m_APInt(ShiftVal)))
return foldICmpShlConstConst(Cmp, Shl->getOperand(1), C, *ShiftVal);
const APInt *ShiftAmt;
if (!match(Shl->getOperand(1), m_APInt(ShiftAmt)))
return foldICmpShlOne(Cmp, Shl, C);
unsigned TypeBits = C.getBitWidth();
if (ShiftAmt->uge(TypeBits))
return nullptr;
ICmpInst::Predicate Pred = Cmp.getPredicate();
Value *X = Shl->getOperand(0);
Type *ShType = Shl->getType();
if (Shl->hasNoSignedWrap()) {
if (Pred == ICmpInst::ICMP_SGT) {
APInt ShiftedC = C.ashr(*ShiftAmt);
return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC));
}
if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) &&
C.ashr(*ShiftAmt).shl(*ShiftAmt) == C) {
APInt ShiftedC = C.ashr(*ShiftAmt);
return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC));
}
if (Pred == ICmpInst::ICMP_SLT) {
assert(!C.isMinSignedValue() && "Unexpected icmp slt");
APInt ShiftedC = (C - 1).ashr(*ShiftAmt) + 1;
return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC));
}
if (isSignTest(Pred, C))
return new ICmpInst(Pred, X, Constant::getNullValue(ShType));
}
if (Shl->hasNoUnsignedWrap()) {
if (Pred == ICmpInst::ICMP_UGT) {
APInt ShiftedC = C.lshr(*ShiftAmt);
return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC));
}
if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) &&
C.lshr(*ShiftAmt).shl(*ShiftAmt) == C) {
APInt ShiftedC = C.lshr(*ShiftAmt);
return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC));
}
if (Pred == ICmpInst::ICMP_ULT) {
assert(C.ugt(0) && "ult 0 should have been eliminated");
APInt ShiftedC = (C - 1).lshr(*ShiftAmt) + 1;
return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC));
}
}
if (Cmp.isEquality() && Shl->hasOneUse()) {
Constant *Mask = ConstantInt::get(
ShType,
APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt->getZExtValue()));
Value *And = Builder.CreateAnd(X, Mask, Shl->getName() + ".mask");
Constant *LShrC = ConstantInt::get(ShType, C.lshr(*ShiftAmt));
return new ICmpInst(Pred, And, LShrC);
}
bool TrueIfSigned = false;
if (Shl->hasOneUse() && isSignBitCheck(Pred, C, TrueIfSigned)) {
Constant *Mask = ConstantInt::get(
ShType,
APInt::getOneBitSet(TypeBits, TypeBits - ShiftAmt->getZExtValue() - 1));
Value *And = Builder.CreateAnd(X, Mask, Shl->getName() + ".mask");
return new ICmpInst(TrueIfSigned ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ,
And, Constant::getNullValue(ShType));
}
if (Cmp.isUnsigned() && Shl->hasOneUse()) {
if ((C + 1).isPowerOf2() &&
(Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT)) {
Value *And = Builder.CreateAnd(X, (~C).lshr(ShiftAmt->getZExtValue()));
return new ICmpInst(Pred == ICmpInst::ICMP_ULE ? ICmpInst::ICMP_EQ
: ICmpInst::ICMP_NE,
And, Constant::getNullValue(ShType));
}
if (C.isPowerOf2() &&
(Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE)) {
Value *And =
Builder.CreateAnd(X, (~(C - 1)).lshr(ShiftAmt->getZExtValue()));
return new ICmpInst(Pred == ICmpInst::ICMP_ULT ? ICmpInst::ICMP_EQ
: ICmpInst::ICMP_NE,
And, Constant::getNullValue(ShType));
}
}
unsigned Amt = ShiftAmt->getLimitedValue(TypeBits - 1);
if (Shl->hasOneUse() && Amt != 0 && C.countTrailingZeros() >= Amt &&
DL.isLegalInteger(TypeBits - Amt)) {
Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt);
if (auto *ShVTy = dyn_cast<VectorType>(ShType))
TruncTy = VectorType::get(TruncTy, ShVTy->getElementCount());
Constant *NewC =
ConstantInt::get(TruncTy, C.ashr(*ShiftAmt).trunc(TypeBits - Amt));
return new ICmpInst(Pred, Builder.CreateTrunc(X, TruncTy), NewC);
}
return nullptr;
}
Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp,
BinaryOperator *Shr,
const APInt &C) {
Value *X = Shr->getOperand(0);
CmpInst::Predicate Pred = Cmp.getPredicate();
if (Cmp.isEquality() && Shr->isExact() && C.isZero())
return new ICmpInst(Pred, X, Cmp.getOperand(1));
bool IsAShr = Shr->getOpcode() == Instruction::AShr;
const APInt *ShiftValC;
if (match(X, m_APInt(ShiftValC))) {
if (Cmp.isEquality())
return foldICmpShrConstConst(Cmp, Shr->getOperand(1), C, *ShiftValC);
bool TrueIfSigned;
if (!IsAShr && ShiftValC->isNegative() &&
isSignBitCheck(Pred, C, TrueIfSigned))
return new ICmpInst(TrueIfSigned ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE,
Shr->getOperand(1),
ConstantInt::getNullValue(X->getType()));
if (!IsAShr && ShiftValC->isPowerOf2() &&
(Pred == CmpInst::ICMP_UGT || Pred == CmpInst::ICMP_ULT)) {
bool IsUGT = Pred == CmpInst::ICMP_UGT;
assert(ShiftValC->uge(C) && "Expected simplify of compare");
assert((IsUGT || !C.isZero()) && "Expected X u< 0 to simplify");
unsigned CmpLZ =
IsUGT ? C.countLeadingZeros() : (C - 1).countLeadingZeros();
unsigned ShiftLZ = ShiftValC->countLeadingZeros();
Constant *NewC = ConstantInt::get(Shr->getType(), CmpLZ - ShiftLZ);
auto NewPred = IsUGT ? CmpInst::ICMP_ULT : CmpInst::ICMP_UGE;
return new ICmpInst(NewPred, Shr->getOperand(1), NewC);
}
}
const APInt *ShiftAmtC;
if (!match(Shr->getOperand(1), m_APInt(ShiftAmtC)))
return nullptr;
unsigned TypeBits = C.getBitWidth();
unsigned ShAmtVal = ShiftAmtC->getLimitedValue(TypeBits);
if (ShAmtVal >= TypeBits || ShAmtVal == 0)
return nullptr;
bool IsExact = Shr->isExact();
Type *ShrTy = Shr->getType();
if (IsAShr) {
if (IsExact || Pred == CmpInst::ICMP_SLT || Pred == CmpInst::ICMP_ULT) {
APInt ShiftedC = C.shl(ShAmtVal);
if (ShiftedC.ashr(ShAmtVal) == C)
return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC));
}
if (Pred == CmpInst::ICMP_SGT) {
APInt ShiftedC = (C + 1).shl(ShAmtVal) - 1;
if (!C.isMaxSignedValue() && !(C + 1).shl(ShAmtVal).isMinSignedValue() &&
(ShiftedC + 1).ashr(ShAmtVal) == (C + 1))
return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC));
}
if (Pred == CmpInst::ICMP_UGT) {
APInt ShiftedC = (C + 1).shl(ShAmtVal) - 1;
if ((ShiftedC + 1).ashr(ShAmtVal) == (C + 1) ||
(C + 1).shl(ShAmtVal).isMinSignedValue())
return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC));
}
if (C.getBitWidth() > 2 && C.getNumSignBits() <= ShAmtVal) {
if (Pred == CmpInst::ICMP_UGT) {
return new ICmpInst(CmpInst::ICMP_SLT, X,
ConstantInt::getNullValue(ShrTy));
}
if (Pred == CmpInst::ICMP_ULT) {
return new ICmpInst(CmpInst::ICMP_SGT, X,
ConstantInt::getAllOnesValue(ShrTy));
}
}
} else {
if (Pred == CmpInst::ICMP_ULT || (Pred == CmpInst::ICMP_UGT && IsExact)) {
APInt ShiftedC = C.shl(ShAmtVal);
if (ShiftedC.lshr(ShAmtVal) == C)
return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC));
}
if (Pred == CmpInst::ICMP_UGT) {
APInt ShiftedC = (C + 1).shl(ShAmtVal) - 1;
if ((ShiftedC + 1).lshr(ShAmtVal) == (C + 1))
return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC));
}
}
if (!Cmp.isEquality())
return nullptr;
assert(((IsAShr && C.shl(ShAmtVal).ashr(ShAmtVal) == C) ||
(!IsAShr && C.shl(ShAmtVal).lshr(ShAmtVal) == C)) &&
"Expected icmp+shr simplify did not occur.");
if (Shr->isExact())
return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, C << ShAmtVal));
if (C.isZero()) {
if (Pred == CmpInst::ICMP_EQ)
return new ICmpInst(CmpInst::ICMP_ULT, X,
ConstantInt::get(ShrTy, (C + 1).shl(ShAmtVal)));
else
return new ICmpInst(CmpInst::ICMP_UGT, X,
ConstantInt::get(ShrTy, (C + 1).shl(ShAmtVal) - 1));
}
if (Shr->hasOneUse()) {
APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal));
Constant *Mask = ConstantInt::get(ShrTy, Val);
Value *And = Builder.CreateAnd(X, Mask, Shr->getName() + ".mask");
return new ICmpInst(Pred, And, ConstantInt::get(ShrTy, C << ShAmtVal));
}
return nullptr;
}
Instruction *InstCombinerImpl::foldICmpSRemConstant(ICmpInst &Cmp,
BinaryOperator *SRem,
const APInt &C) {
const ICmpInst::Predicate Pred = Cmp.getPredicate();
if (Pred != ICmpInst::ICMP_SGT && Pred != ICmpInst::ICMP_SLT &&
Pred != ICmpInst::ICMP_EQ && Pred != ICmpInst::ICMP_NE)
return nullptr;
if (!SRem->hasOneUse())
return nullptr;
const APInt *DivisorC;
if (!match(SRem->getOperand(1), m_Power2(DivisorC)))
return nullptr;
if (((Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT) &&
!C.isZero()) ||
((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) &&
!C.isStrictlyPositive()))
return nullptr;
Type *Ty = SRem->getType();
APInt SignMask = APInt::getSignMask(Ty->getScalarSizeInBits());
Constant *MaskC = ConstantInt::get(Ty, SignMask | (*DivisorC - 1));
Value *And = Builder.CreateAnd(SRem->getOperand(0), MaskC);
if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE)
return new ICmpInst(Pred, And, ConstantInt::get(Ty, C));
if (Pred == ICmpInst::ICMP_SGT)
return new ICmpInst(ICmpInst::ICMP_SGT, And, ConstantInt::getNullValue(Ty));
return new ICmpInst(ICmpInst::ICMP_UGT, And, ConstantInt::get(Ty, SignMask));
}
Instruction *InstCombinerImpl::foldICmpUDivConstant(ICmpInst &Cmp,
BinaryOperator *UDiv,
const APInt &C) {
ICmpInst::Predicate Pred = Cmp.getPredicate();
Value *X = UDiv->getOperand(0);
Value *Y = UDiv->getOperand(1);
Type *Ty = UDiv->getType();
const APInt *C2;
if (!match(X, m_APInt(C2)))
return nullptr;
assert(*C2 != 0 && "udiv 0, X should have been simplified already.");
if (Pred == ICmpInst::ICMP_UGT) {
assert(!C.isMaxValue() &&
"icmp ugt X, UINT_MAX should have been simplified already.");
return new ICmpInst(ICmpInst::ICMP_ULE, Y,
ConstantInt::get(Ty, C2->udiv(C + 1)));
}
if (Pred == ICmpInst::ICMP_ULT) {
assert(C != 0 && "icmp ult X, 0 should have been simplified already.");
return new ICmpInst(ICmpInst::ICMP_UGT, Y,
ConstantInt::get(Ty, C2->udiv(C)));
}
return nullptr;
}
Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp,
BinaryOperator *Div,
const APInt &C) {
ICmpInst::Predicate Pred = Cmp.getPredicate();
Value *X = Div->getOperand(0);
Value *Y = Div->getOperand(1);
Type *Ty = Div->getType();
bool DivIsSigned = Div->getOpcode() == Instruction::SDiv;
if (Cmp.isEquality() && Div->hasOneUse() && C.isSignBitSet() &&
(!DivIsSigned || C.isMinSignedValue())) {
Value *XBig = Builder.CreateICmp(Pred, X, ConstantInt::get(Ty, C));
Value *YOne = Builder.CreateICmp(Pred, Y, ConstantInt::get(Ty, 1));
auto Logic = Pred == ICmpInst::ICMP_EQ ? Instruction::And : Instruction::Or;
return BinaryOperator::Create(Logic, XBig, YOne);
}
const APInt *C2;
if (!match(Y, m_APInt(C2)))
return nullptr;
if (!Cmp.isEquality() && DivIsSigned != Cmp.isSigned())
return nullptr;
if (C2->isZero() || C2->isOne() || (DivIsSigned && C2->isAllOnes()))
return nullptr;
APInt Prod = C * *C2;
bool ProdOV = (DivIsSigned ? Prod.sdiv(*C2) : Prod.udiv(*C2)) != C;
APInt RangeSize = Div->isExact() ? APInt(C2->getBitWidth(), 1) : *C2;
int LoOverflow = 0, HiOverflow = 0;
APInt LoBound, HiBound;
if (!DivIsSigned) { LoBound = Prod;
HiOverflow = LoOverflow = ProdOV;
if (!HiOverflow) {
HiOverflow = addWithOverflow(HiBound, LoBound, RangeSize, false);
}
} else if (C2->isStrictlyPositive()) { if (C.isZero()) { LoBound = -(RangeSize - 1);
HiBound = RangeSize;
} else if (C.isStrictlyPositive()) { LoBound = Prod; HiOverflow = LoOverflow = ProdOV;
if (!HiOverflow)
HiOverflow = addWithOverflow(HiBound, Prod, RangeSize, true);
} else { HiBound = Prod + 1;
LoOverflow = HiOverflow = ProdOV ? -1 : 0;
if (!LoOverflow) {
APInt DivNeg = -RangeSize;
LoOverflow = addWithOverflow(LoBound, HiBound, DivNeg, true) ? -1 : 0;
}
}
} else if (C2->isNegative()) { if (Div->isExact())
RangeSize.negate();
if (C.isZero()) { LoBound = RangeSize + 1;
HiBound = -RangeSize;
if (HiBound == *C2) { HiOverflow = 1; HiBound = APInt(); }
} else if (C.isStrictlyPositive()) { HiBound = Prod + 1;
HiOverflow = LoOverflow = ProdOV ? -1 : 0;
if (!LoOverflow)
LoOverflow =
addWithOverflow(LoBound, HiBound, RangeSize, true) ? -1 : 0;
} else { LoBound = Prod; LoOverflow = HiOverflow = ProdOV;
if (!HiOverflow)
HiOverflow = subWithOverflow(HiBound, Prod, RangeSize, true);
}
Pred = ICmpInst::getSwappedPredicate(Pred);
}
switch (Pred) {
default:
llvm_unreachable("Unhandled icmp predicate!");
case ICmpInst::ICMP_EQ:
if (LoOverflow && HiOverflow)
return replaceInstUsesWith(Cmp, Builder.getFalse());
if (HiOverflow)
return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE,
X, ConstantInt::get(Ty, LoBound));
if (LoOverflow)
return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT,
X, ConstantInt::get(Ty, HiBound));
return replaceInstUsesWith(
Cmp, insertRangeTest(X, LoBound, HiBound, DivIsSigned, true));
case ICmpInst::ICMP_NE:
if (LoOverflow && HiOverflow)
return replaceInstUsesWith(Cmp, Builder.getTrue());
if (HiOverflow)
return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT,
X, ConstantInt::get(Ty, LoBound));
if (LoOverflow)
return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE,
X, ConstantInt::get(Ty, HiBound));
return replaceInstUsesWith(
Cmp, insertRangeTest(X, LoBound, HiBound, DivIsSigned, false));
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_SLT:
if (LoOverflow == +1) return replaceInstUsesWith(Cmp, Builder.getTrue());
if (LoOverflow == -1) return replaceInstUsesWith(Cmp, Builder.getFalse());
return new ICmpInst(Pred, X, ConstantInt::get(Ty, LoBound));
case ICmpInst::ICMP_UGT:
case ICmpInst::ICMP_SGT:
if (HiOverflow == +1) return replaceInstUsesWith(Cmp, Builder.getFalse());
if (HiOverflow == -1) return replaceInstUsesWith(Cmp, Builder.getTrue());
if (Pred == ICmpInst::ICMP_UGT)
return new ICmpInst(ICmpInst::ICMP_UGE, X, ConstantInt::get(Ty, HiBound));
return new ICmpInst(ICmpInst::ICMP_SGE, X, ConstantInt::get(Ty, HiBound));
}
return nullptr;
}
Instruction *InstCombinerImpl::foldICmpSubConstant(ICmpInst &Cmp,
BinaryOperator *Sub,
const APInt &C) {
Value *X = Sub->getOperand(0), *Y = Sub->getOperand(1);
ICmpInst::Predicate Pred = Cmp.getPredicate();
Type *Ty = Sub->getType();
Constant *SubC;
if (Cmp.isEquality() && match(X, m_ImmConstant(SubC))) {
return new ICmpInst(Pred, Y,
ConstantExpr::getSub(SubC, ConstantInt::get(Ty, C)));
}
const APInt *C2;
APInt SubResult;
ICmpInst::Predicate SwappedPred = Cmp.getSwappedPredicate();
bool HasNSW = Sub->hasNoSignedWrap();
bool HasNUW = Sub->hasNoUnsignedWrap();
if (match(X, m_APInt(C2)) &&
((Cmp.isUnsigned() && HasNUW) || (Cmp.isSigned() && HasNSW)) &&
!subWithOverflow(SubResult, *C2, C, Cmp.isSigned()))
return new ICmpInst(SwappedPred, Y, ConstantInt::get(Ty, SubResult));
if (Cmp.isEquality() && C.isZero() &&
none_of((Sub->users()), [](const User *U) { return isa<PHINode>(U); }))
return new ICmpInst(Pred, X, Y);
if (!Sub->hasOneUse())
return nullptr;
if (Sub->hasNoSignedWrap()) {
if (Pred == ICmpInst::ICMP_SGT && C.isAllOnes())
return new ICmpInst(ICmpInst::ICMP_SGE, X, Y);
if (Pred == ICmpInst::ICMP_SGT && C.isZero())
return new ICmpInst(ICmpInst::ICMP_SGT, X, Y);
if (Pred == ICmpInst::ICMP_SLT && C.isZero())
return new ICmpInst(ICmpInst::ICMP_SLT, X, Y);
if (Pred == ICmpInst::ICMP_SLT && C.isOne())
return new ICmpInst(ICmpInst::ICMP_SLE, X, Y);
}
if (!match(X, m_APInt(C2)))
return nullptr;
if (Pred == ICmpInst::ICMP_ULT && C.isPowerOf2() &&
(*C2 & (C - 1)) == (C - 1))
return new ICmpInst(ICmpInst::ICMP_EQ, Builder.CreateOr(Y, C - 1), X);
if (Pred == ICmpInst::ICMP_UGT && (C + 1).isPowerOf2() && (*C2 & C) == C)
return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateOr(Y, C), X);
Value *Add = Builder.CreateAdd(Y, ConstantInt::get(Ty, ~(*C2)), "notsub",
HasNUW, HasNSW);
return new ICmpInst(SwappedPred, Add, ConstantInt::get(Ty, ~C));
}
Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp,
BinaryOperator *Add,
const APInt &C) {
Value *Y = Add->getOperand(1);
const APInt *C2;
if (Cmp.isEquality() || !match(Y, m_APInt(C2)))
return nullptr;
Value *X = Add->getOperand(0);
Type *Ty = Add->getType();
const CmpInst::Predicate Pred = Cmp.getPredicate();
if ((Add->hasNoSignedWrap() &&
(Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT)) ||
(Add->hasNoUnsignedWrap() &&
(Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULT))) {
bool Overflow;
APInt NewC =
Cmp.isSigned() ? C.ssub_ov(*C2, Overflow) : C.usub_ov(*C2, Overflow);
if (!Overflow)
return new ICmpInst(Pred, X, ConstantInt::get(Ty, NewC));
}
auto CR = ConstantRange::makeExactICmpRegion(Pred, C).subtract(*C2);
const APInt &Upper = CR.getUpper();
const APInt &Lower = CR.getLower();
if (Cmp.isSigned()) {
if (Lower.isSignMask())
return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantInt::get(Ty, Upper));
if (Upper.isSignMask())
return new ICmpInst(ICmpInst::ICMP_SGE, X, ConstantInt::get(Ty, Lower));
} else {
if (Lower.isMinValue())
return new ICmpInst(ICmpInst::ICMP_ULT, X, ConstantInt::get(Ty, Upper));
if (Upper.isMinValue())
return new ICmpInst(ICmpInst::ICMP_UGE, X, ConstantInt::get(Ty, Lower));
}
const APInt SMax = APInt::getSignedMaxValue(Ty->getScalarSizeInBits());
const APInt SMin = APInt::getSignedMinValue(Ty->getScalarSizeInBits());
if (Pred == CmpInst::ICMP_UGT && C == *C2 + SMax)
return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantInt::get(Ty, -(*C2)));
if (Pred == CmpInst::ICMP_ULT && C == *C2 + SMin)
return new ICmpInst(ICmpInst::ICMP_SGT, X, ConstantInt::get(Ty, ~(*C2)));
if (Pred == CmpInst::ICMP_SGT && C == *C2 - 1)
return new ICmpInst(ICmpInst::ICMP_ULT, X, ConstantInt::get(Ty, SMax - C));
if (Pred == CmpInst::ICMP_SLT && C == *C2)
return new ICmpInst(ICmpInst::ICMP_UGT, X, ConstantInt::get(Ty, C ^ SMax));
if (!Add->hasOneUse())
return nullptr;
if (Pred == ICmpInst::ICMP_ULT && C.isPowerOf2() && (*C2 & (C - 1)) == 0)
return new ICmpInst(ICmpInst::ICMP_EQ, Builder.CreateAnd(X, -C),
ConstantExpr::getNeg(cast<Constant>(Y)));
if (Pred == ICmpInst::ICMP_UGT && (C + 1).isPowerOf2() && (*C2 & C) == 0)
return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateAnd(X, ~C),
ConstantExpr::getNeg(cast<Constant>(Y)));
if (Pred == ICmpInst::ICMP_UGT)
return new ICmpInst(ICmpInst::ICMP_ULT,
Builder.CreateAdd(X, ConstantInt::get(Ty, *C2 - C - 1)),
ConstantInt::get(Ty, ~C));
return nullptr;
}
bool InstCombinerImpl::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS,
Value *&RHS, ConstantInt *&Less,
ConstantInt *&Equal,
ConstantInt *&Greater) {
ICmpInst::Predicate PredA;
if (!match(SI->getCondition(), m_ICmp(PredA, m_Value(LHS), m_Value(RHS))) ||
!ICmpInst::isEquality(PredA))
return false;
Value *EqualVal = SI->getTrueValue();
Value *UnequalVal = SI->getFalseValue();
if (PredA == ICmpInst::ICMP_NE)
std::swap(EqualVal, UnequalVal);
if (!match(EqualVal, m_ConstantInt(Equal)))
return false;
ICmpInst::Predicate PredB;
Value *LHS2, *RHS2;
if (!match(UnequalVal, m_Select(m_ICmp(PredB, m_Value(LHS2), m_Value(RHS2)),
m_ConstantInt(Less), m_ConstantInt(Greater))))
return false;
if (LHS2 != LHS) {
std::swap(LHS2, RHS2);
PredB = ICmpInst::getSwappedPredicate(PredB);
}
if (LHS2 != LHS)
return false;
if (PredB == ICmpInst::ICMP_SGT && isa<Constant>(RHS2)) {
auto FlippedStrictness =
InstCombiner::getFlippedStrictnessPredicateAndConstant(
PredB, cast<Constant>(RHS2));
if (!FlippedStrictness)
return false;
assert(FlippedStrictness->first == ICmpInst::ICMP_SGE &&
"basic correctness failure");
RHS2 = FlippedStrictness->second;
std::swap(Less, Greater);
PredB = ICmpInst::ICMP_SLT;
}
return PredB == ICmpInst::ICMP_SLT && RHS == RHS2;
}
Instruction *InstCombinerImpl::foldICmpSelectConstant(ICmpInst &Cmp,
SelectInst *Select,
ConstantInt *C) {
assert(C && "Cmp RHS should be a constant int!");
Value *OrigLHS, *OrigRHS;
ConstantInt *C1LessThan, *C2Equal, *C3GreaterThan;
if (Cmp.hasOneUse() &&
matchThreeWayIntCompare(Select, OrigLHS, OrigRHS, C1LessThan, C2Equal,
C3GreaterThan)) {
assert(C1LessThan && C2Equal && C3GreaterThan);
bool TrueWhenLessThan =
ConstantExpr::getCompare(Cmp.getPredicate(), C1LessThan, C)
->isAllOnesValue();
bool TrueWhenEqual =
ConstantExpr::getCompare(Cmp.getPredicate(), C2Equal, C)
->isAllOnesValue();
bool TrueWhenGreaterThan =
ConstantExpr::getCompare(Cmp.getPredicate(), C3GreaterThan, C)
->isAllOnesValue();
Value *Cond = Builder.getFalse();
if (TrueWhenLessThan)
Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_SLT,
OrigLHS, OrigRHS));
if (TrueWhenEqual)
Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_EQ,
OrigLHS, OrigRHS));
if (TrueWhenGreaterThan)
Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_SGT,
OrigLHS, OrigRHS));
return replaceInstUsesWith(Cmp, Cond);
}
return nullptr;
}
Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) {
auto *Bitcast = dyn_cast<BitCastInst>(Cmp.getOperand(0));
if (!Bitcast)
return nullptr;
ICmpInst::Predicate Pred = Cmp.getPredicate();
Value *Op1 = Cmp.getOperand(1);
Value *BCSrcOp = Bitcast->getOperand(0);
Type *SrcType = Bitcast->getSrcTy();
Type *DstType = Bitcast->getType();
if (SrcType->isVectorTy() == DstType->isVectorTy() &&
SrcType->getScalarSizeInBits() == DstType->getScalarSizeInBits()) {
Value *X;
if (match(BCSrcOp, m_SIToFP(m_Value(X)))) {
if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_SLT ||
Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SGT) &&
match(Op1, m_Zero()))
return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType()));
if (Pred == ICmpInst::ICMP_SLT && match(Op1, m_One()))
return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), 1));
if (Pred == ICmpInst::ICMP_SGT && match(Op1, m_AllOnes()))
return new ICmpInst(Pred, X,
ConstantInt::getAllOnesValue(X->getType()));
}
if (match(BCSrcOp, m_UIToFP(m_Value(X))))
if (Cmp.isEquality() && match(Op1, m_Zero()))
return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType()));
const APInt *C;
bool TrueIfSigned;
if (match(Op1, m_APInt(C)) && Bitcast->hasOneUse() &&
isSignBitCheck(Pred, *C, TrueIfSigned)) {
if (match(BCSrcOp, m_FPExt(m_Value(X))) ||
match(BCSrcOp, m_FPTrunc(m_Value(X)))) {
Type *XType = X->getType();
if (!(XType->isPPC_FP128Ty() || SrcType->isPPC_FP128Ty())) {
Type *NewType = Builder.getIntNTy(XType->getScalarSizeInBits());
if (auto *XVTy = dyn_cast<VectorType>(XType))
NewType = VectorType::get(NewType, XVTy->getElementCount());
Value *NewBitcast = Builder.CreateBitCast(X, NewType);
if (TrueIfSigned)
return new ICmpInst(ICmpInst::ICMP_SLT, NewBitcast,
ConstantInt::getNullValue(NewType));
else
return new ICmpInst(ICmpInst::ICMP_SGT, NewBitcast,
ConstantInt::getAllOnesValue(NewType));
}
}
}
}
if (DstType->isPointerTy() && (isa<Constant>(Op1) || isa<BitCastInst>(Op1))) {
if (auto *BC2 = dyn_cast<BitCastInst>(Op1))
Op1 = BC2->getOperand(0);
Op1 = Builder.CreateBitCast(Op1, SrcType);
return new ICmpInst(Pred, BCSrcOp, Op1);
}
const APInt *C;
if (!match(Cmp.getOperand(1), m_APInt(C)) || !DstType->isIntegerTy() ||
!SrcType->isIntOrIntVectorTy())
return nullptr;
if (Cmp.isEquality() && C->isAllOnes() && Bitcast->hasOneUse() &&
isFreeToInvert(BCSrcOp, BCSrcOp->hasOneUse())) {
Value *Cast = Builder.CreateBitCast(Builder.CreateNot(BCSrcOp), DstType);
return new ICmpInst(Pred, Cast, ConstantInt::getNullValue(DstType));
}
Value *X;
if (Cmp.isEquality() && C->isZero() && Bitcast->hasOneUse() &&
match(BCSrcOp, m_ZExtOrSExt(m_Value(X)))) {
if (auto *VecTy = dyn_cast<FixedVectorType>(X->getType())) {
Type *NewType = Builder.getIntNTy(VecTy->getPrimitiveSizeInBits());
Value *NewCast = Builder.CreateBitCast(X, NewType);
return new ICmpInst(Pred, NewCast, ConstantInt::getNullValue(NewType));
}
}
Value *Vec;
ArrayRef<int> Mask;
if (match(BCSrcOp, m_Shuffle(m_Value(Vec), m_Undef(), m_Mask(Mask)))) {
if (is_splat(Mask)) {
auto *VecTy = cast<VectorType>(SrcType);
auto *EltTy = cast<IntegerType>(VecTy->getElementType());
if (C->isSplat(EltTy->getBitWidth())) {
Value *Elem = Builder.getInt32(Mask[0]);
Value *Extract = Builder.CreateExtractElement(Vec, Elem);
Value *NewC = ConstantInt::get(EltTy, C->trunc(EltTy->getBitWidth()));
return new ICmpInst(Pred, Extract, NewC);
}
}
}
return nullptr;
}
Instruction *InstCombinerImpl::foldICmpInstWithConstant(ICmpInst &Cmp) {
const APInt *C;
if (match(Cmp.getOperand(1), m_APInt(C))) {
if (auto *BO = dyn_cast<BinaryOperator>(Cmp.getOperand(0)))
if (Instruction *I = foldICmpBinOpWithConstant(Cmp, BO, *C))
return I;
if (auto *SI = dyn_cast<SelectInst>(Cmp.getOperand(0)))
if (auto *ConstRHS = dyn_cast<ConstantInt>(Cmp.getOperand(1)))
if (Instruction *I = foldICmpSelectConstant(Cmp, SI, ConstRHS))
return I;
if (auto *TI = dyn_cast<TruncInst>(Cmp.getOperand(0)))
if (Instruction *I = foldICmpTruncConstant(Cmp, TI, *C))
return I;
if (auto *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0)))
if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, II, *C))
return I;
}
if (match(Cmp.getOperand(1), m_APIntAllowUndef(C)))
return foldICmpInstWithConstantAllowUndef(Cmp, *C);
return nullptr;
}
Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant(
ICmpInst &Cmp, BinaryOperator *BO, const APInt &C) {
if (!Cmp.isEquality())
return nullptr;
ICmpInst::Predicate Pred = Cmp.getPredicate();
bool isICMP_NE = Pred == ICmpInst::ICMP_NE;
Constant *RHS = cast<Constant>(Cmp.getOperand(1));
Value *BOp0 = BO->getOperand(0), *BOp1 = BO->getOperand(1);
switch (BO->getOpcode()) {
case Instruction::SRem:
if (C.isZero() && BO->hasOneUse()) {
const APInt *BOC;
if (match(BOp1, m_APInt(BOC)) && BOC->sgt(1) && BOC->isPowerOf2()) {
Value *NewRem = Builder.CreateURem(BOp0, BOp1, BO->getName());
return new ICmpInst(Pred, NewRem,
Constant::getNullValue(BO->getType()));
}
}
break;
case Instruction::Add: {
if (Constant *BOC = dyn_cast<Constant>(BOp1)) {
if (BO->hasOneUse())
return new ICmpInst(Pred, BOp0, ConstantExpr::getSub(RHS, BOC));
} else if (C.isZero()) {
if (Value *NegVal = dyn_castNegVal(BOp1))
return new ICmpInst(Pred, BOp0, NegVal);
if (Value *NegVal = dyn_castNegVal(BOp0))
return new ICmpInst(Pred, NegVal, BOp1);
if (BO->hasOneUse()) {
Value *Neg = Builder.CreateNeg(BOp1);
Neg->takeName(BO);
return new ICmpInst(Pred, BOp0, Neg);
}
}
break;
}
case Instruction::Xor:
if (BO->hasOneUse()) {
if (Constant *BOC = dyn_cast<Constant>(BOp1)) {
return new ICmpInst(Pred, BOp0, ConstantExpr::getXor(RHS, BOC));
} else if (C.isZero()) {
return new ICmpInst(Pred, BOp0, BOp1);
}
}
break;
case Instruction::Or: {
const APInt *BOC;
if (match(BOp1, m_APInt(BOC)) && BO->hasOneUse() && RHS->isAllOnesValue()) {
Constant *NotBOC = ConstantExpr::getNot(cast<Constant>(BOp1));
Value *And = Builder.CreateAnd(BOp0, NotBOC);
return new ICmpInst(Pred, And, NotBOC);
}
break;
}
case Instruction::And: {
const APInt *BOC;
if (match(BOp1, m_APInt(BOC))) {
if (C == *BOC && C.isPowerOf2())
return new ICmpInst(isICMP_NE ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE,
BO, Constant::getNullValue(RHS->getType()));
}
break;
}
case Instruction::UDiv:
if (C.isZero()) {
auto NewPred = isICMP_NE ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT;
return new ICmpInst(NewPred, BOp1, BOp0);
}
break;
default:
break;
}
return nullptr;
}
Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant(
ICmpInst &Cmp, IntrinsicInst *II, const APInt &C) {
Type *Ty = II->getType();
unsigned BitWidth = C.getBitWidth();
const ICmpInst::Predicate Pred = Cmp.getPredicate();
switch (II->getIntrinsicID()) {
case Intrinsic::abs:
if (C.isZero() || C.isMinSignedValue())
return new ICmpInst(Pred, II->getArgOperand(0), ConstantInt::get(Ty, C));
break;
case Intrinsic::bswap:
return new ICmpInst(Pred, II->getArgOperand(0),
ConstantInt::get(Ty, C.byteSwap()));
case Intrinsic::ctlz:
case Intrinsic::cttz: {
if (C == BitWidth)
return new ICmpInst(Pred, II->getArgOperand(0),
ConstantInt::getNullValue(Ty));
unsigned Num = C.getLimitedValue(BitWidth);
if (Num != BitWidth && II->hasOneUse()) {
bool IsTrailing = II->getIntrinsicID() == Intrinsic::cttz;
APInt Mask1 = IsTrailing ? APInt::getLowBitsSet(BitWidth, Num + 1)
: APInt::getHighBitsSet(BitWidth, Num + 1);
APInt Mask2 = IsTrailing
? APInt::getOneBitSet(BitWidth, Num)
: APInt::getOneBitSet(BitWidth, BitWidth - Num - 1);
return new ICmpInst(Pred, Builder.CreateAnd(II->getArgOperand(0), Mask1),
ConstantInt::get(Ty, Mask2));
}
break;
}
case Intrinsic::ctpop: {
bool IsZero = C.isZero();
if (IsZero || C == BitWidth)
return new ICmpInst(Pred, II->getArgOperand(0),
IsZero ? Constant::getNullValue(Ty)
: Constant::getAllOnesValue(Ty));
break;
}
case Intrinsic::fshl:
case Intrinsic::fshr:
if (II->getArgOperand(0) == II->getArgOperand(1)) {
const APInt *RotAmtC;
if (match(II->getArgOperand(2), m_APInt(RotAmtC)))
return new ICmpInst(Pred, II->getArgOperand(0),
II->getIntrinsicID() == Intrinsic::fshl
? ConstantInt::get(Ty, C.rotr(*RotAmtC))
: ConstantInt::get(Ty, C.rotl(*RotAmtC)));
}
break;
case Intrinsic::uadd_sat: {
if (C.isZero()) {
Value *Or = Builder.CreateOr(II->getArgOperand(0), II->getArgOperand(1));
return new ICmpInst(Pred, Or, Constant::getNullValue(Ty));
}
break;
}
case Intrinsic::usub_sat: {
if (C.isZero()) {
ICmpInst::Predicate NewPred =
Pred == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT;
return new ICmpInst(NewPred, II->getArgOperand(0), II->getArgOperand(1));
}
break;
}
default:
break;
}
return nullptr;
}
static Instruction *foldICmpIntrinsicWithIntrinsic(ICmpInst &Cmp) {
assert(Cmp.isEquality());
ICmpInst::Predicate Pred = Cmp.getPredicate();
Value *Op0 = Cmp.getOperand(0);
Value *Op1 = Cmp.getOperand(1);
const auto *IIOp0 = dyn_cast<IntrinsicInst>(Op0);
const auto *IIOp1 = dyn_cast<IntrinsicInst>(Op1);
if (!IIOp0 || !IIOp1 || IIOp0->getIntrinsicID() != IIOp1->getIntrinsicID())
return nullptr;
switch (IIOp0->getIntrinsicID()) {
case Intrinsic::bswap:
case Intrinsic::bitreverse:
return new ICmpInst(Pred, IIOp0->getOperand(0), IIOp1->getOperand(0));
case Intrinsic::fshl:
case Intrinsic::fshr:
if (IIOp0->getOperand(0) != IIOp0->getOperand(1))
break;
if (IIOp1->getOperand(0) != IIOp1->getOperand(1))
break;
if (IIOp0->getOperand(2) != IIOp1->getOperand(2))
break;
return new ICmpInst(Pred, IIOp0->getOperand(0), IIOp1->getOperand(0));
default:
break;
}
return nullptr;
}
Instruction *
InstCombinerImpl::foldICmpInstWithConstantAllowUndef(ICmpInst &Cmp,
const APInt &C) {
const ICmpInst::Predicate Pred = Cmp.getPredicate();
if (auto *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0))) {
switch (II->getIntrinsicID()) {
default:
break;
case Intrinsic::fshl:
case Intrinsic::fshr:
if (Cmp.isEquality() && II->getArgOperand(0) == II->getArgOperand(1)) {
if (C.isZero() || C.isAllOnes())
return new ICmpInst(Pred, II->getArgOperand(0), Cmp.getOperand(1));
}
break;
}
}
return nullptr;
}
Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp,
BinaryOperator *BO,
const APInt &C) {
switch (BO->getOpcode()) {
case Instruction::Xor:
if (Instruction *I = foldICmpXorConstant(Cmp, BO, C))
return I;
break;
case Instruction::And:
if (Instruction *I = foldICmpAndConstant(Cmp, BO, C))
return I;
break;
case Instruction::Or:
if (Instruction *I = foldICmpOrConstant(Cmp, BO, C))
return I;
break;
case Instruction::Mul:
if (Instruction *I = foldICmpMulConstant(Cmp, BO, C))
return I;
break;
case Instruction::Shl:
if (Instruction *I = foldICmpShlConstant(Cmp, BO, C))
return I;
break;
case Instruction::LShr:
case Instruction::AShr:
if (Instruction *I = foldICmpShrConstant(Cmp, BO, C))
return I;
break;
case Instruction::SRem:
if (Instruction *I = foldICmpSRemConstant(Cmp, BO, C))
return I;
break;
case Instruction::UDiv:
if (Instruction *I = foldICmpUDivConstant(Cmp, BO, C))
return I;
LLVM_FALLTHROUGH;
case Instruction::SDiv:
if (Instruction *I = foldICmpDivConstant(Cmp, BO, C))
return I;
break;
case Instruction::Sub:
if (Instruction *I = foldICmpSubConstant(Cmp, BO, C))
return I;
break;
case Instruction::Add:
if (Instruction *I = foldICmpAddConstant(Cmp, BO, C))
return I;
break;
default:
break;
}
return foldICmpBinOpEqualityWithConstant(Cmp, BO, C);
}
Instruction *InstCombinerImpl::foldICmpIntrinsicWithConstant(ICmpInst &Cmp,
IntrinsicInst *II,
const APInt &C) {
if (Cmp.isEquality())
return foldICmpEqIntrinsicWithConstant(Cmp, II, C);
Type *Ty = II->getType();
unsigned BitWidth = C.getBitWidth();
ICmpInst::Predicate Pred = Cmp.getPredicate();
switch (II->getIntrinsicID()) {
case Intrinsic::ctpop: {
Value *X = II->getArgOperand(0);
if (C == BitWidth - 1 && Pred == ICmpInst::ICMP_UGT)
return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, X,
ConstantInt::getAllOnesValue(Ty));
if (C == BitWidth && Pred == ICmpInst::ICMP_ULT)
return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, X,
ConstantInt::getAllOnesValue(Ty));
break;
}
case Intrinsic::ctlz: {
if (Pred == ICmpInst::ICMP_UGT && C.ult(BitWidth)) {
unsigned Num = C.getLimitedValue();
APInt Limit = APInt::getOneBitSet(BitWidth, BitWidth - Num - 1);
return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_ULT,
II->getArgOperand(0), ConstantInt::get(Ty, Limit));
}
if (Pred == ICmpInst::ICMP_ULT && C.uge(1) && C.ule(BitWidth)) {
unsigned Num = C.getLimitedValue();
APInt Limit = APInt::getLowBitsSet(BitWidth, BitWidth - Num);
return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT,
II->getArgOperand(0), ConstantInt::get(Ty, Limit));
}
break;
}
case Intrinsic::cttz: {
if (!II->hasOneUse())
return nullptr;
if (Pred == ICmpInst::ICMP_UGT && C.ult(BitWidth)) {
APInt Mask = APInt::getLowBitsSet(BitWidth, C.getLimitedValue() + 1);
return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ,
Builder.CreateAnd(II->getArgOperand(0), Mask),
ConstantInt::getNullValue(Ty));
}
if (Pred == ICmpInst::ICMP_ULT && C.uge(1) && C.ule(BitWidth)) {
APInt Mask = APInt::getLowBitsSet(BitWidth, C.getLimitedValue());
return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE,
Builder.CreateAnd(II->getArgOperand(0), Mask),
ConstantInt::getNullValue(Ty));
}
break;
}
default:
break;
}
return nullptr;
}
Instruction *InstCombinerImpl::foldICmpInstWithConstantNotInt(ICmpInst &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
Constant *RHSC = dyn_cast<Constant>(Op1);
Instruction *LHSI = dyn_cast<Instruction>(Op0);
if (!RHSC || !LHSI)
return nullptr;
switch (LHSI->getOpcode()) {
case Instruction::GetElementPtr:
if (RHSC->isNullValue() &&
cast<GetElementPtrInst>(LHSI)->hasAllZeroIndices())
return new ICmpInst(
I.getPredicate(), LHSI->getOperand(0),
Constant::getNullValue(LHSI->getOperand(0)->getType()));
break;
case Instruction::PHI:
if (LHSI->getParent() == I.getParent())
if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI)))
return NV;
break;
case Instruction::IntToPtr:
if (RHSC->isNullValue() &&
DL.getIntPtrType(RHSC->getType()) == LHSI->getOperand(0)->getType())
return new ICmpInst(
I.getPredicate(), LHSI->getOperand(0),
Constant::getNullValue(LHSI->getOperand(0)->getType()));
break;
case Instruction::Load:
if (GetElementPtrInst *GEP =
dyn_cast<GetElementPtrInst>(LHSI->getOperand(0)))
if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0)))
if (Instruction *Res =
foldCmpLoadFromIndexedGlobal(cast<LoadInst>(LHSI), GEP, GV, I))
return Res;
break;
}
return nullptr;
}
Instruction *InstCombinerImpl::foldSelectICmp(ICmpInst::Predicate Pred,
SelectInst *SI, Value *RHS,
const ICmpInst &I) {
auto SimplifyOp = [&](Value *Op, bool SelectCondIsTrue) -> Value * {
if (Value *Res = simplifyICmpInst(Pred, Op, RHS, SQ))
return Res;
if (Optional<bool> Impl = isImpliedCondition(SI->getCondition(), Pred, Op,
RHS, DL, SelectCondIsTrue))
return ConstantInt::get(I.getType(), *Impl);
return nullptr;
};
ConstantInt *CI = nullptr;
Value *Op1 = SimplifyOp(SI->getOperand(1), true);
if (Op1)
CI = dyn_cast<ConstantInt>(Op1);
Value *Op2 = SimplifyOp(SI->getOperand(2), false);
if (Op2)
CI = dyn_cast<ConstantInt>(Op2);
bool Transform = false;
if (Op1 && Op2)
Transform = true;
else if (Op1 || Op2) {
if (SI->hasOneUse())
Transform = true;
else if (CI && !CI->isZero())
Transform = replacedSelectWithOperand(SI, &I, Op1 ? 2 : 1);
}
if (Transform) {
if (!Op1)
Op1 = Builder.CreateICmp(Pred, SI->getOperand(1), RHS, I.getName());
if (!Op2)
Op2 = Builder.CreateICmp(Pred, SI->getOperand(2), RHS, I.getName());
return SelectInst::Create(SI->getOperand(0), Op1, Op2);
}
return nullptr;
}
static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I,
InstCombiner::BuilderTy &Builder) {
ICmpInst::Predicate SrcPred;
Value *X, *M, *Y;
auto m_VariableMask = m_CombineOr(
m_CombineOr(m_Not(m_Shl(m_AllOnes(), m_Value())),
m_Add(m_Shl(m_One(), m_Value()), m_AllOnes())),
m_CombineOr(m_LShr(m_AllOnes(), m_Value()),
m_LShr(m_Shl(m_AllOnes(), m_Value(Y)), m_Deferred(Y))));
auto m_Mask = m_CombineOr(m_VariableMask, m_LowBitMask());
if (!match(&I, m_c_ICmp(SrcPred,
m_c_And(m_CombineAnd(m_Mask, m_Value(M)), m_Value(X)),
m_Deferred(X))))
return nullptr;
ICmpInst::Predicate DstPred;
switch (SrcPred) {
case ICmpInst::Predicate::ICMP_EQ:
DstPred = ICmpInst::Predicate::ICMP_ULE;
break;
case ICmpInst::Predicate::ICMP_NE:
DstPred = ICmpInst::Predicate::ICMP_UGT;
break;
case ICmpInst::Predicate::ICMP_ULT:
DstPred = ICmpInst::Predicate::ICMP_UGT;
break;
case ICmpInst::Predicate::ICMP_UGE:
DstPred = ICmpInst::Predicate::ICMP_ULE;
break;
case ICmpInst::Predicate::ICMP_SLT:
if (!match(M, m_Constant())) return nullptr;
if (!match(M, m_NonNegative())) return nullptr;
DstPred = ICmpInst::Predicate::ICMP_SGT;
break;
case ICmpInst::Predicate::ICMP_SGE:
if (!match(M, m_Constant())) return nullptr;
if (!match(M, m_NonNegative())) return nullptr;
DstPred = ICmpInst::Predicate::ICMP_SLE;
break;
case ICmpInst::Predicate::ICMP_SGT:
case ICmpInst::Predicate::ICMP_SLE:
return nullptr;
case ICmpInst::Predicate::ICMP_UGT:
case ICmpInst::Predicate::ICMP_ULE:
llvm_unreachable("Instsimplify took care of commut. variant");
break;
default:
llvm_unreachable("All possible folds are handled.");
}
Type *OpTy = M->getType();
auto *VecC = dyn_cast<Constant>(M);
auto *OpVTy = dyn_cast<FixedVectorType>(OpTy);
if (OpVTy && VecC && VecC->containsUndefOrPoisonElement()) {
Constant *SafeReplacementConstant = nullptr;
for (unsigned i = 0, e = OpVTy->getNumElements(); i != e; ++i) {
if (!isa<UndefValue>(VecC->getAggregateElement(i))) {
SafeReplacementConstant = VecC->getAggregateElement(i);
break;
}
}
assert(SafeReplacementConstant && "Failed to find undef replacement");
M = Constant::replaceUndefsWith(VecC, SafeReplacementConstant);
}
return Builder.CreateICmp(DstPred, X, M);
}
static Value *
foldICmpWithTruncSignExtendedVal(ICmpInst &I,
InstCombiner::BuilderTy &Builder) {
ICmpInst::Predicate SrcPred;
Value *X;
const APInt *C0, *C1; if (!match(&I, m_c_ICmp(SrcPred,
m_OneUse(m_AShr(m_Shl(m_Value(X), m_APInt(C0)),
m_APInt(C1))),
m_Deferred(X))))
return nullptr;
if (*C0 != *C1)
return nullptr;
const APInt &MaskedBits = *C0;
assert(MaskedBits != 0 && "shift by zero should be folded away already.");
ICmpInst::Predicate DstPred;
switch (SrcPred) {
case ICmpInst::Predicate::ICMP_EQ:
DstPred = ICmpInst::Predicate::ICMP_ULT;
break;
case ICmpInst::Predicate::ICMP_NE:
DstPred = ICmpInst::Predicate::ICMP_UGE;
break;
default:
return nullptr;
}
auto *XType = X->getType();
const unsigned XBitWidth = XType->getScalarSizeInBits();
const APInt BitWidth = APInt(XBitWidth, XBitWidth);
assert(BitWidth.ugt(MaskedBits) && "shifts should leave some bits untouched");
const APInt KeptBits = BitWidth - MaskedBits;
assert(KeptBits.ugt(0) && KeptBits.ult(BitWidth) && "unreachable");
const APInt ICmpCst = APInt(XBitWidth, 1).shl(KeptBits);
assert(ICmpCst.isPowerOf2());
const APInt AddCst = ICmpCst.lshr(1);
assert(AddCst.ult(ICmpCst) && AddCst.isPowerOf2());
Value *T0 = Builder.CreateAdd(X, ConstantInt::get(XType, AddCst));
Value *T1 = Builder.CreateICmp(DstPred, T0, ConstantInt::get(XType, ICmpCst));
return T1;
}
static Value *
foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ,
InstCombiner::BuilderTy &Builder) {
if (!I.isEquality() || !match(I.getOperand(1), m_Zero()) ||
!I.getOperand(0)->hasOneUse())
return nullptr;
auto m_AnyLogicalShift = m_LogicalShift(m_Value(), m_Value());
Instruction *XShift, *MaybeTruncation, *YShift;
if (!match(
I.getOperand(0),
m_c_And(m_CombineAnd(m_AnyLogicalShift, m_Instruction(XShift)),
m_CombineAnd(m_TruncOrSelf(m_CombineAnd(
m_AnyLogicalShift, m_Instruction(YShift))),
m_Instruction(MaybeTruncation)))))
return nullptr;
Instruction *WidestShift = YShift;
Instruction *NarrowestShift = XShift;
Type *WidestTy = WidestShift->getType();
Type *NarrowestTy = NarrowestShift->getType();
assert(NarrowestTy == I.getOperand(0)->getType() &&
"We did not look past any shifts while matching XShift though.");
bool HadTrunc = WidestTy != I.getOperand(0)->getType();
if (match(YShift, m_LShr(m_Value(), m_Value())))
std::swap(XShift, YShift);
auto XShiftOpcode = XShift->getOpcode();
if (XShiftOpcode == YShift->getOpcode())
return nullptr;
Value *X, *XShAmt, *Y, *YShAmt;
match(XShift, m_BinOp(m_Value(X), m_ZExtOrSelf(m_Value(XShAmt))));
match(YShift, m_BinOp(m_Value(Y), m_ZExtOrSelf(m_Value(YShAmt))));
if (!isa<Constant>(X) && !isa<Constant>(Y)) {
if (!match(I.getOperand(0),
m_c_And(m_OneUse(m_AnyLogicalShift), m_Value())))
return nullptr;
if (HadTrunc) {
if (!MaybeTruncation->hasOneUse() &&
!NarrowestShift->getOperand(1)->hasOneUse())
return nullptr;
}
}
if (XShAmt->getType() != YShAmt->getType())
return nullptr;
unsigned MaximalPossibleTotalShiftAmount =
(WidestTy->getScalarSizeInBits() - 1) +
(NarrowestTy->getScalarSizeInBits() - 1);
APInt MaximalRepresentableShiftAmount =
APInt::getAllOnes(XShAmt->getType()->getScalarSizeInBits());
if (MaximalRepresentableShiftAmount.ult(MaximalPossibleTotalShiftAmount))
return nullptr;
auto *NewShAmt = dyn_cast_or_null<Constant>(
simplifyAddInst(XShAmt, YShAmt, false,
false, SQ.getWithInstruction(&I)));
if (!NewShAmt)
return nullptr;
NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, WidestTy);
unsigned WidestBitWidth = WidestTy->getScalarSizeInBits();
if (!match(NewShAmt,
m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT,
APInt(WidestBitWidth, WidestBitWidth))))
return nullptr;
if (HadTrunc && match(WidestShift, m_LShr(m_Value(), m_Value()))) {
auto CanFold = [NewShAmt, WidestBitWidth, NarrowestShift, SQ,
WidestShift]() {
Constant *NewShAmtSplat = NewShAmt->getType()->isVectorTy()
? NewShAmt->getSplatValue()
: NewShAmt;
if (NewShAmtSplat &&
(NewShAmtSplat->isNullValue() ||
NewShAmtSplat->getUniqueInteger() == WidestBitWidth - 1))
return true;
if (auto *C = dyn_cast<Constant>(NarrowestShift->getOperand(0))) {
KnownBits Known = computeKnownBits(C, SQ.DL);
unsigned MinLeadZero = Known.countMinLeadingZeros();
unsigned MaxActiveBits = Known.getBitWidth() - MinLeadZero;
if (MaxActiveBits <= 1)
return true;
if (NewShAmtSplat && NewShAmtSplat->getUniqueInteger().ule(MinLeadZero))
return true;
}
if (auto *C = dyn_cast<Constant>(WidestShift->getOperand(0))) {
KnownBits Known = computeKnownBits(C, SQ.DL);
unsigned MinLeadZero = Known.countMinLeadingZeros();
unsigned MaxActiveBits = Known.getBitWidth() - MinLeadZero;
if (MaxActiveBits <= 1)
return true;
if (NewShAmtSplat) {
APInt AdjNewShAmt =
(WidestBitWidth - 1) - NewShAmtSplat->getUniqueInteger();
if (AdjNewShAmt.ule(MinLeadZero))
return true;
}
}
return false; };
if (!CanFold())
return nullptr;
}
X = Builder.CreateZExt(X, WidestTy);
Y = Builder.CreateZExt(Y, WidestTy);
Value *T0 = XShiftOpcode == Instruction::BinaryOps::LShr
? Builder.CreateLShr(X, NewShAmt)
: Builder.CreateShl(X, NewShAmt);
Value *T1 = Builder.CreateAnd(T0, Y);
return Builder.CreateICmp(I.getPredicate(), T1,
Constant::getNullValue(WidestTy));
}
Value *InstCombinerImpl::foldMultiplicationOverflowCheck(ICmpInst &I) {
ICmpInst::Predicate Pred;
Value *X, *Y;
Instruction *Mul;
Instruction *Div;
bool NeedNegation;
if (!I.isEquality() &&
match(&I, m_c_ICmp(Pred,
m_CombineAnd(m_OneUse(m_UDiv(m_AllOnes(), m_Value(X))),
m_Instruction(Div)),
m_Value(Y)))) {
Mul = nullptr;
switch (Pred) {
case ICmpInst::Predicate::ICMP_ULT:
NeedNegation = false;
break; case ICmpInst::Predicate::ICMP_UGE:
NeedNegation = true;
break; default:
return nullptr; }
} else if (I.isEquality() &&
match(&I,
m_c_ICmp(Pred, m_Value(Y),
m_CombineAnd(
m_OneUse(m_IDiv(m_CombineAnd(m_c_Mul(m_Deferred(Y),
m_Value(X)),
m_Instruction(Mul)),
m_Deferred(X))),
m_Instruction(Div))))) {
NeedNegation = Pred == ICmpInst::Predicate::ICMP_EQ;
} else
return nullptr;
BuilderTy::InsertPointGuard Guard(Builder);
bool MulHadOtherUses = Mul && !Mul->hasOneUse();
if (MulHadOtherUses)
Builder.SetInsertPoint(Mul);
Function *F = Intrinsic::getDeclaration(I.getModule(),
Div->getOpcode() == Instruction::UDiv
? Intrinsic::umul_with_overflow
: Intrinsic::smul_with_overflow,
X->getType());
CallInst *Call = Builder.CreateCall(F, {X, Y}, "mul");
if (MulHadOtherUses)
replaceInstUsesWith(*Mul, Builder.CreateExtractValue(Call, 0, "mul.val"));
Value *Res = Builder.CreateExtractValue(Call, 1, "mul.ov");
if (NeedNegation) Res = Builder.CreateNot(Res, "mul.not.ov");
if (MulHadOtherUses)
eraseInstFromFunction(*Mul);
return Res;
}
static Instruction *foldICmpXNegX(ICmpInst &I) {
CmpInst::Predicate Pred;
Value *X;
if (!match(&I, m_c_ICmp(Pred, m_NSWNeg(m_Value(X)), m_Deferred(X))))
return nullptr;
if (ICmpInst::isSigned(Pred))
Pred = ICmpInst::getSwappedPredicate(Pred);
else if (ICmpInst::isUnsigned(Pred))
Pred = ICmpInst::getSignedPredicate(Pred);
return ICmpInst::Create(Instruction::ICmp, Pred, X,
Constant::getNullValue(X->getType()), I.getName());
}
Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
const SimplifyQuery &SQ) {
const SimplifyQuery Q = SQ.getWithInstruction(&I);
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
BinaryOperator *BO0 = dyn_cast<BinaryOperator>(Op0);
BinaryOperator *BO1 = dyn_cast<BinaryOperator>(Op1);
if (!BO0 && !BO1)
return nullptr;
if (Instruction *NewICmp = foldICmpXNegX(I))
return NewICmp;
const CmpInst::Predicate Pred = I.getPredicate();
Value *X;
if (match(Op0, m_OneUse(m_c_Add(m_Specific(Op1), m_Value(X)))) &&
(Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE))
return new ICmpInst(Pred, Builder.CreateNot(Op1), X);
if (match(Op1, m_OneUse(m_c_Add(m_Specific(Op0), m_Value(X)))) &&
(Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULE))
return new ICmpInst(Pred, X, Builder.CreateNot(Op0));
{
Constant *C;
if (match(Op0, m_OneUse(m_Add(m_c_Add(m_Specific(Op1), m_Value(X)),
m_ImmConstant(C)))) &&
(Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE)) {
Constant *C2 = ConstantExpr::getNot(C);
return new ICmpInst(Pred, Builder.CreateSub(C2, X), Op1);
}
if (match(Op1, m_OneUse(m_Add(m_c_Add(m_Specific(Op0), m_Value(X)),
m_ImmConstant(C)))) &&
(Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULE)) {
Constant *C2 = ConstantExpr::getNot(C);
return new ICmpInst(Pred, Op0, Builder.CreateSub(C2, X));
}
}
{
BinaryOperator *BO;
const APInt *C;
if ((Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE) &&
match(Op0, m_And(m_BinOp(BO), m_LowBitMask(C))) &&
match(BO, m_Add(m_Specific(Op1), m_SpecificIntAllowUndef(*C)))) {
CmpInst::Predicate NewPred =
Pred == ICmpInst::ICMP_ULT ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ;
Constant *Zero = ConstantInt::getNullValue(Op1->getType());
return new ICmpInst(NewPred, Op1, Zero);
}
if ((Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULE) &&
match(Op1, m_And(m_BinOp(BO), m_LowBitMask(C))) &&
match(BO, m_Add(m_Specific(Op0), m_SpecificIntAllowUndef(*C)))) {
CmpInst::Predicate NewPred =
Pred == ICmpInst::ICMP_UGT ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ;
Constant *Zero = ConstantInt::getNullValue(Op1->getType());
return new ICmpInst(NewPred, Op0, Zero);
}
}
bool NoOp0WrapProblem = false, NoOp1WrapProblem = false;
if (BO0 && isa<OverflowingBinaryOperator>(BO0))
NoOp0WrapProblem =
ICmpInst::isEquality(Pred) ||
(CmpInst::isUnsigned(Pred) && BO0->hasNoUnsignedWrap()) ||
(CmpInst::isSigned(Pred) && BO0->hasNoSignedWrap());
if (BO1 && isa<OverflowingBinaryOperator>(BO1))
NoOp1WrapProblem =
ICmpInst::isEquality(Pred) ||
(CmpInst::isUnsigned(Pred) && BO1->hasNoUnsignedWrap()) ||
(CmpInst::isSigned(Pred) && BO1->hasNoSignedWrap());
Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr;
if (BO0 && BO0->getOpcode() == Instruction::Add) {
A = BO0->getOperand(0);
B = BO0->getOperand(1);
}
if (BO1 && BO1->getOpcode() == Instruction::Add) {
C = BO1->getOperand(0);
D = BO1->getOperand(1);
}
if ((A == Op1 || B == Op1) && NoOp0WrapProblem)
return new ICmpInst(Pred, A == Op1 ? B : A,
Constant::getNullValue(Op1->getType()));
if ((C == Op0 || D == Op0) && NoOp1WrapProblem)
return new ICmpInst(Pred, Constant::getNullValue(Op0->getType()),
C == Op0 ? D : C);
if (A && C && (A == C || A == D || B == C || B == D) && NoOp0WrapProblem &&
NoOp1WrapProblem) {
Value *Y, *Z;
if (A == C) {
Y = B;
Z = D;
} else if (A == D) {
Y = B;
Z = C;
} else if (B == C) {
Y = A;
Z = D;
} else {
assert(B == D);
Y = A;
Z = C;
}
return new ICmpInst(Pred, Y, Z);
}
if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SLT &&
match(B, m_AllOnes()))
return new ICmpInst(CmpInst::ICMP_SLE, A, Op1);
if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SGE &&
match(B, m_AllOnes()))
return new ICmpInst(CmpInst::ICMP_SGT, A, Op1);
if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SLE && match(B, m_One()))
return new ICmpInst(CmpInst::ICMP_SLT, A, Op1);
if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SGT && match(B, m_One()))
return new ICmpInst(CmpInst::ICMP_SGE, A, Op1);
if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGT &&
match(D, m_AllOnes()))
return new ICmpInst(CmpInst::ICMP_SGE, Op0, C);
if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLE &&
match(D, m_AllOnes()))
return new ICmpInst(CmpInst::ICMP_SLT, Op0, C);
if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGE && match(D, m_One()))
return new ICmpInst(CmpInst::ICMP_SGT, Op0, C);
if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLT && match(D, m_One()))
return new ICmpInst(CmpInst::ICMP_SLE, Op0, C);
if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_ULE && match(B, m_One()))
return new ICmpInst(CmpInst::ICMP_ULT, A, Op1);
if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_UGT && match(B, m_One()))
return new ICmpInst(CmpInst::ICMP_UGE, A, Op1);
if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_UGE && match(D, m_One()))
return new ICmpInst(CmpInst::ICMP_UGT, Op0, C);
if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_ULT && match(D, m_One()))
return new ICmpInst(CmpInst::ICMP_ULE, Op0, C);
if (A && C && NoOp0WrapProblem && NoOp1WrapProblem &&
(BO0->hasOneUse() || BO1->hasOneUse()) && !I.isUnsigned()) {
const APInt *AP1, *AP2;
if (match(B, m_APIntAllowUndef(AP1)) && match(D, m_APIntAllowUndef(AP2)) &&
AP1->isNegative() == AP2->isNegative()) {
APInt AP1Abs = AP1->abs();
APInt AP2Abs = AP2->abs();
if (AP1Abs.uge(AP2Abs)) {
APInt Diff = *AP1 - *AP2;
bool HasNUW = BO0->hasNoUnsignedWrap() && Diff.ule(*AP1);
bool HasNSW = BO0->hasNoSignedWrap();
Constant *C3 = Constant::getIntegerValue(BO0->getType(), Diff);
Value *NewAdd = Builder.CreateAdd(A, C3, "", HasNUW, HasNSW);
return new ICmpInst(Pred, NewAdd, C);
} else {
APInt Diff = *AP2 - *AP1;
bool HasNUW = BO1->hasNoUnsignedWrap() && Diff.ule(*AP2);
bool HasNSW = BO1->hasNoSignedWrap();
Constant *C3 = Constant::getIntegerValue(BO0->getType(), Diff);
Value *NewAdd = Builder.CreateAdd(C, C3, "", HasNUW, HasNSW);
return new ICmpInst(Pred, A, NewAdd);
}
}
Constant *Cst1, *Cst2;
if (match(B, m_ImmConstant(Cst1)) && match(D, m_ImmConstant(Cst2)) &&
ICmpInst::isEquality(Pred)) {
Constant *Diff = ConstantExpr::getSub(Cst2, Cst1);
Value *NewAdd = Builder.CreateAdd(C, Diff);
return new ICmpInst(Pred, A, NewAdd);
}
}
A = nullptr;
B = nullptr;
C = nullptr;
D = nullptr;
if (BO0 && BO0->getOpcode() == Instruction::Sub) {
A = BO0->getOperand(0);
B = BO0->getOperand(1);
}
if (BO1 && BO1->getOpcode() == Instruction::Sub) {
C = BO1->getOperand(0);
D = BO1->getOperand(1);
}
if (A == Op1 && NoOp0WrapProblem)
return new ICmpInst(Pred, Constant::getNullValue(Op1->getType()), B);
if (C == Op0 && NoOp1WrapProblem)
return new ICmpInst(Pred, D, Constant::getNullValue(Op0->getType()));
if (A == Op1 && (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULE))
return new ICmpInst(Pred, B, A);
if (C == Op0 && (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE))
return new ICmpInst(Pred, C, D);
if (A == Op1 && (Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_ULT) &&
isKnownNonZero(B, Q.DL, 0, Q.AC, Q.CxtI, Q.DT))
return new ICmpInst(CmpInst::getFlippedStrictnessPredicate(Pred), B, A);
if (C == Op0 && (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT) &&
isKnownNonZero(D, Q.DL, 0, Q.AC, Q.CxtI, Q.DT))
return new ICmpInst(CmpInst::getFlippedStrictnessPredicate(Pred), C, D);
if (B && D && B == D && NoOp0WrapProblem && NoOp1WrapProblem)
return new ICmpInst(Pred, A, C);
if (A && C && A == C && NoOp0WrapProblem && NoOp1WrapProblem)
return new ICmpInst(Pred, D, B);
if (NoOp0WrapProblem && ICmpInst::isSigned(Pred)) {
Value *X;
if (match(BO0, m_Neg(m_Value(X))))
if (Constant *RHSC = dyn_cast<Constant>(Op1))
if (RHSC->isNotMinSignedValue())
return new ICmpInst(I.getSwappedPredicate(), X,
ConstantExpr::getNeg(RHSC));
}
{
Value *X, *Y;
const APInt *C;
if (match(Op0, m_Mul(m_Value(X), m_APInt(C))) && *C != 0 &&
match(Op1, m_Mul(m_Value(Y), m_SpecificInt(*C))) && I.isEquality())
if (!C->countTrailingZeros() ||
(BO0 && BO1 && BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap()) ||
(BO0 && BO1 && BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap()))
return new ICmpInst(Pred, X, Y);
}
BinaryOperator *SRem = nullptr;
if (BO0 && BO0->getOpcode() == Instruction::SRem && Op1 == BO0->getOperand(1))
SRem = BO0;
else if (BO1 && BO1->getOpcode() == Instruction::SRem &&
Op0 == BO1->getOperand(1))
SRem = BO1;
if (SRem) {
switch (SRem == BO0 ? ICmpInst::getSwappedPredicate(Pred) : Pred) {
default:
break;
case ICmpInst::ICMP_EQ:
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
case ICmpInst::ICMP_NE:
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_SGE:
return new ICmpInst(ICmpInst::ICMP_SGT, SRem->getOperand(1),
Constant::getAllOnesValue(SRem->getType()));
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_SLE:
return new ICmpInst(ICmpInst::ICMP_SLT, SRem->getOperand(1),
Constant::getNullValue(SRem->getType()));
}
}
if (BO0 && BO1 && BO0->getOpcode() == BO1->getOpcode() && BO0->hasOneUse() &&
BO1->hasOneUse() && BO0->getOperand(1) == BO1->getOperand(1)) {
switch (BO0->getOpcode()) {
default:
break;
case Instruction::Add:
case Instruction::Sub:
case Instruction::Xor: {
if (I.isEquality()) return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0));
const APInt *C;
if (match(BO0->getOperand(1), m_APInt(C))) {
if (C->isSignMask()) {
ICmpInst::Predicate NewPred = I.getFlippedSignednessPredicate();
return new ICmpInst(NewPred, BO0->getOperand(0), BO1->getOperand(0));
}
if (BO0->getOpcode() == Instruction::Xor && C->isMaxSignedValue()) {
ICmpInst::Predicate NewPred = I.getFlippedSignednessPredicate();
NewPred = I.getSwappedPredicate(NewPred);
return new ICmpInst(NewPred, BO0->getOperand(0), BO1->getOperand(0));
}
}
break;
}
case Instruction::Mul: {
if (!I.isEquality())
break;
const APInt *C;
if (match(BO0->getOperand(1), m_APInt(C)) && !C->isZero() &&
!C->isOne()) {
if (unsigned TZs = C->countTrailingZeros()) {
Constant *Mask = ConstantInt::get(
BO0->getType(),
APInt::getLowBitsSet(C->getBitWidth(), C->getBitWidth() - TZs));
Value *And1 = Builder.CreateAnd(BO0->getOperand(0), Mask);
Value *And2 = Builder.CreateAnd(BO1->getOperand(0), Mask);
return new ICmpInst(Pred, And1, And2);
}
}
break;
}
case Instruction::UDiv:
case Instruction::LShr:
if (I.isSigned() || !BO0->isExact() || !BO1->isExact())
break;
return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0));
case Instruction::SDiv:
if (!I.isEquality() || !BO0->isExact() || !BO1->isExact())
break;
return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0));
case Instruction::AShr:
if (!BO0->isExact() || !BO1->isExact())
break;
return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0));
case Instruction::Shl: {
bool NUW = BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap();
bool NSW = BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap();
if (!NUW && !NSW)
break;
if (!NSW && I.isSigned())
break;
return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0));
}
}
}
if (BO0) {
auto LSubOne = m_Add(m_Specific(Op1), m_AllOnes());
auto BitwiseAnd = m_c_And(m_Value(), LSubOne);
if (match(BO0, BitwiseAnd) && Pred == ICmpInst::ICMP_ULT) {
auto *Zero = Constant::getNullValue(BO0->getType());
return new ICmpInst(ICmpInst::ICMP_NE, Op1, Zero);
}
}
if (Value *V = foldMultiplicationOverflowCheck(I))
return replaceInstUsesWith(I, V);
if (Value *V = foldICmpWithLowBitMaskedVal(I, Builder))
return replaceInstUsesWith(I, V);
if (Value *V = foldICmpWithTruncSignExtendedVal(I, Builder))
return replaceInstUsesWith(I, V);
if (Value *V = foldShiftIntoShiftInAnotherHandOfAndInICmp(I, SQ, Builder))
return replaceInstUsesWith(I, V);
return nullptr;
}
static Instruction *foldICmpWithMinMax(ICmpInst &Cmp) {
ICmpInst::Predicate Pred = Cmp.getPredicate();
Value *Op0 = Cmp.getOperand(0);
Value *X = Cmp.getOperand(1);
if (match(X, m_c_SMin(m_Specific(Op0), m_Value())) ||
match(X, m_c_SMax(m_Specific(Op0), m_Value())) ||
match(X, m_c_UMin(m_Specific(Op0), m_Value())) ||
match(X, m_c_UMax(m_Specific(Op0), m_Value()))) {
std::swap(Op0, X);
Pred = Cmp.getSwappedPredicate();
}
Value *Y;
if (match(Op0, m_c_SMin(m_Specific(X), m_Value(Y)))) {
if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_SGE)
return new ICmpInst(ICmpInst::ICMP_SLE, X, Y);
if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_SLT)
return new ICmpInst(ICmpInst::ICMP_SGT, X, Y);
return nullptr;
}
if (match(Op0, m_c_SMax(m_Specific(X), m_Value(Y)))) {
if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_SLE)
return new ICmpInst(ICmpInst::ICMP_SGE, X, Y);
if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_SGT)
return new ICmpInst(ICmpInst::ICMP_SLT, X, Y);
return nullptr;
}
if (match(Op0, m_c_UMin(m_Specific(X), m_Value(Y)))) {
if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_UGE)
return new ICmpInst(ICmpInst::ICMP_ULE, X, Y);
if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_ULT)
return new ICmpInst(ICmpInst::ICMP_UGT, X, Y);
return nullptr;
}
if (match(Op0, m_c_UMax(m_Specific(X), m_Value(Y)))) {
if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_ULE)
return new ICmpInst(ICmpInst::ICMP_UGE, X, Y);
if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_UGT)
return new ICmpInst(ICmpInst::ICMP_ULT, X, Y);
return nullptr;
}
return nullptr;
}
Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
if (!I.isEquality())
return nullptr;
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
const CmpInst::Predicate Pred = I.getPredicate();
Value *A, *B, *C, *D;
if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) {
if (A == Op1 || B == Op1) { Value *OtherVal = A == Op1 ? B : A;
return new ICmpInst(Pred, OtherVal, Constant::getNullValue(A->getType()));
}
if (match(Op1, m_Xor(m_Value(C), m_Value(D)))) {
ConstantInt *C1, *C2;
if (match(B, m_ConstantInt(C1)) && match(D, m_ConstantInt(C2)) &&
Op1->hasOneUse()) {
Constant *NC = Builder.getInt(C1->getValue() ^ C2->getValue());
Value *Xor = Builder.CreateXor(C, NC);
return new ICmpInst(Pred, A, Xor);
}
if (A == C)
return new ICmpInst(Pred, B, D);
if (A == D)
return new ICmpInst(Pred, B, C);
if (B == C)
return new ICmpInst(Pred, A, D);
if (B == D)
return new ICmpInst(Pred, A, C);
}
}
if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && (A == Op0 || B == Op0)) {
Value *OtherVal = A == Op0 ? B : A;
return new ICmpInst(Pred, OtherVal, Constant::getNullValue(A->getType()));
}
if (match(Op0, m_OneUse(m_And(m_Value(A), m_Value(B)))) &&
match(Op1, m_OneUse(m_And(m_Value(C), m_Value(D))))) {
Value *X = nullptr, *Y = nullptr, *Z = nullptr;
if (A == C) {
X = B;
Y = D;
Z = A;
} else if (A == D) {
X = B;
Y = C;
Z = A;
} else if (B == C) {
X = A;
Y = D;
Z = B;
} else if (B == D) {
X = A;
Y = C;
Z = B;
}
if (X) { Op1 = Builder.CreateXor(X, Y);
Op1 = Builder.CreateAnd(Op1, Z);
return new ICmpInst(Pred, Op1, Constant::getNullValue(Op1->getType()));
}
}
{
Value *X, *Y;
Constant *C;
if (match(Op0, m_OneUse(m_Or(m_Value(X), m_Constant(C)))) &&
match(Op1, m_OneUse(m_Or(m_Value(Y), m_Specific(C))))) {
Value *Xor = Builder.CreateXor(X, Y);
Value *And = Builder.CreateAnd(Xor, ConstantExpr::getNot(C));
return new ICmpInst(Pred, And, Constant::getNullValue(And->getType()));
}
}
ConstantInt *Cst1;
if ((Op0->hasOneUse() && match(Op0, m_ZExt(m_Value(A))) &&
match(Op1, m_And(m_Value(B), m_ConstantInt(Cst1)))) ||
(Op1->hasOneUse() && match(Op0, m_And(m_Value(B), m_ConstantInt(Cst1))) &&
match(Op1, m_ZExt(m_Value(A))))) {
APInt Pow2 = Cst1->getValue() + 1;
if (Pow2.isPowerOf2() && isa<IntegerType>(A->getType()) &&
Pow2.logBase2() == cast<IntegerType>(A->getType())->getBitWidth())
return new ICmpInst(Pred, A, Builder.CreateTrunc(B, A->getType()));
}
const APInt *AP1, *AP2;
if ((match(Op0, m_OneUse(m_LShr(m_Value(A), m_APIntAllowUndef(AP1)))) &&
match(Op1, m_OneUse(m_LShr(m_Value(B), m_APIntAllowUndef(AP2))))) ||
(match(Op0, m_OneUse(m_AShr(m_Value(A), m_APIntAllowUndef(AP1)))) &&
match(Op1, m_OneUse(m_AShr(m_Value(B), m_APIntAllowUndef(AP2)))))) {
if (AP1 != AP2)
return nullptr;
unsigned TypeBits = AP1->getBitWidth();
unsigned ShAmt = AP1->getLimitedValue(TypeBits);
if (ShAmt < TypeBits && ShAmt != 0) {
ICmpInst::Predicate NewPred =
Pred == ICmpInst::ICMP_NE ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT;
Value *Xor = Builder.CreateXor(A, B, I.getName() + ".unshifted");
APInt CmpVal = APInt::getOneBitSet(TypeBits, ShAmt);
return new ICmpInst(NewPred, Xor, ConstantInt::get(A->getType(), CmpVal));
}
}
if (match(Op0, m_OneUse(m_Shl(m_Value(A), m_ConstantInt(Cst1)))) &&
match(Op1, m_OneUse(m_Shl(m_Value(B), m_Specific(Cst1))))) {
unsigned TypeBits = Cst1->getBitWidth();
unsigned ShAmt = (unsigned)Cst1->getLimitedValue(TypeBits);
if (ShAmt < TypeBits && ShAmt != 0) {
Value *Xor = Builder.CreateXor(A, B, I.getName() + ".unshifted");
APInt AndVal = APInt::getLowBitsSet(TypeBits, TypeBits - ShAmt);
Value *And = Builder.CreateAnd(Xor, Builder.getInt(AndVal),
I.getName() + ".mask");
return new ICmpInst(Pred, And, Constant::getNullValue(Cst1->getType()));
}
}
uint64_t ShAmt = 0;
if (Op0->hasOneUse() &&
match(Op0, m_Trunc(m_OneUse(m_LShr(m_Value(A), m_ConstantInt(ShAmt))))) &&
match(Op1, m_ConstantInt(Cst1)) &&
!A->hasOneUse()) {
unsigned ASize = cast<IntegerType>(A->getType())->getPrimitiveSizeInBits();
if (ShAmt < ASize) {
APInt MaskV =
APInt::getLowBitsSet(ASize, Op0->getType()->getPrimitiveSizeInBits());
MaskV <<= ShAmt;
APInt CmpV = Cst1->getValue().zext(ASize);
CmpV <<= ShAmt;
Value *Mask = Builder.CreateAnd(A, Builder.getInt(MaskV));
return new ICmpInst(Pred, Mask, Builder.getInt(CmpV));
}
}
if (Instruction *ICmp = foldICmpIntrinsicWithIntrinsic(I))
return ICmp;
if (!match(Op0, m_OneUse(m_c_And(m_Add(m_Value(A), m_AllOnes()),
m_Deferred(A)))) ||
!match(Op1, m_ZeroInt()))
A = nullptr;
if (match(Op0, m_OneUse(m_c_And(m_Neg(m_Specific(Op1)), m_Specific(Op1)))))
A = Op1;
else if (match(Op1,
m_OneUse(m_c_And(m_Neg(m_Specific(Op0)), m_Specific(Op0)))))
A = Op0;
if (A) {
Type *Ty = A->getType();
CallInst *CtPop = Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, A);
return Pred == ICmpInst::ICMP_EQ
? new ICmpInst(ICmpInst::ICMP_ULT, CtPop, ConstantInt::get(Ty, 2))
: new ICmpInst(ICmpInst::ICMP_UGT, CtPop, ConstantInt::get(Ty, 1));
}
unsigned BitWidth = Op0->getType()->getScalarSizeInBits();
if (match(Op0, m_AShr(m_Trunc(m_Value(A)), m_SpecificInt(BitWidth - 1))) &&
match(Op1, m_Trunc(m_LShr(m_Specific(A), m_SpecificInt(BitWidth)))) &&
A->getType()->getScalarSizeInBits() == BitWidth * 2 &&
(I.getOperand(0)->hasOneUse() || I.getOperand(1)->hasOneUse())) {
APInt C = APInt::getOneBitSet(BitWidth * 2, BitWidth - 1);
Value *Add = Builder.CreateAdd(A, ConstantInt::get(A->getType(), C));
return new ICmpInst(Pred == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_ULT
: ICmpInst::ICMP_UGE,
Add, ConstantInt::get(A->getType(), C.shl(1)));
}
return nullptr;
}
static Instruction *foldICmpWithTrunc(ICmpInst &ICmp,
InstCombiner::BuilderTy &Builder) {
ICmpInst::Predicate Pred = ICmp.getPredicate();
Value *Op0 = ICmp.getOperand(0), *Op1 = ICmp.getOperand(1);
Value *X;
const APInt *C;
if (!match(Op0, m_OneUse(m_Trunc(m_Value(X)))) || !match(Op1, m_APInt(C)))
return nullptr;
APInt Mask;
if (decomposeBitTestICmp(Op0, Op1, Pred, X, Mask, true )) {
Value *And = Builder.CreateAnd(X, Mask);
Constant *Zero = ConstantInt::getNullValue(X->getType());
return new ICmpInst(Pred, And, Zero);
}
unsigned SrcBits = X->getType()->getScalarSizeInBits();
if (Pred == ICmpInst::ICMP_ULT && C->isNegatedPowerOf2()) {
Constant *MaskC = ConstantInt::get(X->getType(), C->zext(SrcBits));
Value *And = Builder.CreateAnd(X, MaskC);
return new ICmpInst(ICmpInst::ICMP_NE, And, MaskC);
}
if (Pred == ICmpInst::ICMP_UGT && (~*C).isPowerOf2()) {
Constant *MaskC = ConstantInt::get(X->getType(), (*C + 1).zext(SrcBits));
Value *And = Builder.CreateAnd(X, MaskC);
return new ICmpInst(ICmpInst::ICMP_EQ, And, MaskC);
}
return nullptr;
}
Instruction *InstCombinerImpl::foldICmpWithZextOrSext(ICmpInst &ICmp) {
assert(isa<CastInst>(ICmp.getOperand(0)) && "Expected cast for operand 0");
auto *CastOp0 = cast<CastInst>(ICmp.getOperand(0));
Value *X;
if (!match(CastOp0, m_ZExtOrSExt(m_Value(X))))
return nullptr;
bool IsSignedExt = CastOp0->getOpcode() == Instruction::SExt;
bool IsSignedCmp = ICmp.isSigned();
Value *Y;
if (match(ICmp.getOperand(1), m_ZExtOrSExt(m_Value(Y)))) {
bool IsZext0 = isa<ZExtOperator>(ICmp.getOperand(0));
bool IsZext1 = isa<ZExtOperator>(ICmp.getOperand(1));
if (IsZext0 != IsZext1) {
if ((IsZext0 && isKnownNonNegative(X, DL, 0, &AC, &ICmp, &DT)) ||
(IsZext1 && isKnownNonNegative(Y, DL, 0, &AC, &ICmp, &DT)))
IsSignedExt = true;
else
return nullptr;
}
Type *XTy = X->getType(), *YTy = Y->getType();
if (XTy != YTy) {
if (!ICmp.getOperand(0)->hasOneUse() && !ICmp.getOperand(1)->hasOneUse())
return nullptr;
CastInst::CastOps CastOpcode =
IsSignedExt ? Instruction::SExt : Instruction::ZExt;
if (XTy->getScalarSizeInBits() < YTy->getScalarSizeInBits())
X = Builder.CreateCast(CastOpcode, X, YTy);
else if (YTy->getScalarSizeInBits() < XTy->getScalarSizeInBits())
Y = Builder.CreateCast(CastOpcode, Y, XTy);
else
return nullptr;
}
if (ICmp.isEquality())
return new ICmpInst(ICmp.getPredicate(), X, Y);
if (IsSignedCmp && IsSignedExt)
return new ICmpInst(ICmp.getPredicate(), X, Y);
return new ICmpInst(ICmp.getUnsignedPredicate(), X, Y);
}
auto *C = dyn_cast<Constant>(ICmp.getOperand(1));
if (!C)
return nullptr;
Type *SrcTy = CastOp0->getSrcTy();
Type *DestTy = CastOp0->getDestTy();
Constant *Res1 = ConstantExpr::getTrunc(C, SrcTy);
Constant *Res2 = ConstantExpr::getCast(CastOp0->getOpcode(), Res1, DestTy);
if (Res2 == C) {
if (ICmp.isEquality())
return new ICmpInst(ICmp.getPredicate(), X, Res1);
if (IsSignedExt && IsSignedCmp)
return new ICmpInst(ICmp.getPredicate(), X, Res1);
return new ICmpInst(ICmp.getUnsignedPredicate(), X, Res1);
}
if (IsSignedCmp || !IsSignedExt || !isa<ConstantInt>(C))
return nullptr;
if (ICmp.getPredicate() == ICmpInst::ICMP_ULT)
return new ICmpInst(CmpInst::ICMP_SGT, X, Constant::getAllOnesValue(SrcTy));
assert(ICmp.getPredicate() == ICmpInst::ICMP_UGT && "ICmp should be folded!");
return new ICmpInst(CmpInst::ICMP_SLT, X, Constant::getNullValue(SrcTy));
}
Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) {
Value *SimplifiedOp0 = simplifyIntToPtrRoundTripCast(ICmp.getOperand(0));
Value *SimplifiedOp1 = simplifyIntToPtrRoundTripCast(ICmp.getOperand(1));
if (SimplifiedOp0 || SimplifiedOp1)
return new ICmpInst(ICmp.getPredicate(),
SimplifiedOp0 ? SimplifiedOp0 : ICmp.getOperand(0),
SimplifiedOp1 ? SimplifiedOp1 : ICmp.getOperand(1));
auto *CastOp0 = dyn_cast<CastInst>(ICmp.getOperand(0));
if (!CastOp0)
return nullptr;
if (!isa<Constant>(ICmp.getOperand(1)) && !isa<CastInst>(ICmp.getOperand(1)))
return nullptr;
Value *Op0Src = CastOp0->getOperand(0);
Type *SrcTy = CastOp0->getSrcTy();
Type *DestTy = CastOp0->getDestTy();
auto CompatibleSizes = [&](Type *SrcTy, Type *DestTy) {
if (isa<VectorType>(SrcTy)) {
SrcTy = cast<VectorType>(SrcTy)->getElementType();
DestTy = cast<VectorType>(DestTy)->getElementType();
}
return DL.getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth();
};
if (CastOp0->getOpcode() == Instruction::PtrToInt &&
CompatibleSizes(SrcTy, DestTy)) {
Value *NewOp1 = nullptr;
if (auto *PtrToIntOp1 = dyn_cast<PtrToIntOperator>(ICmp.getOperand(1))) {
Value *PtrSrc = PtrToIntOp1->getOperand(0);
if (PtrSrc->getType()->getPointerAddressSpace() ==
Op0Src->getType()->getPointerAddressSpace()) {
NewOp1 = PtrToIntOp1->getOperand(0);
if (Op0Src->getType() != NewOp1->getType())
NewOp1 = Builder.CreateBitCast(NewOp1, Op0Src->getType());
}
} else if (auto *RHSC = dyn_cast<Constant>(ICmp.getOperand(1))) {
NewOp1 = ConstantExpr::getIntToPtr(RHSC, SrcTy);
}
if (NewOp1)
return new ICmpInst(ICmp.getPredicate(), Op0Src, NewOp1);
}
if (Instruction *R = foldICmpWithTrunc(ICmp, Builder))
return R;
return foldICmpWithZextOrSext(ICmp);
}
static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS) {
switch (BinaryOp) {
default:
llvm_unreachable("Unsupported binary op");
case Instruction::Add:
case Instruction::Sub:
return match(RHS, m_Zero());
case Instruction::Mul:
return match(RHS, m_One());
}
}
OverflowResult
InstCombinerImpl::computeOverflow(Instruction::BinaryOps BinaryOp,
bool IsSigned, Value *LHS, Value *RHS,
Instruction *CxtI) const {
switch (BinaryOp) {
default:
llvm_unreachable("Unsupported binary op");
case Instruction::Add:
if (IsSigned)
return computeOverflowForSignedAdd(LHS, RHS, CxtI);
else
return computeOverflowForUnsignedAdd(LHS, RHS, CxtI);
case Instruction::Sub:
if (IsSigned)
return computeOverflowForSignedSub(LHS, RHS, CxtI);
else
return computeOverflowForUnsignedSub(LHS, RHS, CxtI);
case Instruction::Mul:
if (IsSigned)
return computeOverflowForSignedMul(LHS, RHS, CxtI);
else
return computeOverflowForUnsignedMul(LHS, RHS, CxtI);
}
}
bool InstCombinerImpl::OptimizeOverflowCheck(Instruction::BinaryOps BinaryOp,
bool IsSigned, Value *LHS,
Value *RHS, Instruction &OrigI,
Value *&Result,
Constant *&Overflow) {
if (OrigI.isCommutative() && isa<Constant>(LHS) && !isa<Constant>(RHS))
std::swap(LHS, RHS);
Builder.SetInsertPoint(&OrigI);
Type *OverflowTy = Type::getInt1Ty(LHS->getContext());
if (auto *LHSTy = dyn_cast<VectorType>(LHS->getType()))
OverflowTy = VectorType::get(OverflowTy, LHSTy->getElementCount());
if (isNeutralValue(BinaryOp, RHS)) {
Result = LHS;
Overflow = ConstantInt::getFalse(OverflowTy);
return true;
}
switch (computeOverflow(BinaryOp, IsSigned, LHS, RHS, &OrigI)) {
case OverflowResult::MayOverflow:
return false;
case OverflowResult::AlwaysOverflowsLow:
case OverflowResult::AlwaysOverflowsHigh:
Result = Builder.CreateBinOp(BinaryOp, LHS, RHS);
Result->takeName(&OrigI);
Overflow = ConstantInt::getTrue(OverflowTy);
return true;
case OverflowResult::NeverOverflows:
Result = Builder.CreateBinOp(BinaryOp, LHS, RHS);
Result->takeName(&OrigI);
Overflow = ConstantInt::getFalse(OverflowTy);
if (auto *Inst = dyn_cast<Instruction>(Result)) {
if (IsSigned)
Inst->setHasNoSignedWrap();
else
Inst->setHasNoUnsignedWrap();
}
return true;
}
llvm_unreachable("Unexpected overflow result");
}
static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
Value *OtherVal,
InstCombinerImpl &IC) {
if (!isa<IntegerType>(MulVal->getType()))
return nullptr;
assert(I.getOperand(0) == MulVal || I.getOperand(1) == MulVal);
assert(I.getOperand(0) == OtherVal || I.getOperand(1) == OtherVal);
auto *MulInstr = dyn_cast<Instruction>(MulVal);
if (!MulInstr)
return nullptr;
assert(MulInstr->getOpcode() == Instruction::Mul);
auto *LHS = cast<ZExtOperator>(MulInstr->getOperand(0)),
*RHS = cast<ZExtOperator>(MulInstr->getOperand(1));
assert(LHS->getOpcode() == Instruction::ZExt);
assert(RHS->getOpcode() == Instruction::ZExt);
Value *A = LHS->getOperand(0), *B = RHS->getOperand(0);
Type *TyA = A->getType(), *TyB = B->getType();
unsigned WidthA = TyA->getPrimitiveSizeInBits(),
WidthB = TyB->getPrimitiveSizeInBits();
unsigned MulWidth;
Type *MulType;
if (WidthB > WidthA) {
MulWidth = WidthB;
MulType = TyB;
} else {
MulWidth = WidthA;
MulType = TyA;
}
if (MulVal->hasNUsesOrMore(2))
for (User *U : MulVal->users()) {
if (U == &I)
continue;
if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
unsigned TruncWidth = TI->getType()->getPrimitiveSizeInBits();
if (TruncWidth > MulWidth)
return nullptr;
} else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
if (BO->getOpcode() != Instruction::And)
return nullptr;
if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1))) {
const APInt &CVal = CI->getValue();
if (CVal.getBitWidth() - CVal.countLeadingZeros() > MulWidth)
return nullptr;
} else {
return nullptr;
}
} else {
return nullptr;
}
}
switch (I.getPredicate()) {
case ICmpInst::ICMP_EQ:
case ICmpInst::ICMP_NE:
ConstantInt *CI;
Value *ValToMask;
if (match(OtherVal, m_And(m_Value(ValToMask), m_ConstantInt(CI)))) {
if (ValToMask != MulVal)
return nullptr;
const APInt &CVal = CI->getValue() + 1;
if (CVal.isPowerOf2()) {
unsigned MaskWidth = CVal.logBase2();
if (MaskWidth == MulWidth)
break; }
}
return nullptr;
case ICmpInst::ICMP_UGT:
if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) {
APInt MaxVal = APInt::getMaxValue(MulWidth);
MaxVal = MaxVal.zext(CI->getBitWidth());
if (MaxVal.eq(CI->getValue()))
break; }
return nullptr;
case ICmpInst::ICMP_UGE:
if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) {
APInt MaxVal = APInt::getOneBitSet(CI->getBitWidth(), MulWidth);
if (MaxVal.eq(CI->getValue()))
break; }
return nullptr;
case ICmpInst::ICMP_ULE:
if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) {
APInt MaxVal = APInt::getMaxValue(MulWidth);
MaxVal = MaxVal.zext(CI->getBitWidth());
if (MaxVal.eq(CI->getValue()))
break; }
return nullptr;
case ICmpInst::ICMP_ULT:
if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) {
APInt MaxVal = APInt::getOneBitSet(CI->getBitWidth(), MulWidth);
if (MaxVal.eq(CI->getValue()))
break; }
return nullptr;
default:
return nullptr;
}
InstCombiner::BuilderTy &Builder = IC.Builder;
Builder.SetInsertPoint(MulInstr);
Value *MulA = A, *MulB = B;
if (WidthA < MulWidth)
MulA = Builder.CreateZExt(A, MulType);
if (WidthB < MulWidth)
MulB = Builder.CreateZExt(B, MulType);
Function *F = Intrinsic::getDeclaration(
I.getModule(), Intrinsic::umul_with_overflow, MulType);
CallInst *Call = Builder.CreateCall(F, {MulA, MulB}, "umul");
IC.addToWorklist(MulInstr);
if (MulVal->hasNUsesOrMore(2)) {
Value *Mul = Builder.CreateExtractValue(Call, 0, "umul.value");
for (User *U : make_early_inc_range(MulVal->users())) {
if (U == &I || U == OtherVal)
continue;
if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
if (TI->getType()->getPrimitiveSizeInBits() == MulWidth)
IC.replaceInstUsesWith(*TI, Mul);
else
TI->setOperand(0, Mul);
} else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
assert(BO->getOpcode() == Instruction::And);
ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1));
APInt ShortMask = CI->getValue().trunc(MulWidth);
Value *ShortAnd = Builder.CreateAnd(Mul, ShortMask);
Value *Zext = Builder.CreateZExt(ShortAnd, BO->getType());
IC.replaceInstUsesWith(*BO, Zext);
} else {
llvm_unreachable("Unexpected Binary operation");
}
IC.addToWorklist(cast<Instruction>(U));
}
}
if (isa<Instruction>(OtherVal))
IC.addToWorklist(cast<Instruction>(OtherVal));
bool Inverse = false;
switch (I.getPredicate()) {
case ICmpInst::ICMP_NE:
break;
case ICmpInst::ICMP_EQ:
Inverse = true;
break;
case ICmpInst::ICMP_UGT:
case ICmpInst::ICMP_UGE:
if (I.getOperand(0) == MulVal)
break;
Inverse = true;
break;
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_ULE:
if (I.getOperand(1) == MulVal)
break;
Inverse = true;
break;
default:
llvm_unreachable("Unexpected predicate");
}
if (Inverse) {
Value *Res = Builder.CreateExtractValue(Call, 1);
return BinaryOperator::CreateNot(Res);
}
return ExtractValueInst::Create(Call, 1);
}
static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth) {
const APInt *RHS;
if (!match(I.getOperand(1), m_APInt(RHS)))
return APInt::getAllOnes(BitWidth);
bool UnusedBit;
if (InstCombiner::isSignBitCheck(I.getPredicate(), *RHS, UnusedBit))
return APInt::getSignMask(BitWidth);
switch (I.getPredicate()) {
case ICmpInst::ICMP_UGT:
return APInt::getBitsSetFrom(BitWidth, RHS->countTrailingOnes());
case ICmpInst::ICMP_ULT:
return APInt::getBitsSetFrom(BitWidth, RHS->countTrailingZeros());
default:
return APInt::getAllOnes(BitWidth);
}
}
static bool swapMayExposeCSEOpportunities(const Value *Op0, const Value *Op1) {
if (Op0->getType()->isPointerTy())
return false;
int GoodToSwap = 0;
for (const User *U : Op0->users()) {
if (match(U, m_Sub(m_Specific(Op1), m_Specific(Op0))))
GoodToSwap++;
else if (match(U, m_Sub(m_Specific(Op0), m_Specific(Op1))))
GoodToSwap--;
}
return GoodToSwap > 0;
}
bool InstCombinerImpl::dominatesAllUses(const Instruction *DI,
const Instruction *UI,
const BasicBlock *DB) const {
assert(DI && UI && "Instruction not defined\n");
if (!DI->getParent())
return false;
if (DI->getParent() != UI->getParent())
return false;
if (DI->getParent() == DB)
return false;
for (const User *U : DI->users()) {
auto *Usr = cast<Instruction>(U);
if (Usr != UI && !DT.dominates(DB, Usr->getParent()))
return false;
}
return true;
}
static bool isChainSelectCmpBranch(const SelectInst *SI) {
const BasicBlock *BB = SI->getParent();
if (!BB)
return false;
auto *BI = dyn_cast_or_null<BranchInst>(BB->getTerminator());
if (!BI || BI->getNumSuccessors() != 2)
return false;
auto *IC = dyn_cast<ICmpInst>(BI->getCondition());
if (!IC || (IC->getOperand(0) != SI && IC->getOperand(1) != SI))
return false;
return true;
}
bool InstCombinerImpl::replacedSelectWithOperand(SelectInst *SI,
const ICmpInst *Icmp,
const unsigned SIOpd) {
assert((SIOpd == 1 || SIOpd == 2) && "Invalid select operand!");
if (isChainSelectCmpBranch(SI) && Icmp->getPredicate() == ICmpInst::ICMP_EQ) {
BasicBlock *Succ = SI->getParent()->getTerminator()->getSuccessor(1);
if (Succ->getSinglePredecessor() && dominatesAllUses(SI, Icmp, Succ)) {
NumSel++;
SI->replaceUsesOutsideBlock(SI->getOperand(SIOpd), SI->getParent());
return true;
}
}
return false;
}
Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
Type *Ty = Op0->getType();
ICmpInst::Predicate Pred = I.getPredicate();
unsigned BitWidth = Ty->isIntOrIntVectorTy()
? Ty->getScalarSizeInBits()
: DL.getPointerTypeSizeInBits(Ty->getScalarType());
if (!BitWidth)
return nullptr;
KnownBits Op0Known(BitWidth);
KnownBits Op1Known(BitWidth);
if (SimplifyDemandedBits(&I, 0,
getDemandedBitsLHSMask(I, BitWidth),
Op0Known, 0))
return &I;
if (SimplifyDemandedBits(&I, 1, APInt::getAllOnes(BitWidth), Op1Known, 0))
return &I;
APInt Op0Min(BitWidth, 0), Op0Max(BitWidth, 0);
APInt Op1Min(BitWidth, 0), Op1Max(BitWidth, 0);
if (I.isSigned()) {
Op0Min = Op0Known.getSignedMinValue();
Op0Max = Op0Known.getSignedMaxValue();
Op1Min = Op1Known.getSignedMinValue();
Op1Max = Op1Known.getSignedMaxValue();
} else {
Op0Min = Op0Known.getMinValue();
Op0Max = Op0Known.getMaxValue();
Op1Min = Op1Known.getMinValue();
Op1Max = Op1Known.getMaxValue();
}
if (!isa<Constant>(Op0) && Op0Min == Op0Max)
return new ICmpInst(Pred, ConstantExpr::getIntegerValue(Ty, Op0Min), Op1);
if (!isa<Constant>(Op1) && Op1Min == Op1Max)
return new ICmpInst(Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Min));
auto isMinMaxCmp = [&](Instruction &Cmp) {
if (!Cmp.hasOneUse())
return false;
Value *A, *B;
SelectPatternFlavor SPF = matchSelectPattern(Cmp.user_back(), A, B).Flavor;
if (!SelectPatternResult::isMinOrMax(SPF))
return false;
return match(Op0, m_MaxOrMin(m_Value(), m_Value())) ||
match(Op1, m_MaxOrMin(m_Value(), m_Value()));
};
if (!isMinMaxCmp(I)) {
switch (Pred) {
default:
break;
case ICmpInst::ICMP_ULT: {
if (Op1Min == Op0Max) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
const APInt *CmpC;
if (match(Op1, m_APInt(CmpC))) {
if (*CmpC == Op0Min + 1)
return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
ConstantInt::get(Op1->getType(), *CmpC - 1));
if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2())
return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
Constant::getNullValue(Op1->getType()));
}
break;
}
case ICmpInst::ICMP_UGT: {
if (Op1Max == Op0Min) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
const APInt *CmpC;
if (match(Op1, m_APInt(CmpC))) {
if (*CmpC == Op0Max - 1)
return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
ConstantInt::get(Op1->getType(), *CmpC + 1));
if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits())
return new ICmpInst(ICmpInst::ICMP_NE, Op0,
Constant::getNullValue(Op1->getType()));
}
break;
}
case ICmpInst::ICMP_SLT: {
if (Op1Min == Op0Max) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
const APInt *CmpC;
if (match(Op1, m_APInt(CmpC))) {
if (*CmpC == Op0Min + 1) return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
ConstantInt::get(Op1->getType(), *CmpC - 1));
}
break;
}
case ICmpInst::ICMP_SGT: {
if (Op1Max == Op0Min) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
const APInt *CmpC;
if (match(Op1, m_APInt(CmpC))) {
if (*CmpC == Op0Max - 1) return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
ConstantInt::get(Op1->getType(), *CmpC + 1));
}
break;
}
}
}
switch (Pred) {
default:
llvm_unreachable("Unknown icmp opcode!");
case ICmpInst::ICMP_EQ:
case ICmpInst::ICMP_NE: {
if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max))
return replaceInstUsesWith(
I, ConstantInt::getBool(I.getType(), Pred == CmpInst::ICMP_NE));
APInt Op0KnownZeroInverted = ~Op0Known.Zero;
if (Op1Known.isZero()) {
Value *LHS = nullptr;
const APInt *LHSC;
if (!match(Op0, m_And(m_Value(LHS), m_APInt(LHSC))) ||
*LHSC != Op0KnownZeroInverted)
LHS = Op0;
Value *X;
const APInt *C1;
if (match(LHS, m_Shl(m_Power2(C1), m_Value(X)))) {
Type *XTy = X->getType();
unsigned Log2C1 = C1->countTrailingZeros();
APInt C2 = Op0KnownZeroInverted;
APInt C2Pow2 = (C2 & ~(*C1 - 1)) + *C1;
if (C2Pow2.isPowerOf2()) {
unsigned Log2C2 = C2Pow2.countTrailingZeros();
auto *CmpC = ConstantInt::get(XTy, Log2C2 - Log2C1);
auto NewPred =
Pred == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGE : CmpInst::ICMP_ULT;
return new ICmpInst(NewPred, X, CmpC);
}
}
}
break;
}
case ICmpInst::ICMP_ULT: {
if (Op0Max.ult(Op1Min)) return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Min.uge(Op1Max)) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
break;
}
case ICmpInst::ICMP_UGT: {
if (Op0Min.ugt(Op1Max)) return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Max.ule(Op1Min)) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
break;
}
case ICmpInst::ICMP_SLT: {
if (Op0Max.slt(Op1Min)) return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Min.sge(Op1Max)) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
break;
}
case ICmpInst::ICMP_SGT: {
if (Op0Min.sgt(Op1Max)) return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Max.sle(Op1Min)) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
break;
}
case ICmpInst::ICMP_SGE:
assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!");
if (Op0Min.sge(Op1Max)) return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Max.slt(Op1Min)) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
if (Op1Min == Op0Max) return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
break;
case ICmpInst::ICMP_SLE:
assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!");
if (Op0Max.sle(Op1Min)) return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Min.sgt(Op1Max)) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
if (Op1Max == Op0Min) return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
break;
case ICmpInst::ICMP_UGE:
assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!");
if (Op0Min.uge(Op1Max)) return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Max.ult(Op1Min)) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
if (Op1Min == Op0Max) return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
break;
case ICmpInst::ICMP_ULE:
assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!");
if (Op0Max.ule(Op1Min)) return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Min.ugt(Op1Max)) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
if (Op1Max == Op0Min) return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
break;
}
if (I.isSigned() &&
((Op0Known.Zero.isNegative() && Op1Known.Zero.isNegative()) ||
(Op0Known.One.isNegative() && Op1Known.One.isNegative())))
return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1);
return nullptr;
}
static Instruction *foldICmpUsingBoolRange(ICmpInst &I,
InstCombiner::BuilderTy &Builder) {
Value *X, *Y;
ICmpInst::Predicate Pred;
if (match(&I, m_c_ICmp(Pred, m_Value(X), m_OneUse(m_ZExt(m_Value(Y))))) &&
Y->getType()->isIntOrIntVectorTy(1) && Pred == ICmpInst::ICMP_ULT)
return BinaryOperator::CreateAnd(Builder.CreateIsNull(X), Y);
if (match(&I, m_c_ICmp(Pred, m_Value(X), m_OneUse(m_SExt(m_Value(Y))))) &&
Y->getType()->isIntOrIntVectorTy(1) && Pred == ICmpInst::ICMP_ULE)
return BinaryOperator::CreateOr(Builder.CreateIsNull(X), Y);
return nullptr;
}
llvm::Optional<std::pair<CmpInst::Predicate, Constant *>>
InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred,
Constant *C) {
assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) &&
"Only for relational integer predicates.");
Type *Type = C->getType();
bool IsSigned = ICmpInst::isSigned(Pred);
CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred);
bool WillIncrement =
UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT;
auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) {
return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned);
};
Constant *SafeReplacementConstant = nullptr;
if (auto *CI = dyn_cast<ConstantInt>(C)) {
if (!ConstantIsOk(CI))
return llvm::None;
} else if (auto *FVTy = dyn_cast<FixedVectorType>(Type)) {
unsigned NumElts = FVTy->getNumElements();
for (unsigned i = 0; i != NumElts; ++i) {
Constant *Elt = C->getAggregateElement(i);
if (!Elt)
return llvm::None;
if (isa<UndefValue>(Elt))
continue;
auto *CI = dyn_cast<ConstantInt>(Elt);
if (!CI || !ConstantIsOk(CI))
return llvm::None;
if (!SafeReplacementConstant)
SafeReplacementConstant = CI;
}
} else {
return llvm::None;
}
if (C->containsUndefOrPoisonElement()) {
assert(SafeReplacementConstant && "Replacement constant not set");
C = Constant::replaceUndefsWith(C, SafeReplacementConstant);
}
CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred);
Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true);
Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne);
return std::make_pair(NewPred, NewC);
}
static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) {
ICmpInst::Predicate Pred = I.getPredicate();
if (ICmpInst::isEquality(Pred) || !ICmpInst::isIntPredicate(Pred) ||
InstCombiner::isCanonicalPredicate(Pred))
return nullptr;
Value *Op0 = I.getOperand(0);
Value *Op1 = I.getOperand(1);
auto *Op1C = dyn_cast<Constant>(Op1);
if (!Op1C)
return nullptr;
auto FlippedStrictness =
InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, Op1C);
if (!FlippedStrictness)
return nullptr;
return new ICmpInst(FlippedStrictness->first, Op0, FlippedStrictness->second);
}
CmpInst *InstCombinerImpl::canonicalizeICmpPredicate(CmpInst &I) {
CmpInst::Predicate Pred = I.getPredicate();
if (InstCombiner::isCanonicalPredicate(Pred))
return nullptr;
if (!InstCombiner::canFreelyInvertAllUsersOf(&I, nullptr))
return nullptr;
I.setPredicate(CmpInst::getInversePredicate(Pred));
I.setName(I.getName() + ".not");
freelyInvertAllUsersOf(&I);
return &I;
}
static Instruction *canonicalizeICmpBool(ICmpInst &I,
InstCombiner::BuilderTy &Builder) {
Value *A = I.getOperand(0), *B = I.getOperand(1);
assert(A->getType()->isIntOrIntVectorTy(1) && "Bools only");
if (match(B, m_Zero())) {
switch (I.getPredicate()) {
case CmpInst::ICMP_EQ: case CmpInst::ICMP_ULE: case CmpInst::ICMP_SGE: return BinaryOperator::CreateNot(A);
default:
llvm_unreachable("ICmp i1 X, C not simplified as expected.");
}
} else if (match(B, m_One())) {
switch (I.getPredicate()) {
case CmpInst::ICMP_NE: case CmpInst::ICMP_ULT: case CmpInst::ICMP_SGT: return BinaryOperator::CreateNot(A);
default:
llvm_unreachable("ICmp i1 X, C not simplified as expected.");
}
}
switch (I.getPredicate()) {
default:
llvm_unreachable("Invalid icmp instruction!");
case ICmpInst::ICMP_EQ:
return BinaryOperator::CreateNot(Builder.CreateXor(A, B));
case ICmpInst::ICMP_NE:
return BinaryOperator::CreateXor(A, B);
case ICmpInst::ICMP_UGT:
std::swap(A, B);
LLVM_FALLTHROUGH;
case ICmpInst::ICMP_ULT:
return BinaryOperator::CreateAnd(Builder.CreateNot(A), B);
case ICmpInst::ICMP_SGT:
std::swap(A, B);
LLVM_FALLTHROUGH;
case ICmpInst::ICMP_SLT:
return BinaryOperator::CreateAnd(Builder.CreateNot(B), A);
case ICmpInst::ICMP_UGE:
std::swap(A, B);
LLVM_FALLTHROUGH;
case ICmpInst::ICMP_ULE:
return BinaryOperator::CreateOr(Builder.CreateNot(A), B);
case ICmpInst::ICMP_SGE:
std::swap(A, B);
LLVM_FALLTHROUGH;
case ICmpInst::ICMP_SLE:
return BinaryOperator::CreateOr(Builder.CreateNot(B), A);
}
}
static Instruction *foldICmpWithHighBitMask(ICmpInst &Cmp,
InstCombiner::BuilderTy &Builder) {
ICmpInst::Predicate Pred, NewPred;
Value *X, *Y;
if (match(&Cmp,
m_c_ICmp(Pred, m_OneUse(m_Shl(m_One(), m_Value(Y))), m_Value(X)))) {
switch (Pred) {
case ICmpInst::ICMP_ULE:
NewPred = ICmpInst::ICMP_NE;
break;
case ICmpInst::ICMP_UGT:
NewPred = ICmpInst::ICMP_EQ;
break;
default:
return nullptr;
}
} else if (match(&Cmp, m_c_ICmp(Pred,
m_OneUse(m_CombineOr(
m_Not(m_Shl(m_AllOnes(), m_Value(Y))),
m_Add(m_Shl(m_One(), m_Value(Y)),
m_AllOnes()))),
m_Value(X)))) {
switch (Pred) {
case ICmpInst::ICMP_ULT:
NewPred = ICmpInst::ICMP_NE;
break;
case ICmpInst::ICMP_UGE:
NewPred = ICmpInst::ICMP_EQ;
break;
default:
return nullptr;
}
} else
return nullptr;
Value *NewX = Builder.CreateLShr(X, Y, X->getName() + ".highbits");
Constant *Zero = Constant::getNullValue(NewX->getType());
return CmpInst::Create(Instruction::ICmp, NewPred, NewX, Zero);
}
static Instruction *foldVectorCmp(CmpInst &Cmp,
InstCombiner::BuilderTy &Builder) {
const CmpInst::Predicate Pred = Cmp.getPredicate();
Value *LHS = Cmp.getOperand(0), *RHS = Cmp.getOperand(1);
Value *V1, *V2;
ArrayRef<int> M;
if (!match(LHS, m_Shuffle(m_Value(V1), m_Undef(), m_Mask(M))))
return nullptr;
Type *V1Ty = V1->getType();
if (match(RHS, m_Shuffle(m_Value(V2), m_Undef(), m_SpecificMask(M))) &&
V1Ty == V2->getType() && (LHS->hasOneUse() || RHS->hasOneUse())) {
Value *NewCmp = Builder.CreateCmp(Pred, V1, V2);
return new ShuffleVectorInst(NewCmp, M);
}
Constant *C;
if (!LHS->hasOneUse() || !match(RHS, m_Constant(C)))
return nullptr;
Constant *ScalarC = C->getSplatValue( true);
int MaskSplatIndex;
if (ScalarC && match(M, m_SplatOrUndefMask(MaskSplatIndex))) {
C = ConstantVector::getSplat(cast<VectorType>(V1Ty)->getElementCount(),
ScalarC);
SmallVector<int, 8> NewM(M.size(), MaskSplatIndex);
Value *NewCmp = Builder.CreateCmp(Pred, V1, C);
return new ShuffleVectorInst(NewCmp, NewM);
}
return nullptr;
}
static Instruction *foldICmpOfUAddOv(ICmpInst &I) {
CmpInst::Predicate Pred = I.getPredicate();
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
Value *UAddOv;
Value *A, *B;
auto UAddOvResultPat = m_ExtractValue<0>(
m_Intrinsic<Intrinsic::uadd_with_overflow>(m_Value(A), m_Value(B)));
if (match(Op0, UAddOvResultPat) &&
((Pred == ICmpInst::ICMP_ULT && (Op1 == A || Op1 == B)) ||
(Pred == ICmpInst::ICMP_EQ && match(Op1, m_ZeroInt()) &&
(match(A, m_One()) || match(B, m_One()))) ||
(Pred == ICmpInst::ICMP_NE && match(Op1, m_AllOnes()) &&
(match(A, m_AllOnes()) || match(B, m_AllOnes())))))
UAddOv = cast<ExtractValueInst>(Op0)->getAggregateOperand();
else if (match(Op1, UAddOvResultPat) &&
Pred == ICmpInst::ICMP_UGT && (Op0 == A || Op0 == B))
UAddOv = cast<ExtractValueInst>(Op1)->getAggregateOperand();
else
return nullptr;
return ExtractValueInst::Create(UAddOv, 1);
}
static Instruction *foldICmpInvariantGroup(ICmpInst &I) {
if (!I.getOperand(0)->getType()->isPointerTy() ||
NullPointerIsDefined(
I.getParent()->getParent(),
I.getOperand(0)->getType()->getPointerAddressSpace())) {
return nullptr;
}
Instruction *Op;
if (match(I.getOperand(0), m_Instruction(Op)) &&
match(I.getOperand(1), m_Zero()) &&
Op->isLaunderOrStripInvariantGroup()) {
return ICmpInst::Create(Instruction::ICmp, I.getPredicate(),
Op->getOperand(0), I.getOperand(1));
}
return nullptr;
}
static Instruction *foldReductionIdiom(ICmpInst &I,
InstCombiner::BuilderTy &Builder,
const DataLayout &DL) {
if (I.getType()->isVectorTy())
return nullptr;
ICmpInst::Predicate OuterPred, InnerPred;
Value *LHS, *RHS;
if (!match(&I, m_ICmp(OuterPred,
m_OneUse(m_BitCast(m_OneUse(
m_ICmp(InnerPred, m_Value(LHS), m_Value(RHS))))),
m_Zero())))
return nullptr;
auto *LHSTy = dyn_cast<FixedVectorType>(LHS->getType());
if (!LHSTy || !LHSTy->getElementType()->isIntegerTy())
return nullptr;
unsigned NumBits =
LHSTy->getNumElements() * LHSTy->getElementType()->getIntegerBitWidth();
if (!DL.isLegalInteger(NumBits))
return nullptr;
if (ICmpInst::isEquality(OuterPred) && InnerPred == ICmpInst::ICMP_NE) {
auto *ScalarTy = Builder.getIntNTy(NumBits);
LHS = Builder.CreateBitCast(LHS, ScalarTy, LHS->getName() + ".scalar");
RHS = Builder.CreateBitCast(RHS, ScalarTy, RHS->getName() + ".scalar");
return ICmpInst::Create(Instruction::ICmp, OuterPred, LHS, RHS,
I.getName());
}
return nullptr;
}
Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
bool Changed = false;
const SimplifyQuery Q = SQ.getWithInstruction(&I);
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
unsigned Op0Cplxity = getComplexity(Op0);
unsigned Op1Cplxity = getComplexity(Op1);
if (Op0Cplxity < Op1Cplxity ||
(Op0Cplxity == Op1Cplxity && swapMayExposeCSEOpportunities(Op0, Op1))) {
I.swapOperands();
std::swap(Op0, Op1);
Changed = true;
}
if (Value *V = simplifyICmpInst(I.getPredicate(), Op0, Op1, Q))
return replaceInstUsesWith(I, V);
if (I.getPredicate() == ICmpInst::ICMP_NE && match(Op1, m_Zero())) {
Value *Cond, *SelectTrue, *SelectFalse;
if (match(Op0, m_Select(m_Value(Cond), m_Value(SelectTrue),
m_Value(SelectFalse)))) {
if (Value *V = dyn_castNegVal(SelectTrue)) {
if (V == SelectFalse)
return CmpInst::Create(Instruction::ICmp, I.getPredicate(), V, Op1);
}
else if (Value *V = dyn_castNegVal(SelectFalse)) {
if (V == SelectTrue)
return CmpInst::Create(Instruction::ICmp, I.getPredicate(), V, Op1);
}
}
}
if (Op0->getType()->isIntOrIntVectorTy(1))
if (Instruction *Res = canonicalizeICmpBool(I, Builder))
return Res;
if (Instruction *Res = canonicalizeCmpWithConstant(I))
return Res;
if (Instruction *Res = canonicalizeICmpPredicate(I))
return Res;
if (Instruction *Res = foldICmpWithConstant(I))
return Res;
if (Instruction *Res = foldICmpWithDominatingICmp(I))
return Res;
if (Instruction *Res = foldICmpUsingBoolRange(I, Builder))
return Res;
if (Instruction *Res = foldICmpUsingKnownBits(I))
return Res;
if (I.hasOneUse())
if (SelectInst *SI = dyn_cast<SelectInst>(I.user_back())) {
Value *A, *B;
SelectPatternResult SPR = matchSelectPattern(SI, A, B);
if (SPR.Flavor != SPF_UNKNOWN)
return nullptr;
}
if (Instruction *Res = foldICmpWithZero(I))
return Res;
ICmpInst::Predicate Pred = I.getPredicate();
const APInt *C;
if (match(Op1, m_APInt(C))) {
if (Pred == ICmpInst::ICMP_UGT && C->isMaxSignedValue()) {
Constant *Zero = Constant::getNullValue(Op0->getType());
return new ICmpInst(ICmpInst::ICMP_SLT, Op0, Zero);
}
if (Pred == ICmpInst::ICMP_ULT && C->isMinSignedValue()) {
Constant *AllOnes = Constant::getAllOnesValue(Op0->getType());
return new ICmpInst(ICmpInst::ICMP_SGT, Op0, AllOnes);
}
}
if (Instruction *Res = foldICmpBinOp(I, Q))
return Res;
if (Instruction *Res = foldICmpInstWithConstant(I))
return Res;
if (Instruction *New = foldSignBitTest(I))
return New;
if (Instruction *Res = foldICmpInstWithConstantNotInt(I))
return Res;
if (auto *GEP = dyn_cast<GEPOperator>(Op0))
if (Instruction *NI = foldGEPICmp(GEP, Op1, I.getPredicate(), I))
return NI;
if (auto *GEP = dyn_cast<GEPOperator>(Op1))
if (Instruction *NI = foldGEPICmp(GEP, Op0, I.getSwappedPredicate(), I))
return NI;
if (auto *SI = dyn_cast<SelectInst>(Op0))
if (Instruction *NI = foldSelectICmp(I.getPredicate(), SI, Op1, I))
return NI;
if (auto *SI = dyn_cast<SelectInst>(Op1))
if (Instruction *NI = foldSelectICmp(I.getSwappedPredicate(), SI, Op0, I))
return NI;
if (Op0->getType()->isPointerTy() && I.isEquality()) {
assert(Op1->getType()->isPointerTy() && "Comparing pointer with non-pointer?");
if (auto *Alloca = dyn_cast<AllocaInst>(getUnderlyingObject(Op0)))
if (Instruction *New = foldAllocaCmp(I, Alloca))
return New;
if (auto *Alloca = dyn_cast<AllocaInst>(getUnderlyingObject(Op1)))
if (Instruction *New = foldAllocaCmp(I, Alloca))
return New;
}
if (Instruction *Res = foldICmpBitCast(I))
return Res;
if (Instruction *R = foldICmpWithCastOp(I))
return R;
if (Instruction *Res = foldICmpWithMinMax(I))
return Res;
{
Value *A, *B;
if (match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) &&
match(Op1, m_Zero()) &&
isKnownToBeAPowerOfTwo(A, false, 0, &I) && I.isEquality())
return new ICmpInst(I.getInversePredicate(), Builder.CreateAnd(A, B),
Op1);
if (match(Op0, m_Not(m_Value(A)))) {
if (match(Op1, m_Not(m_Value(B))))
return new ICmpInst(I.getPredicate(), B, A);
const APInt *C;
if (match(Op1, m_APInt(C)))
return new ICmpInst(I.getSwappedPredicate(), A,
ConstantInt::get(Op1->getType(), ~(*C)));
}
Instruction *AddI = nullptr;
if (match(&I, m_UAddWithOverflow(m_Value(A), m_Value(B),
m_Instruction(AddI))) &&
isa<IntegerType>(A->getType())) {
Value *Result;
Constant *Overflow;
if (AddI->getOpcode() == Instruction::Add &&
OptimizeOverflowCheck(Instruction::Add, false, A, B, *AddI,
Result, Overflow)) {
replaceInstUsesWith(*AddI, Result);
eraseInstFromFunction(*AddI);
return replaceInstUsesWith(I, Overflow);
}
}
if (match(Op0, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) {
if (Instruction *R = processUMulZExtIdiom(I, Op0, Op1, *this))
return R;
}
if (match(Op1, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) {
if (Instruction *R = processUMulZExtIdiom(I, Op1, Op0, *this))
return R;
}
}
if (Instruction *Res = foldICmpEquality(I))
return Res;
if (Instruction *Res = foldICmpOfUAddOv(I))
return Res;
if (I.getPredicate() == ICmpInst::ICMP_EQ)
if (auto *EVI = dyn_cast<ExtractValueInst>(Op0))
if (auto *ACXI = dyn_cast<AtomicCmpXchgInst>(EVI->getAggregateOperand()))
if (EVI->getIndices()[0] == 0 && ACXI->getCompareOperand() == Op1 &&
!ACXI->isWeak())
return ExtractValueInst::Create(ACXI, 1);
{
Value *X;
const APInt *C;
if (match(Op0, m_Add(m_Value(X), m_APInt(C))) && Op1 == X)
return foldICmpAddOpConst(X, *C, I.getPredicate());
if (match(Op1, m_Add(m_Value(X), m_APInt(C))) && Op0 == X)
return foldICmpAddOpConst(X, *C, I.getSwappedPredicate());
}
if (Instruction *Res = foldICmpWithHighBitMask(I, Builder))
return Res;
if (I.getType()->isVectorTy())
if (Instruction *Res = foldVectorCmp(I, Builder))
return Res;
if (Instruction *Res = foldICmpInvariantGroup(I))
return Res;
if (Instruction *Res = foldReductionIdiom(I, Builder, DL))
return Res;
return Changed ? &I : nullptr;
}
Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I,
Instruction *LHSI,
Constant *RHSC) {
if (!isa<ConstantFP>(RHSC)) return nullptr;
const APFloat &RHS = cast<ConstantFP>(RHSC)->getValueAPF();
int MantissaWidth = LHSI->getType()->getFPMantissaWidth();
if (MantissaWidth == -1) return nullptr;
IntegerType *IntTy = cast<IntegerType>(LHSI->getOperand(0)->getType());
bool LHSUnsigned = isa<UIToFPInst>(LHSI);
if (I.isEquality()) {
FCmpInst::Predicate P = I.getPredicate();
bool IsExact = false;
APSInt RHSCvt(IntTy->getBitWidth(), LHSUnsigned);
RHS.convertToInteger(RHSCvt, APFloat::rmNearestTiesToEven, &IsExact);
if (!IsExact) {
APFloat RHSRoundInt(RHS);
RHSRoundInt.roundToIntegral(APFloat::rmNearestTiesToEven);
if (RHS != RHSRoundInt) {
if (P == FCmpInst::FCMP_OEQ || P == FCmpInst::FCMP_UEQ)
return replaceInstUsesWith(I, Builder.getFalse());
assert(P == FCmpInst::FCMP_ONE || P == FCmpInst::FCMP_UNE);
return replaceInstUsesWith(I, Builder.getTrue());
}
}
}
unsigned InputSize = IntTy->getScalarSizeInBits();
if ((int)InputSize > MantissaWidth) {
int Exp = ilogb(RHS);
if (Exp == APFloat::IEK_Inf) {
int MaxExponent = ilogb(APFloat::getLargest(RHS.getSemantics()));
if (MaxExponent < (int)InputSize - !LHSUnsigned)
return nullptr;
} else {
if (MantissaWidth <= Exp && Exp <= (int)InputSize - !LHSUnsigned)
return nullptr;
}
}
assert(!RHS.isNaN() && "NaN comparison not already folded!");
ICmpInst::Predicate Pred;
switch (I.getPredicate()) {
default: llvm_unreachable("Unexpected predicate!");
case FCmpInst::FCMP_UEQ:
case FCmpInst::FCMP_OEQ:
Pred = ICmpInst::ICMP_EQ;
break;
case FCmpInst::FCMP_UGT:
case FCmpInst::FCMP_OGT:
Pred = LHSUnsigned ? ICmpInst::ICMP_UGT : ICmpInst::ICMP_SGT;
break;
case FCmpInst::FCMP_UGE:
case FCmpInst::FCMP_OGE:
Pred = LHSUnsigned ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_SGE;
break;
case FCmpInst::FCMP_ULT:
case FCmpInst::FCMP_OLT:
Pred = LHSUnsigned ? ICmpInst::ICMP_ULT : ICmpInst::ICMP_SLT;
break;
case FCmpInst::FCMP_ULE:
case FCmpInst::FCMP_OLE:
Pred = LHSUnsigned ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_SLE;
break;
case FCmpInst::FCMP_UNE:
case FCmpInst::FCMP_ONE:
Pred = ICmpInst::ICMP_NE;
break;
case FCmpInst::FCMP_ORD:
return replaceInstUsesWith(I, Builder.getTrue());
case FCmpInst::FCMP_UNO:
return replaceInstUsesWith(I, Builder.getFalse());
}
unsigned IntWidth = IntTy->getScalarSizeInBits();
if (!LHSUnsigned) {
APFloat SMax(RHS.getSemantics());
SMax.convertFromAPInt(APInt::getSignedMaxValue(IntWidth), true,
APFloat::rmNearestTiesToEven);
if (SMax < RHS) { if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SLT ||
Pred == ICmpInst::ICMP_SLE)
return replaceInstUsesWith(I, Builder.getTrue());
return replaceInstUsesWith(I, Builder.getFalse());
}
} else {
APFloat UMax(RHS.getSemantics());
UMax.convertFromAPInt(APInt::getMaxValue(IntWidth), false,
APFloat::rmNearestTiesToEven);
if (UMax < RHS) { if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_ULT ||
Pred == ICmpInst::ICMP_ULE)
return replaceInstUsesWith(I, Builder.getTrue());
return replaceInstUsesWith(I, Builder.getFalse());
}
}
if (!LHSUnsigned) {
APFloat SMin(RHS.getSemantics());
SMin.convertFromAPInt(APInt::getSignedMinValue(IntWidth), true,
APFloat::rmNearestTiesToEven);
if (SMin > RHS) { if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SGT ||
Pred == ICmpInst::ICMP_SGE)
return replaceInstUsesWith(I, Builder.getTrue());
return replaceInstUsesWith(I, Builder.getFalse());
}
} else {
APFloat UMin(RHS.getSemantics());
UMin.convertFromAPInt(APInt::getMinValue(IntWidth), false,
APFloat::rmNearestTiesToEven);
if (UMin > RHS) { if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_UGT ||
Pred == ICmpInst::ICMP_UGE)
return replaceInstUsesWith(I, Builder.getTrue());
return replaceInstUsesWith(I, Builder.getFalse());
}
}
Constant *RHSInt = LHSUnsigned
? ConstantExpr::getFPToUI(RHSC, IntTy)
: ConstantExpr::getFPToSI(RHSC, IntTy);
if (!RHS.isZero()) {
bool Equal = LHSUnsigned
? ConstantExpr::getUIToFP(RHSInt, RHSC->getType()) == RHSC
: ConstantExpr::getSIToFP(RHSInt, RHSC->getType()) == RHSC;
if (!Equal) {
switch (Pred) {
default: llvm_unreachable("Unexpected integer comparison!");
case ICmpInst::ICMP_NE: return replaceInstUsesWith(I, Builder.getTrue());
case ICmpInst::ICMP_EQ: return replaceInstUsesWith(I, Builder.getFalse());
case ICmpInst::ICMP_ULE:
if (RHS.isNegative())
return replaceInstUsesWith(I, Builder.getFalse());
break;
case ICmpInst::ICMP_SLE:
if (RHS.isNegative())
Pred = ICmpInst::ICMP_SLT;
break;
case ICmpInst::ICMP_ULT:
if (RHS.isNegative())
return replaceInstUsesWith(I, Builder.getFalse());
Pred = ICmpInst::ICMP_ULE;
break;
case ICmpInst::ICMP_SLT:
if (!RHS.isNegative())
Pred = ICmpInst::ICMP_SLE;
break;
case ICmpInst::ICMP_UGT:
if (RHS.isNegative())
return replaceInstUsesWith(I, Builder.getTrue());
break;
case ICmpInst::ICMP_SGT:
if (RHS.isNegative())
Pred = ICmpInst::ICMP_SGE;
break;
case ICmpInst::ICMP_UGE:
if (RHS.isNegative())
return replaceInstUsesWith(I, Builder.getTrue());
Pred = ICmpInst::ICMP_UGT;
break;
case ICmpInst::ICMP_SGE:
if (!RHS.isNegative())
Pred = ICmpInst::ICMP_SGT;
break;
}
}
}
return new ICmpInst(Pred, LHSI->getOperand(0), RHSInt);
}
static Instruction *foldFCmpReciprocalAndZero(FCmpInst &I, Instruction *LHSI,
Constant *RHSC) {
FCmpInst::Predicate Pred = I.getPredicate();
if ((Pred != FCmpInst::FCMP_OGT) && (Pred != FCmpInst::FCMP_OLT) &&
(Pred != FCmpInst::FCMP_OGE) && (Pred != FCmpInst::FCMP_OLE))
return nullptr;
if (!match(RHSC, m_AnyZeroFP()))
return nullptr;
if (!LHSI->hasNoInfs() || !I.hasNoInfs())
return nullptr;
const APFloat *C;
if (!match(LHSI->getOperand(0), m_APFloat(C)))
return nullptr;
if (C->isZero())
return nullptr;
if (C->isNegative())
Pred = I.getSwappedPredicate();
return new FCmpInst(Pred, LHSI->getOperand(1), RHSC, "", &I);
}
static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) {
Value *X;
if (!match(I.getOperand(0), m_FAbs(m_Value(X))) ||
!match(I.getOperand(1), m_PosZeroFP()))
return nullptr;
auto replacePredAndOp0 = [&IC](FCmpInst *I, FCmpInst::Predicate P, Value *X) {
I->setPredicate(P);
return IC.replaceOperand(*I, 0, X);
};
switch (I.getPredicate()) {
case FCmpInst::FCMP_UGE:
case FCmpInst::FCMP_OLT:
llvm_unreachable("fcmp should have simplified");
case FCmpInst::FCMP_OGT:
return replacePredAndOp0(&I, FCmpInst::FCMP_ONE, X);
case FCmpInst::FCMP_UGT:
return replacePredAndOp0(&I, FCmpInst::FCMP_UNE, X);
case FCmpInst::FCMP_OLE:
return replacePredAndOp0(&I, FCmpInst::FCMP_OEQ, X);
case FCmpInst::FCMP_ULE:
return replacePredAndOp0(&I, FCmpInst::FCMP_UEQ, X);
case FCmpInst::FCMP_OGE:
assert(!I.hasNoNaNs() && "fcmp should have simplified");
return replacePredAndOp0(&I, FCmpInst::FCMP_ORD, X);
case FCmpInst::FCMP_ULT:
assert(!I.hasNoNaNs() && "fcmp should have simplified");
return replacePredAndOp0(&I, FCmpInst::FCMP_UNO, X);
case FCmpInst::FCMP_OEQ:
case FCmpInst::FCMP_UEQ:
case FCmpInst::FCMP_ONE:
case FCmpInst::FCMP_UNE:
case FCmpInst::FCMP_ORD:
case FCmpInst::FCMP_UNO:
return replacePredAndOp0(&I, I.getPredicate(), X);
default:
return nullptr;
}
}
static Instruction *foldFCmpFNegCommonOp(FCmpInst &I) {
CmpInst::Predicate Pred = I.getPredicate();
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
if (match(Op0, m_FNeg(m_Value())) && !match(Op1, m_FNeg(m_Value()))) {
std::swap(Op0, Op1);
Pred = I.getSwappedPredicate();
}
if (!match(Op1, m_FNeg(m_Specific(Op0))))
return nullptr;
Constant *Zero = ConstantFP::getNullValue(Op0->getType());
return new FCmpInst(Pred, Op0, Zero, "", &I);
}
Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
bool Changed = false;
if (getComplexity(I.getOperand(0)) < getComplexity(I.getOperand(1))) {
I.swapOperands();
Changed = true;
}
const CmpInst::Predicate Pred = I.getPredicate();
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
if (Value *V = simplifyFCmpInst(Pred, Op0, Op1, I.getFastMathFlags(),
SQ.getWithInstruction(&I)))
return replaceInstUsesWith(I, V);
Type *OpType = Op0->getType();
assert(OpType == Op1->getType() && "fcmp with different-typed operands?");
if (Op0 == Op1) {
switch (Pred) {
default: break;
case FCmpInst::FCMP_UNO: case FCmpInst::FCMP_ULT: case FCmpInst::FCMP_UGT: case FCmpInst::FCMP_UNE: I.setPredicate(FCmpInst::FCMP_UNO);
I.setOperand(1, Constant::getNullValue(OpType));
return &I;
case FCmpInst::FCMP_ORD: case FCmpInst::FCMP_OEQ: case FCmpInst::FCMP_OGE: case FCmpInst::FCMP_OLE: I.setPredicate(FCmpInst::FCMP_ORD);
I.setOperand(1, Constant::getNullValue(OpType));
return &I;
}
}
if (Pred == CmpInst::FCMP_ORD || Pred == CmpInst::FCMP_UNO) {
if (!match(Op0, m_PosZeroFP()) && isKnownNeverNaN(Op0, &TLI))
return replaceOperand(I, 0, ConstantFP::getNullValue(OpType));
if (!match(Op1, m_PosZeroFP()) && isKnownNeverNaN(Op1, &TLI))
return replaceOperand(I, 1, ConstantFP::getNullValue(OpType));
}
Value *X, *Y;
if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y))))
return new FCmpInst(I.getSwappedPredicate(), X, Y, "", &I);
if (Instruction *R = foldFCmpFNegCommonOp(I))
return R;
if (I.hasOneUse())
if (SelectInst *SI = dyn_cast<SelectInst>(I.user_back())) {
Value *A, *B;
SelectPatternResult SPR = matchSelectPattern(SI, A, B);
if (SPR.Flavor != SPF_UNKNOWN)
return nullptr;
}
if (match(Op1, m_AnyZeroFP()) && !match(Op1, m_PosZeroFP()))
return replaceOperand(I, 1, ConstantFP::getNullValue(OpType));
Instruction *LHSI;
Constant *RHSC;
if (match(Op0, m_Instruction(LHSI)) && match(Op1, m_Constant(RHSC))) {
switch (LHSI->getOpcode()) {
case Instruction::PHI:
if (LHSI->getParent() == I.getParent())
if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI)))
return NV;
break;
case Instruction::SIToFP:
case Instruction::UIToFP:
if (Instruction *NV = foldFCmpIntToFPConst(I, LHSI, RHSC))
return NV;
break;
case Instruction::FDiv:
if (Instruction *NV = foldFCmpReciprocalAndZero(I, LHSI, RHSC))
return NV;
break;
case Instruction::Load:
if (auto *GEP = dyn_cast<GetElementPtrInst>(LHSI->getOperand(0)))
if (auto *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0)))
if (Instruction *Res = foldCmpLoadFromIndexedGlobal(
cast<LoadInst>(LHSI), GEP, GV, I))
return Res;
break;
}
}
if (Instruction *R = foldFabsWithFcmpZero(I, *this))
return R;
if (match(Op0, m_FNeg(m_Value(X)))) {
Constant *C;
if (match(Op1, m_Constant(C))) {
Constant *NegC = ConstantExpr::getFNeg(C);
return new FCmpInst(I.getSwappedPredicate(), X, NegC, "", &I);
}
}
if (match(Op0, m_FPExt(m_Value(X)))) {
if (match(Op1, m_FPExt(m_Value(Y))) && X->getType() == Y->getType())
return new FCmpInst(Pred, X, Y, "", &I);
const APFloat *C;
if (match(Op1, m_APFloat(C))) {
const fltSemantics &FPSem =
X->getType()->getScalarType()->getFltSemantics();
bool Lossy;
APFloat TruncC = *C;
TruncC.convert(FPSem, APFloat::rmNearestTiesToEven, &Lossy);
if (Lossy) {
switch (Pred) {
case FCmpInst::FCMP_OEQ:
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
case FCmpInst::FCMP_ONE:
return new FCmpInst(FCmpInst::FCMP_ORD, X,
ConstantFP::getNullValue(X->getType()));
case FCmpInst::FCMP_UEQ:
return new FCmpInst(FCmpInst::FCMP_UNO, X,
ConstantFP::getNullValue(X->getType()));
case FCmpInst::FCMP_UNE:
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
default:
break;
}
}
APFloat Fabs = TruncC;
Fabs.clearSign();
if (!Lossy &&
(!(Fabs < APFloat::getSmallestNormalized(FPSem)) || Fabs.isZero())) {
Constant *NewC = ConstantFP::get(X->getType(), TruncC);
return new FCmpInst(Pred, X, NewC, "", &I);
}
}
}
const APFloat *C;
if (match(Op0, m_OneUse(m_Intrinsic<Intrinsic::copysign>(m_APFloat(C),
m_Value(X)))) &&
match(Op1, m_AnyZeroFP()) && !C->isZero() && !C->isNaN()) {
Type *IntType = Builder.getIntNTy(X->getType()->getScalarSizeInBits());
if (auto *VecTy = dyn_cast<VectorType>(OpType))
IntType = VectorType::get(IntType, VecTy->getElementCount());
if (Pred == FCmpInst::FCMP_OLT) {
Value *IntX = Builder.CreateBitCast(X, IntType);
return new ICmpInst(ICmpInst::ICMP_SLT, IntX,
ConstantInt::getNullValue(IntType));
}
}
if (I.getType()->isVectorTy())
if (Instruction *Res = foldVectorCmp(I, Builder))
return Res;
return Changed ? &I : nullptr;
}