#ifndef LLVM_TRANSFORMS_INSTCOMBINE_INSTCOMBINER_H
#define LLVM_TRANSFORMS_INSTCOMBINE_INSTCOMBINER_H
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/TargetFolder.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/KnownBits.h"
#include <cassert>
#define DEBUG_TYPE "instcombine"
#include "llvm/Transforms/Utils/InstructionWorklist.h"
namespace llvm {
class AAResults;
class AssumptionCache;
class ProfileSummaryInfo;
class TargetLibraryInfo;
class TargetTransformInfo;
class LLVM_LIBRARY_VISIBILITY InstCombiner {
TargetTransformInfo &TTI;
public:
uint64_t MaxArraySizeForCombine = 0;
using BuilderTy = IRBuilder<TargetFolder, IRBuilderCallbackInserter>;
BuilderTy &Builder;
protected:
InstructionWorklist &Worklist;
const bool MinimizeSize;
AAResults *AA;
AssumptionCache &AC;
TargetLibraryInfo &TLI;
DominatorTree &DT;
const DataLayout &DL;
const SimplifyQuery SQ;
OptimizationRemarkEmitter &ORE;
BlockFrequencyInfo *BFI;
ProfileSummaryInfo *PSI;
LoopInfo *LI;
bool MadeIRChange = false;
public:
InstCombiner(InstructionWorklist &Worklist, BuilderTy &Builder,
bool MinimizeSize, AAResults *AA, AssumptionCache &AC,
TargetLibraryInfo &TLI, TargetTransformInfo &TTI,
DominatorTree &DT, OptimizationRemarkEmitter &ORE,
BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI,
const DataLayout &DL, LoopInfo *LI)
: TTI(TTI), Builder(Builder), Worklist(Worklist),
MinimizeSize(MinimizeSize), AA(AA), AC(AC), TLI(TLI), DT(DT), DL(DL),
SQ(DL, &TLI, &DT, &AC), ORE(ORE), BFI(BFI), PSI(PSI), LI(LI) {}
virtual ~InstCombiner() = default;
static Value *peekThroughBitcast(Value *V, bool OneUseOnly = false) {
if (auto *BitCast = dyn_cast<BitCastInst>(V))
if (!OneUseOnly || BitCast->hasOneUse())
return BitCast->getOperand(0);
return V;
}
static unsigned getComplexity(Value *V) {
if (isa<Instruction>(V)) {
if (isa<CastInst>(V) || match(V, m_Neg(PatternMatch::m_Value())) ||
match(V, m_Not(PatternMatch::m_Value())) ||
match(V, m_FNeg(PatternMatch::m_Value())))
return 4;
return 5;
}
if (isa<Argument>(V))
return 3;
return isa<Constant>(V) ? (isa<UndefValue>(V) ? 0 : 1) : 2;
}
static bool isCanonicalPredicate(CmpInst::Predicate Pred) {
switch (Pred) {
case CmpInst::ICMP_NE:
case CmpInst::ICMP_ULE:
case CmpInst::ICMP_SLE:
case CmpInst::ICMP_UGE:
case CmpInst::ICMP_SGE:
case CmpInst::FCMP_ONE:
case CmpInst::FCMP_OLE:
case CmpInst::FCMP_OGE:
return false;
default:
return true;
}
}
static bool isSignBitCheck(ICmpInst::Predicate Pred, const APInt &RHS,
bool &TrueIfSigned) {
switch (Pred) {
case ICmpInst::ICMP_SLT: TrueIfSigned = true;
return RHS.isZero();
case ICmpInst::ICMP_SLE: TrueIfSigned = true;
return RHS.isAllOnes();
case ICmpInst::ICMP_SGT: TrueIfSigned = false;
return RHS.isAllOnes();
case ICmpInst::ICMP_SGE: TrueIfSigned = false;
return RHS.isZero();
case ICmpInst::ICMP_UGT:
TrueIfSigned = true;
return RHS.isMaxSignedValue();
case ICmpInst::ICMP_UGE:
TrueIfSigned = true;
return RHS.isMinSignedValue();
case ICmpInst::ICMP_ULT:
TrueIfSigned = false;
return RHS.isMinSignedValue();
case ICmpInst::ICMP_ULE:
TrueIfSigned = false;
return RHS.isMaxSignedValue();
default:
return false;
}
}
static Constant *AddOne(Constant *C) {
return ConstantExpr::getAdd(C, ConstantInt::get(C->getType(), 1));
}
static Constant *SubOne(Constant *C) {
return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1));
}
llvm::Optional<std::pair<
CmpInst::Predicate,
Constant *>> static getFlippedStrictnessPredicateAndConstant(CmpInst::
Predicate
Pred,
Constant *C);
static bool shouldAvoidAbsorbingNotIntoSelect(const SelectInst &SI) {
return match(&SI, PatternMatch::m_LogicalAnd(PatternMatch::m_Value(),
PatternMatch::m_Value())) ||
match(&SI, PatternMatch::m_LogicalOr(PatternMatch::m_Value(),
PatternMatch::m_Value()));
}
static bool isFreeToInvert(Value *V, bool WillInvertAllUses) {
if (match(V, m_Not(PatternMatch::m_Value())))
return true;
if (match(V, PatternMatch::m_AnyIntegralConstant()))
return true;
if (isa<CmpInst>(V))
return WillInvertAllUses;
if (match(V, m_Add(PatternMatch::m_Value(), PatternMatch::m_ImmConstant())))
return WillInvertAllUses;
if (match(V, m_Sub(PatternMatch::m_ImmConstant(), PatternMatch::m_Value())))
return WillInvertAllUses;
if (match(V,
m_Select(PatternMatch::m_Value(), m_Not(PatternMatch::m_Value()),
m_Not(PatternMatch::m_Value()))))
return WillInvertAllUses;
if (match(V, m_MaxOrMin(m_Not(PatternMatch::m_Value()),
m_Not(PatternMatch::m_Value()))))
return WillInvertAllUses;
return false;
}
static bool canFreelyInvertAllUsersOf(Value *V, Value *IgnoredUser) {
for (Use &U : V->uses()) {
if (U.getUser() == IgnoredUser)
continue;
auto *I = cast<Instruction>(U.getUser());
switch (I->getOpcode()) {
case Instruction::Select:
if (U.getOperandNo() != 0) return false;
if (shouldAvoidAbsorbingNotIntoSelect(*cast<SelectInst>(I)))
return false;
break;
case Instruction::Br:
assert(U.getOperandNo() == 0 && "Must be branching on that value.");
break; case Instruction::Xor: if (!match(I, m_Not(PatternMatch::m_Value())))
return false; break;
default:
return false; }
}
return true; }
static Constant *
getSafeVectorConstantForBinop(BinaryOperator::BinaryOps Opcode, Constant *In,
bool IsRHSConstant) {
auto *InVTy = cast<FixedVectorType>(In->getType());
Type *EltTy = InVTy->getElementType();
auto *SafeC = ConstantExpr::getBinOpIdentity(Opcode, EltTy, IsRHSConstant);
if (!SafeC) {
if (IsRHSConstant) {
switch (Opcode) {
case Instruction::SRem: case Instruction::URem: SafeC = ConstantInt::get(EltTy, 1);
break;
case Instruction::FRem: SafeC = ConstantFP::get(EltTy, 1.0);
break;
default:
llvm_unreachable(
"Only rem opcodes have no identity constant for RHS");
}
} else {
switch (Opcode) {
case Instruction::Shl: case Instruction::LShr: case Instruction::AShr: case Instruction::SDiv: case Instruction::UDiv: case Instruction::SRem: case Instruction::URem: case Instruction::Sub: case Instruction::FSub: case Instruction::FDiv: case Instruction::FRem: SafeC = Constant::getNullValue(EltTy);
break;
default:
llvm_unreachable("Expected to find identity constant for opcode");
}
}
}
assert(SafeC && "Must have safe constant for binop");
unsigned NumElts = InVTy->getNumElements();
SmallVector<Constant *, 16> Out(NumElts);
for (unsigned i = 0; i != NumElts; ++i) {
Constant *C = In->getAggregateElement(i);
Out[i] = isa<UndefValue>(C) ? SafeC : C;
}
return ConstantVector::get(Out);
}
void addToWorklist(Instruction *I) { Worklist.push(I); }
AssumptionCache &getAssumptionCache() const { return AC; }
TargetLibraryInfo &getTargetLibraryInfo() const { return TLI; }
DominatorTree &getDominatorTree() const { return DT; }
const DataLayout &getDataLayout() const { return DL; }
const SimplifyQuery &getSimplifyQuery() const { return SQ; }
OptimizationRemarkEmitter &getOptimizationRemarkEmitter() const {
return ORE;
}
BlockFrequencyInfo *getBlockFrequencyInfo() const { return BFI; }
ProfileSummaryInfo *getProfileSummaryInfo() const { return PSI; }
LoopInfo *getLoopInfo() const { return LI; }
Optional<Instruction *> targetInstCombineIntrinsic(IntrinsicInst &II);
Optional<Value *>
targetSimplifyDemandedUseBitsIntrinsic(IntrinsicInst &II, APInt DemandedMask,
KnownBits &Known,
bool &KnownBitsComputed);
Optional<Value *> targetSimplifyDemandedVectorEltsIntrinsic(
IntrinsicInst &II, APInt DemandedElts, APInt &UndefElts,
APInt &UndefElts2, APInt &UndefElts3,
std::function<void(Instruction *, unsigned, APInt, APInt &)>
SimplifyAndSetOp);
Instruction *InsertNewInstBefore(Instruction *New, Instruction &Old) {
assert(New && !New->getParent() &&
"New instruction already inserted into a basic block!");
BasicBlock *BB = Old.getParent();
BB->getInstList().insert(Old.getIterator(), New); Worklist.push(New);
return New;
}
Instruction *InsertNewInstWith(Instruction *New, Instruction &Old) {
New->setDebugLoc(Old.getDebugLoc());
return InsertNewInstBefore(New, Old);
}
Instruction *replaceInstUsesWith(Instruction &I, Value *V) {
if (I.use_empty())
return nullptr;
Worklist.pushUsersToWorkList(I);
if (&I == V)
V = PoisonValue::get(I.getType());
LLVM_DEBUG(dbgs() << "IC: Replacing " << I << "\n"
<< " with " << *V << '\n');
I.replaceAllUsesWith(V);
return &I;
}
Instruction *replaceOperand(Instruction &I, unsigned OpNum, Value *V) {
Worklist.addValue(I.getOperand(OpNum));
I.setOperand(OpNum, V);
return &I;
}
void replaceUse(Use &U, Value *NewValue) {
Worklist.addValue(U);
U = NewValue;
}
virtual Instruction *eraseInstFromFunction(Instruction &I) = 0;
void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth,
const Instruction *CxtI) const {
llvm::computeKnownBits(V, Known, DL, Depth, &AC, CxtI, &DT);
}
KnownBits computeKnownBits(const Value *V, unsigned Depth,
const Instruction *CxtI) const {
return llvm::computeKnownBits(V, DL, Depth, &AC, CxtI, &DT);
}
bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero = false,
unsigned Depth = 0,
const Instruction *CxtI = nullptr) {
return llvm::isKnownToBeAPowerOfTwo(V, DL, OrZero, Depth, &AC, CxtI, &DT);
}
bool MaskedValueIsZero(const Value *V, const APInt &Mask, unsigned Depth = 0,
const Instruction *CxtI = nullptr) const {
return llvm::MaskedValueIsZero(V, Mask, DL, Depth, &AC, CxtI, &DT);
}
unsigned ComputeNumSignBits(const Value *Op, unsigned Depth = 0,
const Instruction *CxtI = nullptr) const {
return llvm::ComputeNumSignBits(Op, DL, Depth, &AC, CxtI, &DT);
}
unsigned ComputeMaxSignificantBits(const Value *Op, unsigned Depth = 0,
const Instruction *CxtI = nullptr) const {
return llvm::ComputeMaxSignificantBits(Op, DL, Depth, &AC, CxtI, &DT);
}
OverflowResult computeOverflowForUnsignedMul(const Value *LHS,
const Value *RHS,
const Instruction *CxtI) const {
return llvm::computeOverflowForUnsignedMul(LHS, RHS, DL, &AC, CxtI, &DT);
}
OverflowResult computeOverflowForSignedMul(const Value *LHS, const Value *RHS,
const Instruction *CxtI) const {
return llvm::computeOverflowForSignedMul(LHS, RHS, DL, &AC, CxtI, &DT);
}
OverflowResult computeOverflowForUnsignedAdd(const Value *LHS,
const Value *RHS,
const Instruction *CxtI) const {
return llvm::computeOverflowForUnsignedAdd(LHS, RHS, DL, &AC, CxtI, &DT);
}
OverflowResult computeOverflowForSignedAdd(const Value *LHS, const Value *RHS,
const Instruction *CxtI) const {
return llvm::computeOverflowForSignedAdd(LHS, RHS, DL, &AC, CxtI, &DT);
}
OverflowResult computeOverflowForUnsignedSub(const Value *LHS,
const Value *RHS,
const Instruction *CxtI) const {
return llvm::computeOverflowForUnsignedSub(LHS, RHS, DL, &AC, CxtI, &DT);
}
OverflowResult computeOverflowForSignedSub(const Value *LHS, const Value *RHS,
const Instruction *CxtI) const {
return llvm::computeOverflowForSignedSub(LHS, RHS, DL, &AC, CxtI, &DT);
}
virtual bool SimplifyDemandedBits(Instruction *I, unsigned OpNo,
const APInt &DemandedMask, KnownBits &Known,
unsigned Depth = 0) = 0;
virtual Value *
SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, APInt &UndefElts,
unsigned Depth = 0,
bool AllowMultipleUsers = false) = 0;
};
}
#undef DEBUG_TYPE
#endif