#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/FoldingSet.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Config/llvm-config.h"
#include "llvm/IR/Argument.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalAlias.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Use.h"
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
#include "llvm/IR/Verifier.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/KnownBits.h"
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cassert>
#include <climits>
#include <cstdint>
#include <cstdlib>
#include <map>
#include <memory>
#include <tuple>
#include <utility>
#include <vector>
using namespace llvm;
using namespace PatternMatch;
#define DEBUG_TYPE "scalar-evolution"
STATISTIC(NumTripCountsComputed,
"Number of loops with predictable loop counts");
STATISTIC(NumTripCountsNotComputed,
"Number of loops without predictable loop counts");
STATISTIC(NumBruteForceTripCountsComputed,
"Number of loops with trip counts computed by force");
#ifdef EXPENSIVE_CHECKS
bool llvm::VerifySCEV = true;
#else
bool llvm::VerifySCEV = false;
#endif
static cl::opt<unsigned>
MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
cl::desc("Maximum number of iterations SCEV will "
"symbolically execute a constant "
"derived loop"),
cl::init(100));
static cl::opt<bool, true> VerifySCEVOpt(
"verify-scev", cl::Hidden, cl::location(VerifySCEV),
cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
static cl::opt<bool> VerifySCEVStrict(
"verify-scev-strict", cl::Hidden,
cl::desc("Enable stricter verification with -verify-scev is passed"));
static cl::opt<bool>
VerifySCEVMap("verify-scev-maps", cl::Hidden,
cl::desc("Verify no dangling value in ScalarEvolution's "
"ExprValueMap (slow)"));
static cl::opt<bool> VerifyIR(
"scev-verify-ir", cl::Hidden,
cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
cl::init(false));
static cl::opt<unsigned> MulOpsInlineThreshold(
"scev-mulops-inline-threshold", cl::Hidden,
cl::desc("Threshold for inlining multiplication operands into a SCEV"),
cl::init(32));
static cl::opt<unsigned> AddOpsInlineThreshold(
"scev-addops-inline-threshold", cl::Hidden,
cl::desc("Threshold for inlining addition operands into a SCEV"),
cl::init(500));
static cl::opt<unsigned> MaxSCEVCompareDepth(
"scalar-evolution-max-scev-compare-depth", cl::Hidden,
cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
cl::init(32));
static cl::opt<unsigned> MaxSCEVOperationsImplicationDepth(
"scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
cl::init(2));
static cl::opt<unsigned> MaxValueCompareDepth(
"scalar-evolution-max-value-compare-depth", cl::Hidden,
cl::desc("Maximum depth of recursive value complexity comparisons"),
cl::init(2));
static cl::opt<unsigned>
MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
cl::desc("Maximum depth of recursive arithmetics"),
cl::init(32));
static cl::opt<unsigned> MaxConstantEvolvingDepth(
"scalar-evolution-max-constant-evolving-depth", cl::Hidden,
cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
static cl::opt<unsigned>
MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
cl::init(8));
static cl::opt<unsigned>
MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
cl::desc("Max coefficients in AddRec during evolving"),
cl::init(8));
static cl::opt<unsigned>
HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
cl::desc("Size of the expression which is considered huge"),
cl::init(4096));
static cl::opt<bool>
ClassifyExpressions("scalar-evolution-classify-expressions",
cl::Hidden, cl::init(true),
cl::desc("When printing analysis, include information on every instruction"));
static cl::opt<bool> UseExpensiveRangeSharpening(
"scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
cl::init(false),
cl::desc("Use more powerful methods of sharpening expression ranges. May "
"be costly in terms of compile time"));
static cl::opt<unsigned> MaxPhiSCCAnalysisSize(
"scalar-evolution-max-scc-analysis-depth", cl::Hidden,
cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
"Phi strongly connected components"),
cl::init(8));
static cl::opt<bool>
EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
cl::desc("Handle <= and >= in finite loops"),
cl::init(true));
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
LLVM_DUMP_METHOD void SCEV::dump() const {
print(dbgs());
dbgs() << '\n';
}
#endif
void SCEV::print(raw_ostream &OS) const {
switch (getSCEVType()) {
case scConstant:
cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
return;
case scPtrToInt: {
const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
const SCEV *Op = PtrToInt->getOperand();
OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
<< *PtrToInt->getType() << ")";
return;
}
case scTruncate: {
const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
const SCEV *Op = Trunc->getOperand();
OS << "(trunc " << *Op->getType() << " " << *Op << " to "
<< *Trunc->getType() << ")";
return;
}
case scZeroExtend: {
const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
const SCEV *Op = ZExt->getOperand();
OS << "(zext " << *Op->getType() << " " << *Op << " to "
<< *ZExt->getType() << ")";
return;
}
case scSignExtend: {
const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
const SCEV *Op = SExt->getOperand();
OS << "(sext " << *Op->getType() << " " << *Op << " to "
<< *SExt->getType() << ")";
return;
}
case scAddRecExpr: {
const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
OS << "{" << *AR->getOperand(0);
for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
OS << ",+," << *AR->getOperand(i);
OS << "}<";
if (AR->hasNoUnsignedWrap())
OS << "nuw><";
if (AR->hasNoSignedWrap())
OS << "nsw><";
if (AR->hasNoSelfWrap() &&
!AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)))
OS << "nw><";
AR->getLoop()->getHeader()->printAsOperand(OS, false);
OS << ">";
return;
}
case scAddExpr:
case scMulExpr:
case scUMaxExpr:
case scSMaxExpr:
case scUMinExpr:
case scSMinExpr:
case scSequentialUMinExpr: {
const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
const char *OpStr = nullptr;
switch (NAry->getSCEVType()) {
case scAddExpr: OpStr = " + "; break;
case scMulExpr: OpStr = " * "; break;
case scUMaxExpr: OpStr = " umax "; break;
case scSMaxExpr: OpStr = " smax "; break;
case scUMinExpr:
OpStr = " umin ";
break;
case scSMinExpr:
OpStr = " smin ";
break;
case scSequentialUMinExpr:
OpStr = " umin_seq ";
break;
default:
llvm_unreachable("There are no other nary expression types.");
}
OS << "(";
ListSeparator LS(OpStr);
for (const SCEV *Op : NAry->operands())
OS << LS << *Op;
OS << ")";
switch (NAry->getSCEVType()) {
case scAddExpr:
case scMulExpr:
if (NAry->hasNoUnsignedWrap())
OS << "<nuw>";
if (NAry->hasNoSignedWrap())
OS << "<nsw>";
break;
default:
break;
}
return;
}
case scUDivExpr: {
const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
return;
}
case scUnknown: {
const SCEVUnknown *U = cast<SCEVUnknown>(this);
Type *AllocTy;
if (U->isSizeOf(AllocTy)) {
OS << "sizeof(" << *AllocTy << ")";
return;
}
if (U->isAlignOf(AllocTy)) {
OS << "alignof(" << *AllocTy << ")";
return;
}
Type *CTy;
Constant *FieldNo;
if (U->isOffsetOf(CTy, FieldNo)) {
OS << "offsetof(" << *CTy << ", ";
FieldNo->printAsOperand(OS, false);
OS << ")";
return;
}
U->getValue()->printAsOperand(OS, false);
return;
}
case scCouldNotCompute:
OS << "***COULDNOTCOMPUTE***";
return;
}
llvm_unreachable("Unknown SCEV kind!");
}
Type *SCEV::getType() const {
switch (getSCEVType()) {
case scConstant:
return cast<SCEVConstant>(this)->getType();
case scPtrToInt:
case scTruncate:
case scZeroExtend:
case scSignExtend:
return cast<SCEVCastExpr>(this)->getType();
case scAddRecExpr:
return cast<SCEVAddRecExpr>(this)->getType();
case scMulExpr:
return cast<SCEVMulExpr>(this)->getType();
case scUMaxExpr:
case scSMaxExpr:
case scUMinExpr:
case scSMinExpr:
return cast<SCEVMinMaxExpr>(this)->getType();
case scSequentialUMinExpr:
return cast<SCEVSequentialMinMaxExpr>(this)->getType();
case scAddExpr:
return cast<SCEVAddExpr>(this)->getType();
case scUDivExpr:
return cast<SCEVUDivExpr>(this)->getType();
case scUnknown:
return cast<SCEVUnknown>(this)->getType();
case scCouldNotCompute:
llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
}
llvm_unreachable("Unknown SCEV kind!");
}
bool SCEV::isZero() const {
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
return SC->getValue()->isZero();
return false;
}
bool SCEV::isOne() const {
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
return SC->getValue()->isOne();
return false;
}
bool SCEV::isAllOnesValue() const {
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
return SC->getValue()->isMinusOne();
return false;
}
bool SCEV::isNonConstantNegative() const {
const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
if (!Mul) return false;
const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
if (!SC) return false;
return SC->getAPInt().isNegative();
}
SCEVCouldNotCompute::SCEVCouldNotCompute() :
SCEV(FoldingSetNodeIDRef(), scCouldNotCompute, 0) {}
bool SCEVCouldNotCompute::classof(const SCEV *S) {
return S->getSCEVType() == scCouldNotCompute;
}
const SCEV *ScalarEvolution::getConstant(ConstantInt *V) {
FoldingSetNodeID ID;
ID.AddInteger(scConstant);
ID.AddPointer(V);
void *IP = nullptr;
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
UniqueSCEVs.InsertNode(S, IP);
return S;
}
const SCEV *ScalarEvolution::getConstant(const APInt &Val) {
return getConstant(ConstantInt::get(getContext(), Val));
}
const SCEV *
ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) {
IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
return getConstant(ConstantInt::get(ITy, V, isSigned));
}
SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
const SCEV *op, Type *ty)
: SCEV(ID, SCEVTy, computeExpressionSize(op)), Ty(ty) {
Operands[0] = op;
}
SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
Type *ITy)
: SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
"Must be a non-bit-width-changing pointer-to-integer cast!");
}
SCEVIntegralCastExpr::SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID,
SCEVTypes SCEVTy, const SCEV *op,
Type *ty)
: SCEVCastExpr(ID, SCEVTy, op, ty) {}
SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
Type *ty)
: SCEVIntegralCastExpr(ID, scTruncate, op, ty) {
assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
"Cannot truncate non-integer value!");
}
SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
const SCEV *op, Type *ty)
: SCEVIntegralCastExpr(ID, scZeroExtend, op, ty) {
assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
"Cannot zero extend non-integer value!");
}
SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
const SCEV *op, Type *ty)
: SCEVIntegralCastExpr(ID, scSignExtend, op, ty) {
assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
"Cannot sign extend non-integer value!");
}
void SCEVUnknown::deleted() {
SE->forgetMemoizedResults(this);
SE->UniqueSCEVs.RemoveNode(this);
setValPtr(nullptr);
}
void SCEVUnknown::allUsesReplacedWith(Value *New) {
SE->forgetMemoizedResults(this);
SE->UniqueSCEVs.RemoveNode(this);
setValPtr(New);
}
bool SCEVUnknown::isSizeOf(Type *&AllocTy) const {
if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
if (VCE->getOpcode() == Instruction::PtrToInt)
if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
if (CE->getOpcode() == Instruction::GetElementPtr &&
CE->getOperand(0)->isNullValue() &&
CE->getNumOperands() == 2)
if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1)))
if (CI->isOne()) {
AllocTy = cast<GEPOperator>(CE)->getSourceElementType();
return true;
}
return false;
}
bool SCEVUnknown::isAlignOf(Type *&AllocTy) const {
if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
if (VCE->getOpcode() == Instruction::PtrToInt)
if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
if (CE->getOpcode() == Instruction::GetElementPtr &&
CE->getOperand(0)->isNullValue()) {
Type *Ty = cast<GEPOperator>(CE)->getSourceElementType();
if (StructType *STy = dyn_cast<StructType>(Ty))
if (!STy->isPacked() &&
CE->getNumOperands() == 3 &&
CE->getOperand(1)->isNullValue()) {
if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2)))
if (CI->isOne() &&
STy->getNumElements() == 2 &&
STy->getElementType(0)->isIntegerTy(1)) {
AllocTy = STy->getElementType(1);
return true;
}
}
}
return false;
}
bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const {
if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
if (VCE->getOpcode() == Instruction::PtrToInt)
if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
if (CE->getOpcode() == Instruction::GetElementPtr &&
CE->getNumOperands() == 3 &&
CE->getOperand(0)->isNullValue() &&
CE->getOperand(1)->isNullValue()) {
Type *Ty = cast<GEPOperator>(CE)->getSourceElementType();
if (Ty->isStructTy() || Ty->isArrayTy()) {
CTy = Ty;
FieldNo = CE->getOperand(2);
return true;
}
}
return false;
}
static int
CompareValueComplexity(EquivalenceClasses<const Value *> &EqCacheValue,
const LoopInfo *const LI, Value *LV, Value *RV,
unsigned Depth) {
if (Depth > MaxValueCompareDepth || EqCacheValue.isEquivalent(LV, RV))
return 0;
bool LIsPointer = LV->getType()->isPointerTy(),
RIsPointer = RV->getType()->isPointerTy();
if (LIsPointer != RIsPointer)
return (int)LIsPointer - (int)RIsPointer;
unsigned LID = LV->getValueID(), RID = RV->getValueID();
if (LID != RID)
return (int)LID - (int)RID;
if (const auto *LA = dyn_cast<Argument>(LV)) {
const auto *RA = cast<Argument>(RV);
unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
return (int)LArgNo - (int)RArgNo;
}
if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
const auto *RGV = cast<GlobalValue>(RV);
const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
auto LT = GV->getLinkage();
return !(GlobalValue::isPrivateLinkage(LT) ||
GlobalValue::isInternalLinkage(LT));
};
if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
return LGV->getName().compare(RGV->getName());
}
if (const auto *LInst = dyn_cast<Instruction>(LV)) {
const auto *RInst = cast<Instruction>(RV);
const BasicBlock *LParent = LInst->getParent(),
*RParent = RInst->getParent();
if (LParent != RParent) {
unsigned LDepth = LI->getLoopDepth(LParent),
RDepth = LI->getLoopDepth(RParent);
if (LDepth != RDepth)
return (int)LDepth - (int)RDepth;
}
unsigned LNumOps = LInst->getNumOperands(),
RNumOps = RInst->getNumOperands();
if (LNumOps != RNumOps)
return (int)LNumOps - (int)RNumOps;
for (unsigned Idx : seq(0u, LNumOps)) {
int Result =
CompareValueComplexity(EqCacheValue, LI, LInst->getOperand(Idx),
RInst->getOperand(Idx), Depth + 1);
if (Result != 0)
return Result;
}
}
EqCacheValue.unionSets(LV, RV);
return 0;
}
static Optional<int>
CompareSCEVComplexity(EquivalenceClasses<const SCEV *> &EqCacheSCEV,
EquivalenceClasses<const Value *> &EqCacheValue,
const LoopInfo *const LI, const SCEV *LHS,
const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
if (LHS == RHS)
return 0;
SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
if (LType != RType)
return (int)LType - (int)RType;
if (EqCacheSCEV.isEquivalent(LHS, RHS))
return 0;
if (Depth > MaxSCEVCompareDepth)
return None;
switch (LType) {
case scUnknown: {
const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
int X = CompareValueComplexity(EqCacheValue, LI, LU->getValue(),
RU->getValue(), Depth + 1);
if (X == 0)
EqCacheSCEV.unionSets(LHS, RHS);
return X;
}
case scConstant: {
const SCEVConstant *LC = cast<SCEVConstant>(LHS);
const SCEVConstant *RC = cast<SCEVConstant>(RHS);
const APInt &LA = LC->getAPInt();
const APInt &RA = RC->getAPInt();
unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
if (LBitWidth != RBitWidth)
return (int)LBitWidth - (int)RBitWidth;
return LA.ult(RA) ? -1 : 1;
}
case scAddRecExpr: {
const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
if (LLoop != RLoop) {
const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
assert(LHead != RHead && "Two loops share the same header?");
if (DT.dominates(LHead, RHead))
return 1;
else
assert(DT.dominates(RHead, LHead) &&
"No dominance between recurrences used by one SCEV?");
return -1;
}
unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
if (LNumOps != RNumOps)
return (int)LNumOps - (int)RNumOps;
for (unsigned i = 0; i != LNumOps; ++i) {
auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
LA->getOperand(i), RA->getOperand(i), DT,
Depth + 1);
if (X != 0)
return X;
}
EqCacheSCEV.unionSets(LHS, RHS);
return 0;
}
case scAddExpr:
case scMulExpr:
case scSMaxExpr:
case scUMaxExpr:
case scSMinExpr:
case scUMinExpr:
case scSequentialUMinExpr: {
const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
if (LNumOps != RNumOps)
return (int)LNumOps - (int)RNumOps;
for (unsigned i = 0; i != LNumOps; ++i) {
auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
LC->getOperand(i), RC->getOperand(i), DT,
Depth + 1);
if (X != 0)
return X;
}
EqCacheSCEV.unionSets(LHS, RHS);
return 0;
}
case scUDivExpr: {
const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getLHS(),
RC->getLHS(), DT, Depth + 1);
if (X != 0)
return X;
X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getRHS(),
RC->getRHS(), DT, Depth + 1);
if (X == 0)
EqCacheSCEV.unionSets(LHS, RHS);
return X;
}
case scPtrToInt:
case scTruncate:
case scZeroExtend:
case scSignExtend: {
const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
auto X =
CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getOperand(),
RC->getOperand(), DT, Depth + 1);
if (X == 0)
EqCacheSCEV.unionSets(LHS, RHS);
return X;
}
case scCouldNotCompute:
llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
}
llvm_unreachable("Unknown SCEV kind!");
}
static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,
LoopInfo *LI, DominatorTree &DT) {
if (Ops.size() < 2) return;
EquivalenceClasses<const SCEV *> EqCacheSCEV;
EquivalenceClasses<const Value *> EqCacheValue;
auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
auto Complexity =
CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT);
return Complexity && *Complexity < 0;
};
if (Ops.size() == 2) {
const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
if (IsLessComplex(RHS, LHS))
std::swap(LHS, RHS);
return;
}
llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
return IsLessComplex(LHS, RHS);
});
for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
const SCEV *S = Ops[i];
unsigned Complexity = S->getSCEVType();
for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
if (Ops[j] == S) { std::swap(Ops[i+1], Ops[j]);
++i; if (i == e-2) return; }
}
}
}
static bool hasHugeExpression(ArrayRef<const SCEV *> Ops) {
return any_of(Ops, [](const SCEV *S) {
return S->getExpressionSize() >= HugeExprThreshold;
});
}
static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
ScalarEvolution &SE,
Type *ResultTy) {
if (K == 1)
return SE.getTruncateOrZeroExtend(It, ResultTy);
if (K > 1000)
return SE.getCouldNotCompute();
unsigned W = SE.getTypeSizeInBits(ResultTy);
APInt OddFactorial(W, 1);
unsigned T = 1;
for (unsigned i = 3; i <= K; ++i) {
APInt Mult(W, i);
unsigned TwoFactors = Mult.countTrailingZeros();
T += TwoFactors;
Mult.lshrInPlace(TwoFactors);
OddFactorial *= Mult;
}
unsigned CalculationBits = W + T;
APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
APInt Mod = APInt::getSignedMinValue(W+1);
APInt MultiplyFactor = OddFactorial.zext(W+1);
MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
MultiplyFactor = MultiplyFactor.trunc(W);
IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
CalculationBits);
const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
for (unsigned i = 1; i != K; ++i) {
const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
Dividend = SE.getMulExpr(Dividend,
SE.getTruncateOrZeroExtend(S, CalculationTy));
}
const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
return SE.getMulExpr(SE.getConstant(MultiplyFactor),
SE.getTruncateOrZeroExtend(DivResult, ResultTy));
}
const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It,
ScalarEvolution &SE) const {
return evaluateAtIteration(makeArrayRef(op_begin(), op_end()), It, SE);
}
const SCEV *
SCEVAddRecExpr::evaluateAtIteration(ArrayRef<const SCEV *> Operands,
const SCEV *It, ScalarEvolution &SE) {
assert(Operands.size() > 0);
const SCEV *Result = Operands[0];
for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
if (isa<SCEVCouldNotCompute>(Coeff))
return Coeff;
Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
}
return Result;
}
const SCEV *ScalarEvolution::getLosslessPtrToIntExpr(const SCEV *Op,
unsigned Depth) {
assert(Depth <= 1 &&
"getLosslessPtrToIntExpr() should self-recurse at most once.");
if (!Op->getType()->isPointerTy())
return Op;
FoldingSetNodeID ID;
ID.AddInteger(scPtrToInt);
ID.AddPointer(Op);
void *IP = nullptr;
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
return S;
if (getDataLayout().isNonIntegralPointerType(Op->getType()))
return getCouldNotCompute();
Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
if (getDataLayout().getTypeSizeInBits(getEffectiveSCEVType(Op->getType())) !=
getDataLayout().getTypeSizeInBits(IntPtrTy))
return getCouldNotCompute();
if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
if (isa<ConstantPointerNull>(U->getValue()))
return getZero(IntPtrTy);
SCEV *S = new (SCEVAllocator)
SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
UniqueSCEVs.InsertNode(S, IP);
registerUser(S, Op);
return S;
}
assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
"non-SCEVUnknown's.");
class SCEVPtrToIntSinkingRewriter
: public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
using Base = SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter>;
public:
SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
SCEVPtrToIntSinkingRewriter Rewriter(SE);
return Rewriter.visit(Scev);
}
const SCEV *visit(const SCEV *S) {
Type *STy = S->getType();
if (!STy->isPointerTy())
return S;
return Base::visit(S);
}
const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
SmallVector<const SCEV *, 2> Operands;
bool Changed = false;
for (const auto *Op : Expr->operands()) {
Operands.push_back(visit(Op));
Changed |= Op != Operands.back();
}
return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
}
const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
SmallVector<const SCEV *, 2> Operands;
bool Changed = false;
for (const auto *Op : Expr->operands()) {
Operands.push_back(visit(Op));
Changed |= Op != Operands.back();
}
return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
}
const SCEV *visitUnknown(const SCEVUnknown *Expr) {
assert(Expr->getType()->isPointerTy() &&
"Should only reach pointer-typed SCEVUnknown's.");
return SE.getLosslessPtrToIntExpr(Expr, 1);
}
};
const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
assert(IntOp->getType()->isIntegerTy() &&
"We must have succeeded in sinking the cast, "
"and ending up with an integer-typed expression!");
return IntOp;
}
const SCEV *ScalarEvolution::getPtrToIntExpr(const SCEV *Op, Type *Ty) {
assert(Ty->isIntegerTy() && "Target type must be an integer type!");
const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
if (isa<SCEVCouldNotCompute>(IntOp))
return IntOp;
return getTruncateOrZeroExtend(IntOp, Ty);
}
const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty,
unsigned Depth) {
assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
"This is not a truncating conversion!");
assert(isSCEVable(Ty) &&
"This is not a conversion to a SCEVable type!");
assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
Ty = getEffectiveSCEVType(Ty);
FoldingSetNodeID ID;
ID.AddInteger(scTruncate);
ID.AddPointer(Op);
ID.AddPointer(Ty);
void *IP = nullptr;
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
return getConstant(
cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
if (Depth > MaxCastDepth) {
SCEV *S =
new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
UniqueSCEVs.InsertNode(S, IP);
registerUser(S, Op);
return S;
}
if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
auto *CommOp = cast<SCEVCommutativeExpr>(Op);
SmallVector<const SCEV *, 4> Operands;
unsigned numTruncs = 0;
for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
++i) {
const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
isa<SCEVTruncateExpr>(S))
numTruncs++;
Operands.push_back(S);
}
if (numTruncs < 2) {
if (isa<SCEVAddExpr>(Op))
return getAddExpr(Operands);
else if (isa<SCEVMulExpr>(Op))
return getMulExpr(Operands);
else
llvm_unreachable("Unexpected SCEV type for Op.");
}
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
return S;
}
if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
SmallVector<const SCEV *, 4> Operands;
for (const SCEV *Op : AddRec->operands())
Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
}
uint32_t MinTrailingZeros = GetMinTrailingZeros(Op);
if (MinTrailingZeros >= getTypeSizeInBits(Ty))
return getZero(Ty);
SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
Op, Ty);
UniqueSCEVs.InsertNode(S, IP);
registerUser(S, Op);
return S;
}
static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
ICmpInst::Predicate *Pred,
ScalarEvolution *SE) {
unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
if (SE->isKnownPositive(Step)) {
*Pred = ICmpInst::ICMP_SLT;
return SE->getConstant(APInt::getSignedMinValue(BitWidth) -
SE->getSignedRangeMax(Step));
}
if (SE->isKnownNegative(Step)) {
*Pred = ICmpInst::ICMP_SGT;
return SE->getConstant(APInt::getSignedMaxValue(BitWidth) -
SE->getSignedRangeMin(Step));
}
return nullptr;
}
static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step,
ICmpInst::Predicate *Pred,
ScalarEvolution *SE) {
unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
*Pred = ICmpInst::ICMP_ULT;
return SE->getConstant(APInt::getMinValue(BitWidth) -
SE->getUnsignedRangeMax(Step));
}
namespace {
struct ExtendOpTraitsBase {
typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
unsigned);
};
template <typename ExtendOp> struct ExtendOpTraits {
};
template <>
struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
static const GetExtendExprTy GetExtendExpr;
static const SCEV *getOverflowLimitForStep(const SCEV *Step,
ICmpInst::Predicate *Pred,
ScalarEvolution *SE) {
return getSignedOverflowLimitForStep(Step, Pred, SE);
}
};
const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
SCEVSignExtendExpr>::GetExtendExpr = &ScalarEvolution::getSignExtendExpr;
template <>
struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
static const GetExtendExprTy GetExtendExpr;
static const SCEV *getOverflowLimitForStep(const SCEV *Step,
ICmpInst::Predicate *Pred,
ScalarEvolution *SE) {
return getUnsignedOverflowLimitForStep(Step, Pred, SE);
}
};
const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr;
}
template <typename ExtendOpTy>
static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
ScalarEvolution *SE, unsigned Depth) {
auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
const Loop *L = AR->getLoop();
const SCEV *Start = AR->getStart();
const SCEV *Step = AR->getStepRecurrence(*SE);
const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
if (!SA)
return nullptr;
SmallVector<const SCEV *, 4> DiffOps;
for (const SCEV *Op : SA->operands())
if (Op != Step)
DiffOps.push_back(Op);
if (DiffOps.size() == SA->getNumOperands())
return nullptr;
auto PreStartFlags =
ScalarEvolution::maskFlags(SA->getNoWrapFlags(), SCEV::FlagNUW);
const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
const SCEV *BECount = SE->getBackedgeTakenCount(L);
if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
!isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
return PreStart;
unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
const SCEV *OperandExtendedStart =
SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
(SE->*GetExtendExpr)(Step, WideTy, Depth));
if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
if (PreAR && AR->getNoWrapFlags(WrapType)) {
SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
}
return PreStart;
}
ICmpInst::Predicate Pred;
const SCEV *OverflowLimit =
ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
if (OverflowLimit &&
SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
return PreStart;
return nullptr;
}
template <typename ExtendOpTy>
static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
ScalarEvolution *SE,
unsigned Depth) {
auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
if (!PreStart)
return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
Depth),
(SE->*GetExtendExpr)(PreStart, Ty, Depth));
}
template <typename ExtendOpTy>
bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
const SCEV *Step,
const Loop *L) {
auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
if (!StartC)
return false;
APInt StartAI = StartC->getAPInt();
for (unsigned Delta : {-2, -1, 1, 2}) {
const SCEV *PreStart = getConstant(StartAI - Delta);
FoldingSetNodeID ID;
ID.AddInteger(scAddRecExpr);
ID.AddPointer(PreStart);
ID.AddPointer(Step);
ID.AddPointer(L);
void *IP = nullptr;
const auto *PreAR =
static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
if (PreAR && PreAR->getNoWrapFlags(WrapType)) { const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
DeltaS, &Pred, this);
if (Limit && isKnownPredicate(Pred, PreAR, Limit)) return true;
}
}
return false;
}
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE,
const SCEVConstant *ConstantTerm,
const SCEVAddExpr *WholeAddExpr) {
const APInt &C = ConstantTerm->getAPInt();
const unsigned BitWidth = C.getBitWidth();
uint32_t TZ = BitWidth;
for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
TZ = std::min(TZ, SE.GetMinTrailingZeros(WholeAddExpr->getOperand(I)));
if (TZ) {
return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
}
return APInt(BitWidth, 0);
}
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE,
const APInt &ConstantStart,
const SCEV *Step) {
const unsigned BitWidth = ConstantStart.getBitWidth();
const uint32_t TZ = SE.GetMinTrailingZeros(Step);
if (TZ)
return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
: ConstantStart;
return APInt(BitWidth, 0);
}
const SCEV *
ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
"This is not an extending conversion!");
assert(isSCEVable(Ty) &&
"This is not a conversion to a SCEVable type!");
assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
Ty = getEffectiveSCEVType(Ty);
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
return getConstant(
cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty)));
if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
FoldingSetNodeID ID;
ID.AddInteger(scZeroExtend);
ID.AddPointer(Op);
ID.AddPointer(Ty);
void *IP = nullptr;
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
if (Depth > MaxCastDepth) {
SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
Op, Ty);
UniqueSCEVs.InsertNode(S, IP);
registerUser(S, Op);
return S;
}
if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
const SCEV *X = ST->getOperand();
ConstantRange CR = getUnsignedRange(X);
unsigned TruncBits = getTypeSizeInBits(ST->getType());
unsigned NewBits = getTypeSizeInBits(Ty);
if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
CR.zextOrTrunc(NewBits)))
return getTruncateOrZeroExtend(X, Ty, Depth);
}
if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
if (AR->isAffine()) {
const SCEV *Start = AR->getStart();
const SCEV *Step = AR->getStepRecurrence(*this);
unsigned BitWidth = getTypeSizeInBits(AR->getType());
const Loop *L = AR->getLoop();
if (!AR->hasNoUnsignedWrap()) {
auto NewFlags = proveNoWrapViaConstantRanges(AR);
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
}
if (AR->hasNoUnsignedWrap()) {
Start =
getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
Step = getZeroExtendExpr(Step, Ty, Depth + 1);
return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
}
const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
const SCEV *CastedMaxBECount =
getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
CastedMaxBECount, MaxBECount->getType(), Depth);
if (MaxBECount == RecastedMaxBECount) {
Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
SCEV::FlagAnyWrap, Depth + 1);
const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
SCEV::FlagAnyWrap,
Depth + 1),
WideTy, Depth + 1);
const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
const SCEV *WideMaxBECount =
getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
const SCEV *OperandExtendedAdd =
getAddExpr(WideStart,
getMulExpr(WideMaxBECount,
getZeroExtendExpr(Step, WideTy, Depth + 1),
SCEV::FlagAnyWrap, Depth + 1),
SCEV::FlagAnyWrap, Depth + 1);
if (ZAdd == OperandExtendedAdd) {
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
Depth + 1);
Step = getZeroExtendExpr(Step, Ty, Depth + 1);
return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
}
OperandExtendedAdd =
getAddExpr(WideStart,
getMulExpr(WideMaxBECount,
getSignExtendExpr(Step, WideTy, Depth + 1),
SCEV::FlagAnyWrap, Depth + 1),
SCEV::FlagAnyWrap, Depth + 1);
if (ZAdd == OperandExtendedAdd) {
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
Depth + 1);
Step = getSignExtendExpr(Step, Ty, Depth + 1);
return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
}
}
}
if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
!AC.assumptions().empty()) {
auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
if (AR->hasNoUnsignedWrap()) {
Start =
getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
Step = getZeroExtendExpr(Step, Ty, Depth + 1);
return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
}
if (isKnownNegative(Step)) {
const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) -
getSignedRangeMin(Step));
if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) ||
isKnownOnEveryIteration(ICmpInst::ICMP_UGT, AR, N)) {
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
Depth + 1);
Step = getSignExtendExpr(Step, Ty, Depth + 1);
return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
}
}
}
if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
const APInt &C = SC->getAPInt();
const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
if (D != 0) {
const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
const SCEV *SResidual =
getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
return getAddExpr(SZExtD, SZExtR,
(SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
Depth + 1);
}
}
if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
Start =
getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
Step = getZeroExtendExpr(Step, Ty, Depth + 1);
return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
}
}
{
const SCEV *LHS;
const SCEV *RHS;
if (matchURem(Op, LHS, RHS))
return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
getZeroExtendExpr(RHS, Ty, Depth + 1));
}
if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
if (SA->hasNoUnsignedWrap()) {
SmallVector<const SCEV *, 4> Ops;
for (const auto *Op : SA->operands())
Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
}
if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
if (D != 0) {
const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
const SCEV *SResidual =
getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth);
const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
return getAddExpr(SZExtD, SZExtR,
(SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
Depth + 1);
}
}
}
if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
if (SM->hasNoUnsignedWrap()) {
SmallVector<const SCEV *, 4> Ops;
for (const auto *Op : SM->operands())
Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
}
if (SM->getNumOperands() == 2)
if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
if (MulLHS->getAPInt().isPowerOf2())
if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
MulLHS->getAPInt().logBase2();
Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
return getMulExpr(
getZeroExtendExpr(MulLHS, Ty),
getZeroExtendExpr(
getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
SCEV::FlagNUW, Depth + 1);
}
}
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
Op, Ty);
UniqueSCEVs.InsertNode(S, IP);
registerUser(S, Op);
return S;
}
const SCEV *
ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
"This is not an extending conversion!");
assert(isSCEVable(Ty) &&
"This is not a conversion to a SCEVable type!");
assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
Ty = getEffectiveSCEVType(Ty);
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
return getConstant(
cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty)));
if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
FoldingSetNodeID ID;
ID.AddInteger(scSignExtend);
ID.AddPointer(Op);
ID.AddPointer(Ty);
void *IP = nullptr;
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
if (Depth > MaxCastDepth) {
SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
Op, Ty);
UniqueSCEVs.InsertNode(S, IP);
registerUser(S, Op);
return S;
}
if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
const SCEV *X = ST->getOperand();
ConstantRange CR = getSignedRange(X);
unsigned TruncBits = getTypeSizeInBits(ST->getType());
unsigned NewBits = getTypeSizeInBits(Ty);
if (CR.truncate(TruncBits).signExtend(NewBits).contains(
CR.sextOrTrunc(NewBits)))
return getTruncateOrSignExtend(X, Ty, Depth);
}
if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
if (SA->hasNoSignedWrap()) {
SmallVector<const SCEV *, 4> Ops;
for (const auto *Op : SA->operands())
Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
}
if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
if (D != 0) {
const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
const SCEV *SResidual =
getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth);
const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
return getAddExpr(SSExtD, SSExtR,
(SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
Depth + 1);
}
}
}
if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
if (AR->isAffine()) {
const SCEV *Start = AR->getStart();
const SCEV *Step = AR->getStepRecurrence(*this);
unsigned BitWidth = getTypeSizeInBits(AR->getType());
const Loop *L = AR->getLoop();
if (!AR->hasNoSignedWrap()) {
auto NewFlags = proveNoWrapViaConstantRanges(AR);
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
}
if (AR->hasNoSignedWrap()) {
Start =
getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
Step = getSignExtendExpr(Step, Ty, Depth + 1);
return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
}
const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
const SCEV *CastedMaxBECount =
getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
CastedMaxBECount, MaxBECount->getType(), Depth);
if (MaxBECount == RecastedMaxBECount) {
Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
SCEV::FlagAnyWrap, Depth + 1);
const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
SCEV::FlagAnyWrap,
Depth + 1),
WideTy, Depth + 1);
const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
const SCEV *WideMaxBECount =
getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
const SCEV *OperandExtendedAdd =
getAddExpr(WideStart,
getMulExpr(WideMaxBECount,
getSignExtendExpr(Step, WideTy, Depth + 1),
SCEV::FlagAnyWrap, Depth + 1),
SCEV::FlagAnyWrap, Depth + 1);
if (SAdd == OperandExtendedAdd) {
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
Depth + 1);
Step = getSignExtendExpr(Step, Ty, Depth + 1);
return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
}
OperandExtendedAdd =
getAddExpr(WideStart,
getMulExpr(WideMaxBECount,
getZeroExtendExpr(Step, WideTy, Depth + 1),
SCEV::FlagAnyWrap, Depth + 1),
SCEV::FlagAnyWrap, Depth + 1);
if (SAdd == OperandExtendedAdd) {
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
Depth + 1);
Step = getZeroExtendExpr(Step, Ty, Depth + 1);
return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
}
}
}
auto NewFlags = proveNoSignedWrapViaInduction(AR);
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
if (AR->hasNoSignedWrap()) {
Start =
getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
Step = getSignExtendExpr(Step, Ty, Depth + 1);
return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
}
if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
const APInt &C = SC->getAPInt();
const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
if (D != 0) {
const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
const SCEV *SResidual =
getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
return getAddExpr(SSExtD, SSExtR,
(SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
Depth + 1);
}
}
if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
Start =
getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
Step = getSignExtendExpr(Step, Ty, Depth + 1);
return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
}
}
if (isKnownNonNegative(Op))
return getZeroExtendExpr(Op, Ty, Depth + 1);
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
Op, Ty);
UniqueSCEVs.InsertNode(S, IP);
registerUser(S, { Op });
return S;
}
const SCEV *ScalarEvolution::getCastExpr(SCEVTypes Kind, const SCEV *Op,
Type *Ty) {
switch (Kind) {
case scTruncate:
return getTruncateExpr(Op, Ty);
case scZeroExtend:
return getZeroExtendExpr(Op, Ty);
case scSignExtend:
return getSignExtendExpr(Op, Ty);
case scPtrToInt:
return getPtrToIntExpr(Op, Ty);
default:
llvm_unreachable("Not a SCEV cast expression!");
}
}
const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op,
Type *Ty) {
assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
"This is not an extending conversion!");
assert(isSCEVable(Ty) &&
"This is not a conversion to a SCEVable type!");
Ty = getEffectiveSCEVType(Ty);
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
if (SC->getAPInt().isNegative())
return getSignExtendExpr(Op, Ty);
if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
const SCEV *NewOp = T->getOperand();
if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
return getAnyExtendExpr(NewOp, Ty);
return getTruncateOrNoop(NewOp, Ty);
}
const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
if (!isa<SCEVZeroExtendExpr>(ZExt))
return ZExt;
const SCEV *SExt = getSignExtendExpr(Op, Ty);
if (!isa<SCEVSignExtendExpr>(SExt))
return SExt;
if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
SmallVector<const SCEV *, 4> Ops;
for (const SCEV *Op : AR->operands())
Ops.push_back(getAnyExtendExpr(Op, Ty));
return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
}
if (isa<SCEVSMaxExpr>(Op))
return SExt;
return ZExt;
}
static bool
CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M,
SmallVectorImpl<const SCEV *> &NewOps,
APInt &AccumulatedConstant,
const SCEV *const *Ops, size_t NumOperands,
const APInt &Scale,
ScalarEvolution &SE) {
bool Interesting = false;
unsigned i = 0;
while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
++i;
if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
Interesting = true;
AccumulatedConstant += Scale * C->getAPInt();
}
for (; i != NumOperands; ++i) {
const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
APInt NewScale =
Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
Interesting |=
CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
Add->op_begin(), Add->getNumOperands(),
NewScale, SE);
} else {
SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
const SCEV *Key = SE.getMulExpr(MulOps);
auto Pair = M.insert({Key, NewScale});
if (Pair.second) {
NewOps.push_back(Pair.first->first);
} else {
Pair.first->second += NewScale;
Interesting = true;
}
}
} else {
std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
M.insert({Ops[i], Scale});
if (Pair.second) {
NewOps.push_back(Pair.first->first);
} else {
Pair.first->second += Scale;
Interesting = true;
}
}
}
return Interesting;
}
bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed,
const SCEV *LHS, const SCEV *RHS) {
const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
SCEV::NoWrapFlags, unsigned);
switch (BinOp) {
default:
llvm_unreachable("Unsupported binary op");
case Instruction::Add:
Operation = &ScalarEvolution::getAddExpr;
break;
case Instruction::Sub:
Operation = &ScalarEvolution::getMinusSCEV;
break;
case Instruction::Mul:
Operation = &ScalarEvolution::getMulExpr;
break;
}
const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
Signed ? &ScalarEvolution::getSignExtendExpr
: &ScalarEvolution::getZeroExtendExpr;
auto *NarrowTy = cast<IntegerType>(LHS->getType());
auto *WideTy =
IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
const SCEV *A = (this->*Extension)(
(this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
return A == B;
}
Optional<SCEV::NoWrapFlags>
ScalarEvolution::getStrengthenedNoWrapFlagsFromBinOp(
const OverflowingBinaryOperator *OBO) {
if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
return None;
SCEV::NoWrapFlags Flags = SCEV::NoWrapFlags::FlagAnyWrap;
if (OBO->hasNoUnsignedWrap())
Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
if (OBO->hasNoSignedWrap())
Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
bool Deduced = false;
if (OBO->getOpcode() != Instruction::Add &&
OBO->getOpcode() != Instruction::Sub &&
OBO->getOpcode() != Instruction::Mul)
return None;
const SCEV *LHS = getSCEV(OBO->getOperand(0));
const SCEV *RHS = getSCEV(OBO->getOperand(1));
if (!OBO->hasNoUnsignedWrap() &&
willNotOverflow((Instruction::BinaryOps)OBO->getOpcode(),
false, LHS, RHS)) {
Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
Deduced = true;
}
if (!OBO->hasNoSignedWrap() &&
willNotOverflow((Instruction::BinaryOps)OBO->getOpcode(),
true, LHS, RHS)) {
Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
Deduced = true;
}
if (Deduced)
return Flags;
return None;
}
static SCEV::NoWrapFlags
StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,
const ArrayRef<const SCEV *> Ops,
SCEV::NoWrapFlags Flags) {
using namespace std::placeholders;
using OBO = OverflowingBinaryOperator;
bool CanAnalyze =
Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr;
(void)CanAnalyze;
assert(CanAnalyze && "don't call from other places!");
int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
SCEV::NoWrapFlags SignOrUnsignWrap =
ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
auto IsKnownNonNegative = [&](const SCEV *S) {
return SE->isKnownNonNegative(S);
};
if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
Flags =
ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
if (SignOrUnsignWrap != SignOrUnsignMask &&
(Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
isa<SCEVConstant>(Ops[0])) {
auto Opcode = [&] {
switch (Type) {
case scAddExpr:
return Instruction::Add;
case scMulExpr:
return Instruction::Mul;
default:
llvm_unreachable("Unexpected SCEV op.");
}
}();
const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
Opcode, C, OBO::NoSignedWrap);
if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
}
if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
Opcode, C, OBO::NoUnsignedWrap);
if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
}
}
if (Type == scAddRecExpr && ScalarEvolution::hasFlags(Flags, SCEV::FlagNW) &&
!ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
if (Type == scMulExpr && !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) &&
Ops.size() == 2) {
if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
if (UDiv->getOperand(1) == Ops[1])
Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
if (UDiv->getOperand(1) == Ops[0])
Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
}
return Flags;
}
bool ScalarEvolution::isAvailableAtLoopEntry(const SCEV *S, const Loop *L) {
return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
}
const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
SCEV::NoWrapFlags OrigFlags,
unsigned Depth) {
assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
"only nuw or nsw allowed");
assert(!Ops.empty() && "Cannot get empty add!");
if (Ops.size() == 1) return Ops[0];
#ifndef NDEBUG
Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
for (unsigned i = 1, e = Ops.size(); i != e; ++i)
assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
"SCEVAddExpr operand types don't match!");
unsigned NumPtrs = count_if(
Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
assert(NumPtrs <= 1 && "add has at most one pointer operand");
#endif
GroupByComplexity(Ops, &LI, DT);
unsigned Idx = 0;
if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
++Idx;
assert(Idx < Ops.size());
while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
if (Ops.size() == 2) return Ops[0];
Ops.erase(Ops.begin()+1); LHSC = cast<SCEVConstant>(Ops[0]);
}
if (LHSC->getValue()->isZero()) {
Ops.erase(Ops.begin());
--Idx;
}
if (Ops.size() == 1) return Ops[0];
}
auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
};
if (Depth > MaxArithDepth || hasHugeExpression(Ops))
return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
Add->setNoWrapFlags(ComputeFlags(Ops));
return S;
}
Type *Ty = Ops[0]->getType();
bool FoundMatch = false;
for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
if (Ops[i] == Ops[i+1]) { unsigned Count = 2;
while (i+Count != e && Ops[i+Count] == Ops[i])
++Count;
const SCEV *Scale = getConstant(Ty, Count);
const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
if (Ops.size() == Count)
return Mul;
Ops[i] = Mul;
Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
--i; e -= Count - 1;
FoundMatch = true;
}
if (FoundMatch)
return getAddExpr(Ops, OrigFlags, Depth + 1);
auto FindTruncSrcType = [&]() -> Type * {
if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
return T->getOperand()->getType();
if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
return T->getOperand()->getType();
}
return nullptr;
};
if (auto *SrcType = FindTruncSrcType()) {
SmallVector<const SCEV *, 8> LargeOps;
bool Ok = true;
for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
if (T->getOperand()->getType() != SrcType) {
Ok = false;
break;
}
LargeOps.push_back(T->getOperand());
} else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
LargeOps.push_back(getAnyExtendExpr(C, SrcType));
} else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
SmallVector<const SCEV *, 8> LargeMulOps;
for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
if (const SCEVTruncateExpr *T =
dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
if (T->getOperand()->getType() != SrcType) {
Ok = false;
break;
}
LargeMulOps.push_back(T->getOperand());
} else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
} else {
Ok = false;
break;
}
}
if (Ok)
LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
} else {
Ok = false;
break;
}
}
if (Ok) {
const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
return getTruncateExpr(Fold, Ty);
}
}
if (Ops.size() == 2) {
const SCEV *A = Ops[0];
const SCEV *B = Ops[1];
auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
auto *C = dyn_cast<SCEVConstant>(A);
if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
auto C2 = C->getAPInt();
SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
APInt ConstAdd = C1 + C2;
auto AddFlags = AddExpr->getNoWrapFlags();
if (ScalarEvolution::hasFlags(AddFlags, SCEV::FlagNUW) &&
ConstAdd.ule(C1)) {
PreservedFlags =
ScalarEvolution::setFlags(PreservedFlags, SCEV::FlagNUW);
}
if (ScalarEvolution::hasFlags(AddFlags, SCEV::FlagNSW) &&
C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
ConstAdd.abs().ule(C1.abs())) {
PreservedFlags =
ScalarEvolution::setFlags(PreservedFlags, SCEV::FlagNSW);
}
if (PreservedFlags != SCEV::FlagAnyWrap) {
SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
NewOps[0] = getConstant(ConstAdd);
return getAddExpr(NewOps, PreservedFlags);
}
}
}
if (Ops.size() == 2) {
const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
if (Mul && Mul->getNumOperands() == 2 &&
Mul->getOperand(0)->isAllOnesValue()) {
const SCEV *X;
const SCEV *Y;
if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
return getMulExpr(Y, getUDivExpr(X, Y));
}
}
}
while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
++Idx;
if (Idx < Ops.size()) {
bool DeletedAdd = false;
SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
if (Ops.size() > AddOpsInlineThreshold ||
Add->getNumOperands() > AddOpsInlineThreshold)
break;
Ops.erase(Ops.begin()+Idx);
Ops.append(Add->op_begin(), Add->op_end());
DeletedAdd = true;
CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
}
if (DeletedAdd)
return getAddExpr(Ops, CommonFlags, Depth + 1);
}
while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
++Idx;
if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
uint64_t BitWidth = getTypeSizeInBits(Ty);
DenseMap<const SCEV *, APInt> M;
SmallVector<const SCEV *, 8> NewOps;
APInt AccumulatedConstant(BitWidth, 0);
if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
Ops.data(), Ops.size(),
APInt(BitWidth, 1), *this)) {
struct APIntCompare {
bool operator()(const APInt &LHS, const APInt &RHS) const {
return LHS.ult(RHS);
}
};
std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
for (const SCEV *NewOp : NewOps)
MulOpLists[M.find(NewOp)->second].push_back(NewOp);
Ops.clear();
if (AccumulatedConstant != 0)
Ops.push_back(getConstant(AccumulatedConstant));
for (auto &MulOp : MulOpLists) {
if (MulOp.first == 1) {
Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
} else if (MulOp.first != 0) {
Ops.push_back(getMulExpr(
getConstant(MulOp.first),
getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
SCEV::FlagAnyWrap, Depth + 1));
}
}
if (Ops.empty())
return getZero(Ty);
if (Ops.size() == 1)
return Ops[0];
return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
}
}
for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
if (isa<SCEVConstant>(MulOpSCEV))
continue;
for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
if (MulOpSCEV == Ops[AddOp]) {
const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
if (Mul->getNumOperands() != 2) {
SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
Mul->op_begin()+MulOp);
MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
}
SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
SCEV::FlagAnyWrap, Depth + 1);
if (Ops.size() == 2) return OuterMul;
if (AddOp < Idx) {
Ops.erase(Ops.begin()+AddOp);
Ops.erase(Ops.begin()+Idx-1);
} else {
Ops.erase(Ops.begin()+Idx);
Ops.erase(Ops.begin()+AddOp-1);
}
Ops.push_back(OuterMul);
return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
}
for (unsigned OtherMulIdx = Idx+1;
OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
++OtherMulIdx) {
const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
OMulOp != e; ++OMulOp)
if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
if (Mul->getNumOperands() != 2) {
SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
Mul->op_begin()+MulOp);
MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
}
const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
if (OtherMul->getNumOperands() != 2) {
SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(),
OtherMul->op_begin()+OMulOp);
MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end());
InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
}
SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
const SCEV *InnerMulSum =
getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
SCEV::FlagAnyWrap, Depth + 1);
if (Ops.size() == 2) return OuterMul;
Ops.erase(Ops.begin()+Idx);
Ops.erase(Ops.begin()+OtherMulIdx-1);
Ops.push_back(OuterMul);
return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
}
}
}
}
while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
++Idx;
for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
SmallVector<const SCEV *, 8> LIOps;
const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
const Loop *AddRecLoop = AddRec->getLoop();
for (unsigned i = 0, e = Ops.size(); i != e; ++i)
if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
LIOps.push_back(Ops[i]);
Ops.erase(Ops.begin()+i);
--i; --e;
}
if (!LIOps.empty()) {
LIOps.push_back(AddRec);
SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
LIOps.pop_back();
LIOps.push_back(AddRec->getStart());
SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
SCEV::NoWrapFlags AddFlags = Flags;
if (AddFlags != SCEV::FlagAnyWrap) {
auto *DefI = getDefiningScopeBound(LIOps);
auto *ReachI = &*AddRecLoop->getHeader()->begin();
if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
AddFlags = SCEV::FlagAnyWrap;
}
AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
if (Ops.size() == 1) return NewRec;
for (unsigned i = 0;; ++i)
if (Ops[i] == AddRec) {
Ops[i] = NewRec;
break;
}
return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
}
for (unsigned OtherIdx = Idx+1;
OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
++OtherIdx) {
assert(DT.dominates(
cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
AddRec->getLoop()->getHeader()) &&
"AddRecExprs are not sorted in reverse dominance order?");
if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
++OtherIdx) {
const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
if (OtherAddRec->getLoop() == AddRecLoop) {
for (unsigned i = 0, e = OtherAddRec->getNumOperands();
i != e; ++i) {
if (i >= AddRecOps.size()) {
AddRecOps.append(OtherAddRec->op_begin()+i,
OtherAddRec->op_end());
break;
}
SmallVector<const SCEV *, 2> TwoOps = {
AddRecOps[i], OtherAddRec->getOperand(i)};
AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
}
Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
}
}
Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
}
}
}
return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
}
const SCEV *
ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
SCEV::NoWrapFlags Flags) {
FoldingSetNodeID ID;
ID.AddInteger(scAddExpr);
for (const SCEV *Op : Ops)
ID.AddPointer(Op);
void *IP = nullptr;
SCEVAddExpr *S =
static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
if (!S) {
const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
std::uninitialized_copy(Ops.begin(), Ops.end(), O);
S = new (SCEVAllocator)
SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
UniqueSCEVs.InsertNode(S, IP);
registerUser(S, Ops);
}
S->setNoWrapFlags(Flags);
return S;
}
const SCEV *
ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
const Loop *L, SCEV::NoWrapFlags Flags) {
FoldingSetNodeID ID;
ID.AddInteger(scAddRecExpr);
for (const SCEV *Op : Ops)
ID.AddPointer(Op);
ID.AddPointer(L);
void *IP = nullptr;
SCEVAddRecExpr *S =
static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
if (!S) {
const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
std::uninitialized_copy(Ops.begin(), Ops.end(), O);
S = new (SCEVAllocator)
SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
UniqueSCEVs.InsertNode(S, IP);
LoopUsers[L].push_back(S);
registerUser(S, Ops);
}
setNoWrapFlags(S, Flags);
return S;
}
const SCEV *
ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
SCEV::NoWrapFlags Flags) {
FoldingSetNodeID ID;
ID.AddInteger(scMulExpr);
for (const SCEV *Op : Ops)
ID.AddPointer(Op);
void *IP = nullptr;
SCEVMulExpr *S =
static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
if (!S) {
const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
std::uninitialized_copy(Ops.begin(), Ops.end(), O);
S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
O, Ops.size());
UniqueSCEVs.InsertNode(S, IP);
registerUser(S, Ops);
}
S->setNoWrapFlags(Flags);
return S;
}
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
uint64_t k = i*j;
if (j > 1 && k / j != i) Overflow = true;
return k;
}
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
if (n == 0 || n == k) return 1;
if (k > n) return 0;
if (k > n/2)
k = n-k;
uint64_t r = 1;
for (uint64_t i = 1; i <= k; ++i) {
r = umul_ov(r, n-(i-1), Overflow);
r /= i;
}
return r;
}
static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
struct FindConstantInAddMulChain {
bool FoundConstant = false;
bool follow(const SCEV *S) {
FoundConstant |= isa<SCEVConstant>(S);
return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
}
bool isDone() const {
return FoundConstant;
}
};
FindConstantInAddMulChain F;
SCEVTraversal<FindConstantInAddMulChain> ST(F);
ST.visitAll(StartExpr);
return F.FoundConstant;
}
const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
SCEV::NoWrapFlags OrigFlags,
unsigned Depth) {
assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
"only nuw or nsw allowed");
assert(!Ops.empty() && "Cannot get empty mul!");
if (Ops.size() == 1) return Ops[0];
#ifndef NDEBUG
Type *ETy = Ops[0]->getType();
assert(!ETy->isPointerTy());
for (unsigned i = 1, e = Ops.size(); i != e; ++i)
assert(Ops[i]->getType() == ETy &&
"SCEVMulExpr operand types don't match!");
#endif
GroupByComplexity(Ops, &LI, DT);
unsigned Idx = 0;
if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
++Idx;
assert(Idx < Ops.size());
while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
Ops[0] = getConstant(LHSC->getAPInt() * RHSC->getAPInt());
if (Ops.size() == 2) return Ops[0];
Ops.erase(Ops.begin()+1); LHSC = cast<SCEVConstant>(Ops[0]);
}
if (LHSC->getValue()->isZero())
return LHSC;
if (LHSC->getValue()->isOne()) {
Ops.erase(Ops.begin());
--Idx;
}
if (Ops.size() == 1)
return Ops[0];
}
auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
};
if (Depth > MaxArithDepth || hasHugeExpression(Ops))
return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
Mul->setNoWrapFlags(ComputeFlags(Ops));
return S;
}
if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
if (Ops.size() == 2) {
if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) {
const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0),
SCEV::FlagAnyWrap, Depth + 1);
const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1),
SCEV::FlagAnyWrap, Depth + 1);
return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
}
if (Ops[0]->isAllOnesValue()) {
if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
SmallVector<const SCEV *, 4> NewOps;
bool AnyFolded = false;
for (const SCEV *AddOp : Add->operands()) {
const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
Depth + 1);
if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
NewOps.push_back(Mul);
}
if (AnyFolded)
return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
} else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
SmallVector<const SCEV *, 4> Operands;
for (const SCEV *AddRecOp : AddRec->operands())
Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
Depth + 1));
return getAddRecExpr(Operands, AddRec->getLoop(),
AddRec->getNoWrapFlags(SCEV::FlagNW));
}
}
}
}
while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
++Idx;
if (Idx < Ops.size()) {
bool DeletedMul = false;
while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
if (Ops.size() > MulOpsInlineThreshold)
break;
Ops.erase(Ops.begin()+Idx);
Ops.append(Mul->op_begin(), Mul->op_end());
DeletedMul = true;
}
if (DeletedMul)
return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
}
while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
++Idx;
for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
SmallVector<const SCEV *, 8> LIOps;
const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
const Loop *AddRecLoop = AddRec->getLoop();
for (unsigned i = 0, e = Ops.size(); i != e; ++i)
if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
LIOps.push_back(Ops[i]);
Ops.erase(Ops.begin()+i);
--i; --e;
}
if (!LIOps.empty()) {
SmallVector<const SCEV *, 4> NewOps;
NewOps.reserve(AddRec->getNumOperands());
const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
SCEV::FlagAnyWrap, Depth + 1));
SCEV::NoWrapFlags Flags = ComputeFlags({Scale, AddRec});
const SCEV *NewRec = getAddRecExpr(
NewOps, AddRecLoop, AddRec->getNoWrapFlags(Flags));
if (Ops.size() == 1) return NewRec;
for (unsigned i = 0;; ++i)
if (Ops[i] == AddRec) {
Ops[i] = NewRec;
break;
}
return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
}
bool OpsModified = false;
for (unsigned OtherIdx = Idx+1;
OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
++OtherIdx) {
const SCEVAddRecExpr *OtherAddRec =
dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop)
continue;
if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
continue;
bool Overflow = false;
Type *Ty = AddRec->getType();
bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
SmallVector<const SCEV*, 7> AddRecOps;
for (int x = 0, xe = AddRec->getNumOperands() +
OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
SmallVector <const SCEV *, 7> SumOps;
for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
z < ze && !Overflow; ++z) {
uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
uint64_t Coeff;
if (LargerThan64Bits)
Coeff = umul_ov(Coeff1, Coeff2, Overflow);
else
Coeff = Coeff1*Coeff2;
const SCEV *CoeffTerm = getConstant(Ty, Coeff);
const SCEV *Term1 = AddRec->getOperand(y-z);
const SCEV *Term2 = OtherAddRec->getOperand(z);
SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
SCEV::FlagAnyWrap, Depth + 1));
}
}
if (SumOps.empty())
SumOps.push_back(getZero(Ty));
AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
}
if (!Overflow) {
const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRecLoop,
SCEV::FlagAnyWrap);
if (Ops.size() == 2) return NewAddRec;
Ops[Idx] = NewAddRec;
Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
OpsModified = true;
AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
if (!AddRec)
break;
}
}
if (OpsModified)
return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
}
return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
}
const SCEV *ScalarEvolution::getURemExpr(const SCEV *LHS,
const SCEV *RHS) {
assert(getEffectiveSCEVType(LHS->getType()) ==
getEffectiveSCEVType(RHS->getType()) &&
"SCEVURemExpr operand types don't match!");
if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
if (RHSC->getValue()->isOne())
return getZero(LHS->getType());
if (RHSC->getAPInt().isPowerOf2()) {
Type *FullTy = LHS->getType();
Type *TruncTy =
IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
}
}
const SCEV *UDiv = getUDivExpr(LHS, RHS);
const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
}
const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
const SCEV *RHS) {
assert(!LHS->getType()->isPointerTy() &&
"SCEVUDivExpr operand can't be pointer!");
assert(LHS->getType() == RHS->getType() &&
"SCEVUDivExpr operand types don't match!");
FoldingSetNodeID ID;
ID.AddInteger(scUDivExpr);
ID.AddPointer(LHS);
ID.AddPointer(RHS);
void *IP = nullptr;
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
return S;
if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
if (LHSC->getValue()->isZero())
return LHS;
if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
if (RHSC->getValue()->isOne())
return LHS; if (!RHSC->getValue()->isZero()) {
Type *Ty = LHS->getType();
unsigned LZ = RHSC->getAPInt().countLeadingZeros();
unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
if (!RHSC->getAPInt().isPowerOf2())
++MaxShiftAmt;
IntegerType *ExtTy =
IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
if (const SCEVConstant *Step =
dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
const APInt &StepInt = Step->getAPInt();
const APInt &DivInt = RHSC->getAPInt();
if (!StepInt.urem(DivInt) &&
getZeroExtendExpr(AR, ExtTy) ==
getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
getZeroExtendExpr(Step, ExtTy),
AR->getLoop(), SCEV::FlagAnyWrap)) {
SmallVector<const SCEV *, 4> Operands;
for (const SCEV *Op : AR->operands())
Operands.push_back(getUDivExpr(Op, RHS));
return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
}
const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
if (StartC && !DivInt.urem(StepInt) &&
getZeroExtendExpr(AR, ExtTy) ==
getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
getZeroExtendExpr(Step, ExtTy),
AR->getLoop(), SCEV::FlagAnyWrap)) {
const APInt &StartInt = StartC->getAPInt();
const APInt &StartRem = StartInt.urem(StepInt);
if (StartRem != 0) {
const SCEV *NewLHS =
getAddRecExpr(getConstant(StartInt - StartRem), Step,
AR->getLoop(), SCEV::FlagNW);
if (LHS != NewLHS) {
LHS = NewLHS;
ID.clear();
ID.AddInteger(scUDivExpr);
ID.AddPointer(LHS);
ID.AddPointer(RHS);
IP = nullptr;
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
return S;
}
}
}
}
if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
SmallVector<const SCEV *, 4> Operands;
for (const SCEV *Op : M->operands())
Operands.push_back(getZeroExtendExpr(Op, ExtTy));
if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
const SCEV *Op = M->getOperand(i);
const SCEV *Div = getUDivExpr(Op, RHSC);
if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
Operands = SmallVector<const SCEV *, 4>(M->operands());
Operands[i] = Div;
return getMulExpr(Operands);
}
}
}
if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
if (auto *DivisorConstant =
dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
bool Overflow = false;
APInt NewRHS =
DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
if (Overflow) {
return getConstant(RHSC->getType(), 0, false);
}
return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
}
}
if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
SmallVector<const SCEV *, 4> Operands;
for (const SCEV *Op : A->operands())
Operands.push_back(getZeroExtendExpr(Op, ExtTy));
if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
Operands.clear();
for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
if (isa<SCEVUDivExpr>(Op) ||
getMulExpr(Op, RHS) != A->getOperand(i))
break;
Operands.push_back(Op);
}
if (Operands.size() == A->getNumOperands())
return getAddExpr(Operands);
}
}
if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
}
}
IP = nullptr;
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
LHS, RHS);
UniqueSCEVs.InsertNode(S, IP);
registerUser(S, {LHS, RHS});
return S;
}
APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
APInt A = C1->getAPInt().abs();
APInt B = C2->getAPInt().abs();
uint32_t ABW = A.getBitWidth();
uint32_t BBW = B.getBitWidth();
if (ABW > BBW)
B = B.zext(ABW);
else if (ABW < BBW)
A = A.zext(BBW);
return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
}
const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS,
const SCEV *RHS) {
const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
if (!Mul || !Mul->hasNoUnsignedWrap())
return getUDivExpr(LHS, RHS);
if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
if (LHSCst == RHSCst) {
SmallVector<const SCEV *, 2> Operands(drop_begin(Mul->operands()));
return getMulExpr(Operands);
}
APInt Factor = gcd(LHSCst, RHSCst);
if (!Factor.isIntN(1)) {
LHSCst =
cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
RHSCst =
cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
SmallVector<const SCEV *, 2> Operands;
Operands.push_back(LHSCst);
Operands.append(Mul->op_begin() + 1, Mul->op_end());
LHS = getMulExpr(Operands);
RHS = RHSCst;
Mul = dyn_cast<SCEVMulExpr>(LHS);
if (!Mul)
return getUDivExactExpr(LHS, RHS);
}
}
}
for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
if (Mul->getOperand(i) == RHS) {
SmallVector<const SCEV *, 2> Operands;
Operands.append(Mul->op_begin(), Mul->op_begin() + i);
Operands.append(Mul->op_begin() + i + 1, Mul->op_end());
return getMulExpr(Operands);
}
}
return getUDivExpr(LHS, RHS);
}
const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
const Loop *L,
SCEV::NoWrapFlags Flags) {
SmallVector<const SCEV *, 4> Operands;
Operands.push_back(Start);
if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
if (StepChrec->getLoop() == L) {
Operands.append(StepChrec->op_begin(), StepChrec->op_end());
return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
}
Operands.push_back(Step);
return getAddRecExpr(Operands, L, Flags);
}
const SCEV *
ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
const Loop *L, SCEV::NoWrapFlags Flags) {
if (Operands.size() == 1) return Operands[0];
#ifndef NDEBUG
Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
assert(getEffectiveSCEVType(Operands[i]->getType()) == ETy &&
"SCEVAddRecExpr operand types don't match!");
assert(!Operands[i]->getType()->isPointerTy() && "Step must be integer");
}
for (unsigned i = 0, e = Operands.size(); i != e; ++i)
assert(isLoopInvariant(Operands[i], L) &&
"SCEVAddRecExpr operand is not loop-invariant!");
#endif
if (Operands.back()->isZero()) {
Operands.pop_back();
return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); }
Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
const Loop *NestedLoop = NestedAR->getLoop();
if (L->contains(NestedLoop)
? (L->getLoopDepth() < NestedLoop->getLoopDepth())
: (!NestedLoop->contains(L) &&
DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
Operands[0] = NestedAR->getStart();
bool AllInvariant = all_of(
Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
if (AllInvariant) {
SCEV::NoWrapFlags OuterFlags =
maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
return isLoopInvariant(Op, NestedLoop);
});
if (AllInvariant) {
SCEV::NoWrapFlags InnerFlags =
maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
}
}
Operands[0] = NestedAR;
}
}
return getOrCreateAddRecExpr(Operands, L, Flags);
}
const SCEV *
ScalarEvolution::getGEPExpr(GEPOperator *GEP,
const SmallVectorImpl<const SCEV *> &IndexExprs) {
const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
const bool AssumeInBoundsFlags = [&]() {
if (!GEP->isInBounds())
return false;
auto *GEPI = dyn_cast<Instruction>(GEP);
return GEPI && isSCEVExprNeverPoison(GEPI);
}();
SCEV::NoWrapFlags OffsetWrap =
AssumeInBoundsFlags ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
Type *CurTy = GEP->getType();
bool FirstIter = true;
SmallVector<const SCEV *, 4> Offsets;
for (const SCEV *IndexExpr : IndexExprs) {
if (StructType *STy = dyn_cast<StructType>(CurTy)) {
ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
unsigned FieldNo = Index->getZExtValue();
const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
Offsets.push_back(FieldOffset);
CurTy = STy->getTypeAtIndex(Index);
} else {
if (FirstIter) {
assert(isa<PointerType>(CurTy) &&
"The first index of a GEP indexes a pointer");
CurTy = GEP->getSourceElementType();
FirstIter = false;
} else {
CurTy = GetElementPtrInst::getTypeAtIndex(CurTy, (uint64_t)0);
}
const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
Offsets.push_back(LocalOffset);
}
}
if (Offsets.empty())
return BaseExpr;
const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
SCEV::NoWrapFlags BaseWrap = AssumeInBoundsFlags && isKnownNonNegative(Offset)
? SCEV::FlagNUW : SCEV::FlagAnyWrap;
auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
assert(BaseExpr->getType() == GEPExpr->getType() &&
"GEP should not change type mid-flight.");
return GEPExpr;
}
SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
ArrayRef<const SCEV *> Ops) {
FoldingSetNodeID ID;
ID.AddInteger(SCEVType);
for (const SCEV *Op : Ops)
ID.AddPointer(Op);
void *IP = nullptr;
return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
}
const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
SCEV::NoWrapFlags Flags = IsNSW ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
}
const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind,
SmallVectorImpl<const SCEV *> &Ops) {
assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
if (Ops.size() == 1) return Ops[0];
#ifndef NDEBUG
Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
"Operand types don't match!");
assert(Ops[0]->getType()->isPointerTy() ==
Ops[i]->getType()->isPointerTy() &&
"min/max should be consistently pointerish");
}
#endif
bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
GroupByComplexity(Ops, &LI, DT);
if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
return S;
}
unsigned Idx = 0;
if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
++Idx;
assert(Idx < Ops.size());
auto FoldOp = [&](const APInt &LHS, const APInt &RHS) {
if (Kind == scSMaxExpr)
return APIntOps::smax(LHS, RHS);
else if (Kind == scSMinExpr)
return APIntOps::smin(LHS, RHS);
else if (Kind == scUMaxExpr)
return APIntOps::umax(LHS, RHS);
else if (Kind == scUMinExpr)
return APIntOps::umin(LHS, RHS);
llvm_unreachable("Unknown SCEV min/max opcode");
};
while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
ConstantInt *Fold = ConstantInt::get(
getContext(), FoldOp(LHSC->getAPInt(), RHSC->getAPInt()));
Ops[0] = getConstant(Fold);
Ops.erase(Ops.begin()+1); if (Ops.size() == 1) return Ops[0];
LHSC = cast<SCEVConstant>(Ops[0]);
}
bool IsMinV = LHSC->getValue()->isMinValue(IsSigned);
bool IsMaxV = LHSC->getValue()->isMaxValue(IsSigned);
if (IsMax ? IsMinV : IsMaxV) {
Ops.erase(Ops.begin());
--Idx;
} else if (IsMax ? IsMaxV : IsMinV) {
return LHSC;
}
if (Ops.size() == 1) return Ops[0];
}
while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
++Idx;
if (Idx < Ops.size()) {
bool DeletedAny = false;
while (Ops[Idx]->getSCEVType() == Kind) {
const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
Ops.erase(Ops.begin()+Idx);
Ops.append(SMME->op_begin(), SMME->op_end());
DeletedAny = true;
}
if (DeletedAny)
return getMinMaxExpr(Kind, Ops);
}
llvm::CmpInst::Predicate GEPred =
IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
llvm::CmpInst::Predicate LEPred =
IsSigned ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
if (Ops[i] == Ops[i + 1] ||
isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
--i;
--e;
} else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
Ops[i + 1])) {
Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
--i;
--e;
}
}
if (Ops.size() == 1) return Ops[0];
assert(!Ops.empty() && "Reduced smax down to nothing!");
FoldingSetNodeID ID;
ID.AddInteger(Kind);
for (unsigned i = 0, e = Ops.size(); i != e; ++i)
ID.AddPointer(Ops[i]);
void *IP = nullptr;
const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
if (ExistingSCEV)
return ExistingSCEV;
const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
std::uninitialized_copy(Ops.begin(), Ops.end(), O);
SCEV *S = new (SCEVAllocator)
SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
UniqueSCEVs.InsertNode(S, IP);
registerUser(S, Ops);
return S;
}
namespace {
class SCEVSequentialMinMaxDeduplicatingVisitor final
: public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
Optional<const SCEV *>> {
using RetVal = Optional<const SCEV *>;
using Base = SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor, RetVal>;
ScalarEvolution &SE;
const SCEVTypes RootKind; const SCEVTypes NonSequentialRootKind; SmallPtrSet<const SCEV *, 16> SeenOps;
bool canRecurseInto(SCEVTypes Kind) const {
return RootKind == Kind || NonSequentialRootKind == Kind;
};
RetVal visitAnyMinMaxExpr(const SCEV *S) {
assert((isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) &&
"Only for min/max expressions.");
SCEVTypes Kind = S->getSCEVType();
if (!canRecurseInto(Kind))
return S;
auto *NAry = cast<SCEVNAryExpr>(S);
SmallVector<const SCEV *> NewOps;
bool Changed =
visit(Kind, makeArrayRef(NAry->op_begin(), NAry->op_end()), NewOps);
if (!Changed)
return S;
if (NewOps.empty())
return None;
return isa<SCEVSequentialMinMaxExpr>(S)
? SE.getSequentialMinMaxExpr(Kind, NewOps)
: SE.getMinMaxExpr(Kind, NewOps);
}
RetVal visit(const SCEV *S) {
if (!SeenOps.insert(S).second)
return None;
return Base::visit(S);
}
public:
SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
SCEVTypes RootKind)
: SE(SE), RootKind(RootKind),
NonSequentialRootKind(
SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
RootKind)) {}
bool visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
SmallVectorImpl<const SCEV *> &NewOps) {
bool Changed = false;
SmallVector<const SCEV *> Ops;
Ops.reserve(OrigOps.size());
for (const SCEV *Op : OrigOps) {
RetVal NewOp = visit(Op);
if (NewOp != Op)
Changed = true;
if (NewOp)
Ops.emplace_back(*NewOp);
}
if (Changed)
NewOps = std::move(Ops);
return Changed;
}
RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
return visitAnyMinMaxExpr(Expr);
}
RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
return visitAnyMinMaxExpr(Expr);
}
RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
return visitAnyMinMaxExpr(Expr);
}
RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
return visitAnyMinMaxExpr(Expr);
}
RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
return visitAnyMinMaxExpr(Expr);
}
RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
};
}
static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
struct SCEVPoisonCollector {
bool LookThroughSeq;
SmallPtrSet<const SCEV *, 4> MaybePoison;
SCEVPoisonCollector(bool LookThroughSeq) : LookThroughSeq(LookThroughSeq) {}
bool follow(const SCEV *S) {
if (!LookThroughSeq && isa<SCEVSequentialMinMaxExpr>(S))
return false;
if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
if (!isGuaranteedNotToBePoison(SU->getValue()))
MaybePoison.insert(S);
}
return true;
}
bool isDone() const { return false; }
};
SCEVPoisonCollector PC1( true);
visitAll(AssumedPoison, PC1);
if (PC1.MaybePoison.empty())
return true;
SCEVPoisonCollector PC2( false);
visitAll(S, PC2);
return all_of(PC1.MaybePoison,
[&](const SCEV *S) { return PC2.MaybePoison.contains(S); });
}
const SCEV *
ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind,
SmallVectorImpl<const SCEV *> &Ops) {
assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
"Not a SCEVSequentialMinMaxExpr!");
assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
if (Ops.size() == 1)
return Ops[0];
#ifndef NDEBUG
Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
"Operand types don't match!");
assert(Ops[0]->getType()->isPointerTy() ==
Ops[i]->getType()->isPointerTy() &&
"min/max should be consistently pointerish");
}
#endif
if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
return S;
{
SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
bool Changed = Deduplicator.visit(Kind, Ops, Ops);
if (Changed)
return getSequentialMinMaxExpr(Kind, Ops);
}
{
unsigned Idx = 0;
bool DeletedAny = false;
while (Idx < Ops.size()) {
if (Ops[Idx]->getSCEVType() != Kind) {
++Idx;
continue;
}
const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
Ops.erase(Ops.begin() + Idx);
Ops.insert(Ops.begin() + Idx, SMME->op_begin(), SMME->op_end());
DeletedAny = true;
}
if (DeletedAny)
return getSequentialMinMaxExpr(Kind, Ops);
}
const SCEV *SaturationPoint;
ICmpInst::Predicate Pred;
switch (Kind) {
case scSequentialUMinExpr:
SaturationPoint = getZero(Ops[0]->getType());
Pred = ICmpInst::ICMP_ULE;
break;
default:
llvm_unreachable("Not a sequential min/max type.");
}
for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
if (::impliesPoison(Ops[i], Ops[i - 1]) ||
isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
SaturationPoint)) {
SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
Ops[i - 1] = getMinMaxExpr(
SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(Kind),
SeqOps);
Ops.erase(Ops.begin() + i);
return getSequentialMinMaxExpr(Kind, Ops);
}
if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
Ops.erase(Ops.begin() + i);
return getSequentialMinMaxExpr(Kind, Ops);
}
}
FoldingSetNodeID ID;
ID.AddInteger(Kind);
for (unsigned i = 0, e = Ops.size(); i != e; ++i)
ID.AddPointer(Ops[i]);
void *IP = nullptr;
const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
if (ExistingSCEV)
return ExistingSCEV;
const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
std::uninitialized_copy(Ops.begin(), Ops.end(), O);
SCEV *S = new (SCEVAllocator)
SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
UniqueSCEVs.InsertNode(S, IP);
registerUser(S, Ops);
return S;
}
const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
return getSMaxExpr(Ops);
}
const SCEV *ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
return getMinMaxExpr(scSMaxExpr, Ops);
}
const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
return getUMaxExpr(Ops);
}
const SCEV *ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
return getMinMaxExpr(scUMaxExpr, Ops);
}
const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS,
const SCEV *RHS) {
SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
return getSMinExpr(Ops);
}
const SCEV *ScalarEvolution::getSMinExpr(SmallVectorImpl<const SCEV *> &Ops) {
return getMinMaxExpr(scSMinExpr, Ops);
}
const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
bool Sequential) {
SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
return getUMinExpr(Ops, Sequential);
}
const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl<const SCEV *> &Ops,
bool Sequential) {
return Sequential ? getSequentialMinMaxExpr(scSequentialUMinExpr, Ops)
: getMinMaxExpr(scUMinExpr, Ops);
}
const SCEV *
ScalarEvolution::getSizeOfScalableVectorExpr(Type *IntTy,
ScalableVectorType *ScalableTy) {
Constant *NullPtr = Constant::getNullValue(ScalableTy->getPointerTo());
Constant *One = ConstantInt::get(IntTy, 1);
Constant *GEP = ConstantExpr::getGetElementPtr(ScalableTy, NullPtr, One);
return getUnknown(ConstantExpr::getPtrToInt(GEP, IntTy));
}
const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) {
if (auto *ScalableAllocTy = dyn_cast<ScalableVectorType>(AllocTy))
return getSizeOfScalableVectorExpr(IntTy, ScalableAllocTy);
return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
}
const SCEV *ScalarEvolution::getStoreSizeOfExpr(Type *IntTy, Type *StoreTy) {
if (auto *ScalableStoreTy = dyn_cast<ScalableVectorType>(StoreTy))
return getSizeOfScalableVectorExpr(IntTy, ScalableStoreTy);
return getConstant(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
}
const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy,
StructType *STy,
unsigned FieldNo) {
return getConstant(
IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo));
}
const SCEV *ScalarEvolution::getUnknown(Value *V) {
FoldingSetNodeID ID;
ID.AddInteger(scUnknown);
ID.AddPointer(V);
void *IP = nullptr;
if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
assert(cast<SCEVUnknown>(S)->getValue() == V &&
"Stale SCEVUnknown in uniquing map!");
return S;
}
SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
FirstUnknown);
FirstUnknown = cast<SCEVUnknown>(S);
UniqueSCEVs.InsertNode(S, IP);
return S;
}
bool ScalarEvolution::isSCEVable(Type *Ty) const {
return Ty->isIntOrPtrTy();
}
uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const {
assert(isSCEVable(Ty) && "Type is not SCEVable!");
if (Ty->isPointerTy())
return getDataLayout().getIndexTypeSizeInBits(Ty);
return getDataLayout().getTypeSizeInBits(Ty);
}
Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const {
assert(isSCEVable(Ty) && "Type is not SCEVable!");
if (Ty->isIntegerTy())
return Ty;
assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
return getDataLayout().getIndexType(Ty);
}
Type *ScalarEvolution::getWiderType(Type *T1, Type *T2) const {
return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
}
bool ScalarEvolution::instructionCouldExistWitthOperands(const SCEV *A,
const SCEV *B) {
bool PreciseA, PreciseB;
auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
if (!PreciseA || !PreciseB)
return false;
return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
DT.dominates(ScopeB, ScopeA);
}
const SCEV *ScalarEvolution::getCouldNotCompute() {
return CouldNotCompute.get();
}
bool ScalarEvolution::checkValidity(const SCEV *S) const {
bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
auto *SU = dyn_cast<SCEVUnknown>(S);
return SU && SU->getValue() == nullptr;
});
return !ContainsNulls;
}
bool ScalarEvolution::containsAddRecurrence(const SCEV *S) {
HasRecMapType::iterator I = HasRecMap.find(S);
if (I != HasRecMap.end())
return I->second;
bool FoundAddRec =
SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
HasRecMap.insert({S, FoundAddRec});
return FoundAddRec;
}
ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
if (SI == ExprValueMap.end())
return None;
#ifndef NDEBUG
if (VerifySCEVMap) {
for (Value *V : SI->second)
assert(ValueExprMap.count(V));
}
#endif
return SI->second.getArrayRef();
}
void ScalarEvolution::eraseValueFromMap(Value *V) {
ValueExprMapType::iterator I = ValueExprMap.find_as(V);
if (I != ValueExprMap.end()) {
auto EVIt = ExprValueMap.find(I->second);
bool Removed = EVIt->second.remove(V);
(void) Removed;
assert(Removed && "Value not in ExprValueMap?");
ValueExprMap.erase(I);
}
}
void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
auto It = ValueExprMap.find_as(V);
if (It == ValueExprMap.end()) {
ValueExprMap.insert({SCEVCallbackVH(V, this), S});
ExprValueMap[S].insert(V);
}
}
const SCEV *ScalarEvolution::getSCEV(Value *V) {
assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
if (const SCEV *S = getExistingSCEV(V))
return S;
return createSCEVIter(V);
}
const SCEV *ScalarEvolution::getExistingSCEV(Value *V) {
assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
ValueExprMapType::iterator I = ValueExprMap.find_as(V);
if (I != ValueExprMap.end()) {
const SCEV *S = I->second;
assert(checkValidity(S) &&
"existing SCEV has not been properly invalidated");
return S;
}
return nullptr;
}
const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V,
SCEV::NoWrapFlags Flags) {
if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
return getConstant(
cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
Type *Ty = V->getType();
Ty = getEffectiveSCEVType(Ty);
return getMulExpr(V, getMinusOne(Ty), Flags);
}
static const SCEV *MatchNotExpr(const SCEV *Expr) {
const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
if (!Add || Add->getNumOperands() != 2 ||
!Add->getOperand(0)->isAllOnesValue())
return nullptr;
const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
if (!AddRHS || AddRHS->getNumOperands() != 2 ||
!AddRHS->getOperand(0)->isAllOnesValue())
return nullptr;
return AddRHS->getOperand(1);
}
const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) {
assert(!V->getType()->isPointerTy() && "Can't negate pointer");
if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
return getConstant(
cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
SmallVector<const SCEV *, 2> MatchedOperands;
for (const SCEV *Operand : MME->operands()) {
const SCEV *Matched = MatchNotExpr(Operand);
if (!Matched)
return (const SCEV *)nullptr;
MatchedOperands.push_back(Matched);
}
return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
MatchedOperands);
};
if (const SCEV *Replaced = MatchMinMaxNegation(MME))
return Replaced;
}
Type *Ty = V->getType();
Ty = getEffectiveSCEVType(Ty);
return getMinusSCEV(getMinusOne(Ty), V);
}
const SCEV *ScalarEvolution::removePointerBase(const SCEV *P) {
assert(P->getType()->isPointerTy());
if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
SmallVector<const SCEV *> Ops{AddRec->operands()};
Ops[0] = removePointerBase(Ops[0]);
return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
}
if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
SmallVector<const SCEV *> Ops{Add->operands()};
const SCEV **PtrOp = nullptr;
for (const SCEV *&AddOp : Ops) {
if (AddOp->getType()->isPointerTy()) {
assert(!PtrOp && "Cannot have multiple pointer ops");
PtrOp = &AddOp;
}
}
*PtrOp = removePointerBase(*PtrOp);
return getAddExpr(Ops);
}
return getZero(P->getType());
}
const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
SCEV::NoWrapFlags Flags,
unsigned Depth) {
if (LHS == RHS)
return getZero(LHS->getType());
if (RHS->getType()->isPointerTy()) {
if (!LHS->getType()->isPointerTy() ||
getPointerBase(LHS) != getPointerBase(RHS))
return getCouldNotCompute();
LHS = removePointerBase(LHS);
RHS = removePointerBase(RHS);
}
auto AddFlags = SCEV::FlagAnyWrap;
const bool RHSIsNotMinSigned =
!getSignedRangeMin(RHS).isMinSignedValue();
if (hasFlags(Flags, SCEV::FlagNSW)) {
if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
AddFlags = SCEV::FlagNSW;
}
}
auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
}
const SCEV *ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty,
unsigned Depth) {
Type *SrcTy = V->getType();
assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
"Cannot truncate or zero extend with non-integer arguments!");
if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
return V; if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
return getTruncateExpr(V, Ty, Depth);
return getZeroExtendExpr(V, Ty, Depth);
}
const SCEV *ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, Type *Ty,
unsigned Depth) {
Type *SrcTy = V->getType();
assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
"Cannot truncate or zero extend with non-integer arguments!");
if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
return V; if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
return getTruncateExpr(V, Ty, Depth);
return getSignExtendExpr(V, Ty, Depth);
}
const SCEV *
ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) {
Type *SrcTy = V->getType();
assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
"Cannot noop or zero extend with non-integer arguments!");
assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
"getNoopOrZeroExtend cannot truncate!");
if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
return V; return getZeroExtendExpr(V, Ty);
}
const SCEV *
ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) {
Type *SrcTy = V->getType();
assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
"Cannot noop or sign extend with non-integer arguments!");
assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
"getNoopOrSignExtend cannot truncate!");
if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
return V; return getSignExtendExpr(V, Ty);
}
const SCEV *
ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) {
Type *SrcTy = V->getType();
assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
"Cannot noop or any extend with non-integer arguments!");
assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
"getNoopOrAnyExtend cannot truncate!");
if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
return V; return getAnyExtendExpr(V, Ty);
}
const SCEV *
ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) {
Type *SrcTy = V->getType();
assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
"Cannot truncate or noop with non-integer arguments!");
assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) &&
"getTruncateOrNoop cannot extend!");
if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
return V; return getTruncateExpr(V, Ty);
}
const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS,
const SCEV *RHS) {
const SCEV *PromotedLHS = LHS;
const SCEV *PromotedRHS = RHS;
if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
else
PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
return getUMaxExpr(PromotedLHS, PromotedRHS);
}
const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS,
const SCEV *RHS,
bool Sequential) {
SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
return getUMinFromMismatchedTypes(Ops, Sequential);
}
const SCEV *
ScalarEvolution::getUMinFromMismatchedTypes(SmallVectorImpl<const SCEV *> &Ops,
bool Sequential) {
assert(!Ops.empty() && "At least one operand must be!");
if (Ops.size() == 1)
return Ops[0];
Type *MaxType = nullptr;
for (const auto *S : Ops)
if (MaxType)
MaxType = getWiderType(MaxType, S->getType());
else
MaxType = S->getType();
assert(MaxType && "Failed to find maximum type!");
SmallVector<const SCEV *, 2> PromotedOps;
for (const auto *S : Ops)
PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
return getUMinExpr(PromotedOps, Sequential);
}
const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) {
if (!V->getType()->isPointerTy())
return V;
while (true) {
if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
V = AddRec->getStart();
} else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
const SCEV *PtrOp = nullptr;
for (const SCEV *AddOp : Add->operands()) {
if (AddOp->getType()->isPointerTy()) {
assert(!PtrOp && "Cannot have multiple pointer ops");
PtrOp = AddOp;
}
}
assert(PtrOp && "Must have pointer op");
V = PtrOp;
} else return V;
}
}
static void PushDefUseChildren(Instruction *I,
SmallVectorImpl<Instruction *> &Worklist,
SmallPtrSetImpl<Instruction *> &Visited) {
for (User *U : I->users()) {
auto *UserInsn = cast<Instruction>(U);
if (Visited.insert(UserInsn).second)
Worklist.push_back(UserInsn);
}
}
namespace {
class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
public:
static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
bool IgnoreOtherLoops = true) {
SCEVInitRewriter Rewriter(L, SE);
const SCEV *Result = Rewriter.visit(S);
if (Rewriter.hasSeenLoopVariantSCEVUnknown())
return SE.getCouldNotCompute();
return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
? SE.getCouldNotCompute()
: Result;
}
const SCEV *visitUnknown(const SCEVUnknown *Expr) {
if (!SE.isLoopInvariant(Expr, L))
SeenLoopVariantSCEVUnknown = true;
return Expr;
}
const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
if (Expr->getLoop() == L)
return Expr->getStart();
SeenOtherLoops = true;
return Expr;
}
bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
bool hasSeenOtherLoops() { return SeenOtherLoops; }
private:
explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
: SCEVRewriteVisitor(SE), L(L) {}
const Loop *L;
bool SeenLoopVariantSCEVUnknown = false;
bool SeenOtherLoops = false;
};
class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
public:
static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
SCEVPostIncRewriter Rewriter(L, SE);
const SCEV *Result = Rewriter.visit(S);
return Rewriter.hasSeenLoopVariantSCEVUnknown()
? SE.getCouldNotCompute()
: Result;
}
const SCEV *visitUnknown(const SCEVUnknown *Expr) {
if (!SE.isLoopInvariant(Expr, L))
SeenLoopVariantSCEVUnknown = true;
return Expr;
}
const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
if (Expr->getLoop() == L)
return Expr->getPostIncExpr(SE);
SeenOtherLoops = true;
return Expr;
}
bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
bool hasSeenOtherLoops() { return SeenOtherLoops; }
private:
explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
: SCEVRewriteVisitor(SE), L(L) {}
const Loop *L;
bool SeenLoopVariantSCEVUnknown = false;
bool SeenOtherLoops = false;
};
class SCEVBackedgeConditionFolder
: public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
public:
static const SCEV *rewrite(const SCEV *S, const Loop *L,
ScalarEvolution &SE) {
bool IsPosBECond = false;
Value *BECond = nullptr;
if (BasicBlock *Latch = L->getLoopLatch()) {
BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
if (BI && BI->isConditional()) {
assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
"Both outgoing branches should not target same header!");
BECond = BI->getCondition();
IsPosBECond = BI->getSuccessor(0) == L->getHeader();
} else {
return S;
}
}
SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
return Rewriter.visit(S);
}
const SCEV *visitUnknown(const SCEVUnknown *Expr) {
const SCEV *Result = Expr;
bool InvariantF = SE.isLoopInvariant(Expr, L);
if (!InvariantF) {
Instruction *I = cast<Instruction>(Expr->getValue());
switch (I->getOpcode()) {
case Instruction::Select: {
SelectInst *SI = cast<SelectInst>(I);
Optional<const SCEV *> Res =
compareWithBackedgeCondition(SI->getCondition());
if (Res) {
bool IsOne = cast<SCEVConstant>(Res.value())->getValue()->isOne();
Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
}
break;
}
default: {
Optional<const SCEV *> Res = compareWithBackedgeCondition(I);
if (Res)
Result = Res.value();
break;
}
}
}
return Result;
}
private:
explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
bool IsPosBECond, ScalarEvolution &SE)
: SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
IsPositiveBECond(IsPosBECond) {}
Optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
const Loop *L;
Value *BackedgeCond = nullptr;
bool IsPositiveBECond;
};
Optional<const SCEV *>
SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
if (BackedgeCond == IC)
return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
: SE.getZero(Type::getInt1Ty(SE.getContext()));
return None;
}
class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
public:
static const SCEV *rewrite(const SCEV *S, const Loop *L,
ScalarEvolution &SE) {
SCEVShiftRewriter Rewriter(L, SE);
const SCEV *Result = Rewriter.visit(S);
return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
}
const SCEV *visitUnknown(const SCEVUnknown *Expr) {
if (!SE.isLoopInvariant(Expr, L))
Valid = false;
return Expr;
}
const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
if (Expr->getLoop() == L && Expr->isAffine())
return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
Valid = false;
return Expr;
}
bool isValid() { return Valid; }
private:
explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
: SCEVRewriteVisitor(SE), L(L) {}
const Loop *L;
bool Valid = true;
};
}
SCEV::NoWrapFlags
ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
if (!AR->isAffine())
return SCEV::FlagAnyWrap;
using OBO = OverflowingBinaryOperator;
SCEV::NoWrapFlags Result = SCEV::FlagAnyWrap;
if (!AR->hasNoSignedWrap()) {
ConstantRange AddRecRange = getSignedRange(AR);
ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
Instruction::Add, IncRange, OBO::NoSignedWrap);
if (NSWRegion.contains(AddRecRange))
Result = ScalarEvolution::setFlags(Result, SCEV::FlagNSW);
}
if (!AR->hasNoUnsignedWrap()) {
ConstantRange AddRecRange = getUnsignedRange(AR);
ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
Instruction::Add, IncRange, OBO::NoUnsignedWrap);
if (NUWRegion.contains(AddRecRange))
Result = ScalarEvolution::setFlags(Result, SCEV::FlagNUW);
}
return Result;
}
SCEV::NoWrapFlags
ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
SCEV::NoWrapFlags Result = AR->getNoWrapFlags();
if (AR->hasNoSignedWrap())
return Result;
if (!AR->isAffine())
return Result;
const SCEV *Step = AR->getStepRecurrence(*this);
const Loop *L = AR->getLoop();
const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
AC.assumptions().empty())
return Result;
ICmpInst::Predicate Pred;
const SCEV *OverflowLimit =
getSignedOverflowLimitForStep(Step, &Pred, this);
if (OverflowLimit &&
(isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
Result = setFlags(Result, SCEV::FlagNSW);
}
return Result;
}
SCEV::NoWrapFlags
ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
SCEV::NoWrapFlags Result = AR->getNoWrapFlags();
if (AR->hasNoUnsignedWrap())
return Result;
if (!AR->isAffine())
return Result;
const SCEV *Step = AR->getStepRecurrence(*this);
unsigned BitWidth = getTypeSizeInBits(AR->getType());
const Loop *L = AR->getLoop();
const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
AC.assumptions().empty())
return Result;
if (isKnownPositive(Step)) {
const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
getUnsignedRangeMax(Step));
if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) ||
isKnownOnEveryIteration(ICmpInst::ICMP_ULT, AR, N)) {
Result = setFlags(Result, SCEV::FlagNUW);
}
}
return Result;
}
namespace {
struct BinaryOp {
unsigned Opcode;
Value *LHS;
Value *RHS;
bool IsNSW = false;
bool IsNUW = false;
Operator *Op = nullptr;
explicit BinaryOp(Operator *Op)
: Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
Op(Op) {
if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
IsNSW = OBO->hasNoSignedWrap();
IsNUW = OBO->hasNoUnsignedWrap();
}
}
explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
bool IsNUW = false)
: Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
};
}
static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) {
auto *Op = dyn_cast<Operator>(V);
if (!Op)
return None;
switch (Op->getOpcode()) {
case Instruction::Add:
case Instruction::Sub:
case Instruction::Mul:
case Instruction::UDiv:
case Instruction::URem:
case Instruction::And:
case Instruction::Or:
case Instruction::AShr:
case Instruction::Shl:
return BinaryOp(Op);
case Instruction::Xor:
if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
if (RHSC->getValue().isSignMask())
return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
if (V->getType()->isIntegerTy(1))
return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
return BinaryOp(Op);
case Instruction::LShr:
if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
if (SA->getValue().ult(BitWidth)) {
Constant *X =
ConstantInt::get(SA->getContext(),
APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
}
}
return BinaryOp(Op);
case Instruction::ExtractValue: {
auto *EVI = cast<ExtractValueInst>(Op);
if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
break;
auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
if (!WO)
break;
Instruction::BinaryOps BinOp = WO->getBinaryOp();
bool Signed = WO->isSigned();
if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
Signed, !Signed);
}
default:
break;
}
if (auto *II = dyn_cast<IntrinsicInst>(V))
if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
return None;
}
static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
bool &Signed, ScalarEvolution &SE) {
if (Op == SymbolicPHI)
return nullptr;
unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
if (SourceBits != NewBits)
return nullptr;
const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op);
const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op);
if (!SExt && !ZExt)
return nullptr;
const SCEVTruncateExpr *Trunc =
SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand())
: dyn_cast<SCEVTruncateExpr>(ZExt->getOperand());
if (!Trunc)
return nullptr;
const SCEV *X = Trunc->getOperand();
if (X != SymbolicPHI)
return nullptr;
Signed = SExt != nullptr;
return Trunc->getType();
}
static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
if (!PN->getType()->isIntegerTy())
return nullptr;
const Loop *L = LI.getLoopFor(PN->getParent());
if (!L || L->getHeader() != PN->getParent())
return nullptr;
return L;
}
Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
SmallVector<const SCEVPredicate *, 3> Predicates;
auto *PN = cast<PHINode>(SymbolicPHI->getValue());
const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
assert(L && "Expecting an integer loop header phi");
Value *BEValueV = nullptr, *StartValueV = nullptr;
for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
Value *V = PN->getIncomingValue(i);
if (L->contains(PN->getIncomingBlock(i))) {
if (!BEValueV) {
BEValueV = V;
} else if (BEValueV != V) {
BEValueV = nullptr;
break;
}
} else if (!StartValueV) {
StartValueV = V;
} else if (StartValueV != V) {
StartValueV = nullptr;
break;
}
}
if (!BEValueV || !StartValueV)
return None;
const SCEV *BEValue = getSCEV(BEValueV);
const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
if (!Add)
return None;
unsigned FoundIndex = Add->getNumOperands();
Type *TruncTy = nullptr;
bool Signed;
for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
if ((TruncTy =
isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
if (FoundIndex == e) {
FoundIndex = i;
break;
}
if (FoundIndex == Add->getNumOperands())
return None;
SmallVector<const SCEV *, 8> Ops;
for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
if (i != FoundIndex)
Ops.push_back(Add->getOperand(i));
const SCEV *Accum = getAddExpr(Ops);
if (!isLoopInvariant(Accum, L))
return None;
const SCEV *StartVal = getSCEV(StartValueV);
const SCEV *PHISCEV =
getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
SCEVWrapPredicate::IncrementWrapFlags AddedFlags =
Signed ? SCEVWrapPredicate::IncrementNSSW
: SCEVWrapPredicate::IncrementNUSW;
const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
Predicates.push_back(AddRecPred);
}
auto getExtendedExpr = [&](const SCEV *Expr,
bool CreateSignExtend) -> const SCEV * {
assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
const SCEV *ExtendedExpr =
CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
: getZeroExtendExpr(TruncatedExpr, Expr->getType());
return ExtendedExpr;
};
auto PredIsKnownFalse = [&](const SCEV *Expr,
const SCEV *ExtendedExpr) -> bool {
return Expr != ExtendedExpr &&
isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
};
const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
if (PredIsKnownFalse(StartVal, StartExtended)) {
LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
return None;
}
const SCEV *AccumExtended = getExtendedExpr(Accum, true);
if (PredIsKnownFalse(Accum, AccumExtended)) {
LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
return None;
}
auto AppendPredicate = [&](const SCEV *Expr,
const SCEV *ExtendedExpr) -> void {
if (Expr != ExtendedExpr &&
!isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
Predicates.push_back(Pred);
}
};
AppendPredicate(StartVal, StartExtended);
AppendPredicate(Accum, AccumExtended);
auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
std::make_pair(NewAR, Predicates);
PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
return PredRewrite;
}
Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) {
auto *PN = cast<PHINode>(SymbolicPHI->getValue());
const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
if (!L)
return None;
auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
if (I != PredicatedSCEVRewrites.end()) {
std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
I->second;
if (Rewrite.first == SymbolicPHI)
return None;
assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
assert(!(Rewrite.second).empty() && "Expected to find Predicates");
return Rewrite;
}
Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
if (!Rewrite) {
SmallVector<const SCEVPredicate *, 3> Predicates;
PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
return None;
}
return Rewrite;
}
bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
if (AR1 == AR2)
return true;
auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
!Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
return false;
return true;
};
if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
!areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
return false;
return true;
}
const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
Value *BEValueV,
Value *StartValueV) {
const Loop *L = LI.getLoopFor(PN->getParent());
assert(L && L->getHeader() == PN->getParent());
assert(BEValueV && StartValueV);
auto BO = MatchBinaryOp(BEValueV, DT);
if (!BO)
return nullptr;
if (BO->Opcode != Instruction::Add)
return nullptr;
const SCEV *Accum = nullptr;
if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
Accum = getSCEV(BO->RHS);
else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
Accum = getSCEV(BO->LHS);
if (!Accum)
return nullptr;
SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
if (BO->IsNUW)
Flags = setFlags(Flags, SCEV::FlagNUW);
if (BO->IsNSW)
Flags = setFlags(Flags, SCEV::FlagNSW);
const SCEV *StartVal = getSCEV(StartValueV);
const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
insertValueToMap(PN, PHISCEV);
if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
assert(isLoopInvariant(Accum, L) &&
"Accum is defined outside L, but is not invariant?");
if (isAddRecNeverPoison(BEInst, L))
(void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
}
return PHISCEV;
}
const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
const Loop *L = LI.getLoopFor(PN->getParent());
if (!L || L->getHeader() != PN->getParent())
return nullptr;
Value *BEValueV = nullptr, *StartValueV = nullptr;
for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
Value *V = PN->getIncomingValue(i);
if (L->contains(PN->getIncomingBlock(i))) {
if (!BEValueV) {
BEValueV = V;
} else if (BEValueV != V) {
BEValueV = nullptr;
break;
}
} else if (!StartValueV) {
StartValueV = V;
} else if (StartValueV != V) {
StartValueV = nullptr;
break;
}
}
if (!BEValueV || !StartValueV)
return nullptr;
assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
"PHI node already processed?");
if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
return S;
const SCEV *SymbolicName = getUnknown(PN);
insertValueToMap(PN, SymbolicName);
const SCEV *BEValue = getSCEV(BEValueV);
if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
unsigned FoundIndex = Add->getNumOperands();
for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
if (Add->getOperand(i) == SymbolicName)
if (FoundIndex == e) {
FoundIndex = i;
break;
}
if (FoundIndex != Add->getNumOperands()) {
SmallVector<const SCEV *, 8> Ops;
for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
if (i != FoundIndex)
Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
L, *this));
const SCEV *Accum = getAddExpr(Ops);
if (isLoopInvariant(Accum, L) ||
(isa<SCEVAddRecExpr>(Accum) &&
cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
if (auto BO = MatchBinaryOp(BEValueV, DT)) {
if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
if (BO->IsNUW)
Flags = setFlags(Flags, SCEV::FlagNUW);
if (BO->IsNSW)
Flags = setFlags(Flags, SCEV::FlagNSW);
}
} else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
if (GEP->isInBounds() && GEP->getOperand(0) == PN) {
Flags = setFlags(Flags, SCEV::FlagNW);
const SCEV *Ptr = getSCEV(GEP->getPointerOperand());
if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr)))
Flags = setFlags(Flags, SCEV::FlagNUW);
}
}
const SCEV *StartVal = getSCEV(StartValueV);
const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
forgetMemoizedResults(SymbolicName);
insertValueToMap(PN, PHISCEV);
if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
(void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
return PHISCEV;
}
}
} else {
const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
if (Shifted != getCouldNotCompute() &&
Start != getCouldNotCompute()) {
const SCEV *StartVal = getSCEV(StartValueV);
if (Start == StartVal) {
forgetMemoizedResults(SymbolicName);
insertValueToMap(PN, Shifted);
return Shifted;
}
}
}
eraseValueFromMap(PN);
return nullptr;
}
static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S,
BasicBlock *BB) {
struct CheckAvailable {
bool TraversalDone = false;
bool Available = true;
const Loop *L = nullptr; BasicBlock *BB = nullptr;
DominatorTree &DT;
CheckAvailable(const Loop *L, BasicBlock *BB, DominatorTree &DT)
: L(L), BB(BB), DT(DT) {}
bool setUnavailable() {
TraversalDone = true;
Available = false;
return false;
}
bool follow(const SCEV *S) {
switch (S->getSCEVType()) {
case scConstant:
case scPtrToInt:
case scTruncate:
case scZeroExtend:
case scSignExtend:
case scAddExpr:
case scMulExpr:
case scUMaxExpr:
case scSMaxExpr:
case scUMinExpr:
case scSMinExpr:
case scSequentialUMinExpr:
return true;
case scAddRecExpr: {
const auto *ARLoop = cast<SCEVAddRecExpr>(S)->getLoop();
if (L && (ARLoop == L || ARLoop->contains(L)))
return true;
return setUnavailable();
}
case scUnknown: {
const auto *SU = cast<SCEVUnknown>(S);
Value *V = SU->getValue();
if (isa<Argument>(V))
return false;
if (isa<Instruction>(V) && DT.dominates(cast<Instruction>(V), BB))
return false;
return setUnavailable();
}
case scUDivExpr:
case scCouldNotCompute:
return setUnavailable();
}
llvm_unreachable("Unknown SCEV kind!");
}
bool isDone() { return TraversalDone; }
};
CheckAvailable CA(L, BB, DT);
SCEVTraversal<CheckAvailable> ST(CA);
ST.visitAll(S);
return CA.Available;
}
static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge,
Value *&C, Value *&LHS, Value *&RHS) {
C = BI->getCondition();
BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
if (!LeftEdge.isSingleEdge())
return false;
assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
Use &LeftUse = Merge->getOperandUse(0);
Use &RightUse = Merge->getOperandUse(1);
if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
LHS = LeftUse;
RHS = RightUse;
return true;
}
if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
LHS = RightUse;
RHS = LeftUse;
return true;
}
return false;
}
const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
auto IsReachable =
[&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
const Loop *L = LI.getLoopFor(PN->getParent());
for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
if (LI.getLoopFor(PN->getIncomingBlock(i)) != L)
return nullptr;
BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
assert(IDom && "At least the entry block should dominate PN");
auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
if (BI && BI->isConditional() &&
BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
IsAvailableOnEntry(L, DT, getSCEV(LHS), PN->getParent()) &&
IsAvailableOnEntry(L, DT, getSCEV(RHS), PN->getParent()))
return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
}
return nullptr;
}
const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
if (const SCEV *S = createAddRecFromPHI(PN))
return S;
if (const SCEV *S = createNodeFromSelectLikePHI(PN))
return S;
if (Value *V = simplifyInstruction(PN, {getDataLayout(), &TLI, &DT, &AC}))
if (LI.replacementPreservesLCSSAForm(PN, V))
return getSCEV(V);
return getUnknown(PN);
}
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
SCEVTypes RootKind) {
struct FindClosure {
const SCEV *OperandToFind;
const SCEVTypes RootKind; const SCEVTypes NonSequentialRootKind;
bool Found = false;
bool canRecurseInto(SCEVTypes Kind) const {
return RootKind == Kind || NonSequentialRootKind == Kind ||
scZeroExtend == Kind;
};
FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
: OperandToFind(OperandToFind), RootKind(RootKind),
NonSequentialRootKind(
SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
RootKind)) {}
bool follow(const SCEV *S) {
Found = S == OperandToFind;
return !isDone() && canRecurseInto(S->getSCEVType());
}
bool isDone() const { return Found; }
};
FindClosure FC(OperandToFind, RootKind);
visitAll(Root, FC);
return FC.Found;
}
const SCEV *ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(
Instruction *I, ICmpInst *Cond, Value *TrueVal, Value *FalseVal) {
auto *ICI = Cond;
Value *LHS = ICI->getOperand(0);
Value *RHS = ICI->getOperand(1);
switch (ICI->getPredicate()) {
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_SLE:
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_ULE:
std::swap(LHS, RHS);
LLVM_FALLTHROUGH;
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_SGE:
case ICmpInst::ICMP_UGT:
case ICmpInst::ICMP_UGE:
if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) {
bool Signed = ICI->isSigned();
const SCEV *LA = getSCEV(TrueVal);
const SCEV *RA = getSCEV(FalseVal);
const SCEV *LS = getSCEV(LHS);
const SCEV *RS = getSCEV(RHS);
if (LA->getType()->isPointerTy()) {
if (LA == LS && RA == RS)
return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
if (LA == RS && RA == LS)
return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
}
auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
if (Op->getType()->isPointerTy()) {
Op = getLosslessPtrToIntExpr(Op);
if (isa<SCEVCouldNotCompute>(Op))
return Op;
}
if (Signed)
Op = getNoopOrSignExtend(Op, I->getType());
else
Op = getNoopOrZeroExtend(Op, I->getType());
return Op;
};
LS = CoerceOperand(LS);
RS = CoerceOperand(RS);
if (isa<SCEVCouldNotCompute>(LS) || isa<SCEVCouldNotCompute>(RS))
break;
const SCEV *LDiff = getMinusSCEV(LA, LS);
const SCEV *RDiff = getMinusSCEV(RA, RS);
if (LDiff == RDiff)
return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
LDiff);
LDiff = getMinusSCEV(LA, RS);
RDiff = getMinusSCEV(RA, LS);
if (LDiff == RDiff)
return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
LDiff);
}
break;
case ICmpInst::ICMP_NE:
std::swap(TrueVal, FalseVal);
LLVM_FALLTHROUGH;
case ICmpInst::ICMP_EQ:
if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) &&
isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
const SCEV *TrueValExpr = getSCEV(TrueVal); const SCEV *FalseValExpr = getSCEV(FalseVal); const SCEV *Y = getMinusSCEV(FalseValExpr, X); const SCEV *C = getMinusSCEV(TrueValExpr, Y); if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
return getAddExpr(getUMaxExpr(X, C), Y);
}
if (isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero() &&
isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
const SCEV *X = getSCEV(LHS);
while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
X = ZExt->getOperand();
if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(I->getType())) {
const SCEV *FalseValExpr = getSCEV(FalseVal);
if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
return getUMinExpr(getNoopOrZeroExtend(X, I->getType()), FalseValExpr,
true);
}
}
break;
default:
break;
}
return getUnknown(I);
}
static Optional<const SCEV *>
createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr,
const SCEV *TrueExpr, const SCEV *FalseExpr) {
assert(CondExpr->getType()->isIntegerTy(1) &&
TrueExpr->getType() == FalseExpr->getType() &&
TrueExpr->getType()->isIntegerTy(1) &&
"Unexpected operands of a select.");
if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
return None;
const SCEV *X, *C;
if (isa<SCEVConstant>(TrueExpr)) {
CondExpr = SE->getNotSCEV(CondExpr);
X = FalseExpr;
C = TrueExpr;
} else {
X = TrueExpr;
C = FalseExpr;
}
return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
true));
}
static Optional<const SCEV *> createNodeForSelectViaUMinSeq(ScalarEvolution *SE,
Value *Cond,
Value *TrueVal,
Value *FalseVal) {
if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
return None;
const auto *SECond = SE->getSCEV(Cond);
const auto *SETrue = SE->getSCEV(TrueVal);
const auto *SEFalse = SE->getSCEV(FalseVal);
return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
}
const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
assert(TrueVal->getType() == FalseVal->getType() &&
V->getType() == TrueVal->getType() &&
"Types of select hands and of the result must match.");
if (!V->getType()->isIntegerTy(1))
return getUnknown(V);
if (Optional<const SCEV *> S =
createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
return *S;
return getUnknown(V);
}
const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
Value *TrueVal,
Value *FalseVal) {
if (auto *CI = dyn_cast<ConstantInt>(Cond))
return getSCEV(CI->isOne() ? TrueVal : FalseVal);
if (auto *I = dyn_cast<Instruction>(V)) {
if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
const SCEV *S = createNodeForSelectOrPHIInstWithICmpInstCond(
I, ICI, TrueVal, FalseVal);
if (!isa<SCEVUnknown>(S))
return S;
}
}
return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
}
const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
assert(GEP->getSourceElementType()->isSized() &&
"GEP source element type must be sized");
SmallVector<const SCEV *, 4> IndexExprs;
for (Value *Index : GEP->indices())
IndexExprs.push_back(getSCEV(Index));
return getGEPExpr(GEP, IndexExprs);
}
uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) {
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
return C->getAPInt().countTrailingZeros();
if (const SCEVPtrToIntExpr *I = dyn_cast<SCEVPtrToIntExpr>(S))
return GetMinTrailingZeros(I->getOperand());
if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
return std::min(GetMinTrailingZeros(T->getOperand()),
(uint32_t)getTypeSizeInBits(T->getType()));
if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
return OpRes == getTypeSizeInBits(E->getOperand()->getType())
? getTypeSizeInBits(E->getType())
: OpRes;
}
if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
return OpRes == getTypeSizeInBits(E->getOperand()->getType())
? getTypeSizeInBits(E->getType())
: OpRes;
}
if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
return MinOpRes;
}
if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0));
uint32_t BitWidth = getTypeSizeInBits(M->getType());
for (unsigned i = 1, e = M->getNumOperands();
SumOpRes != BitWidth && i != e; ++i)
SumOpRes =
std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), BitWidth);
return SumOpRes;
}
if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
return MinOpRes;
}
if (const SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) {
uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
return MinOpRes;
}
if (const SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) {
uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
return MinOpRes;
}
if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
KnownBits Known = computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT);
return Known.countMinTrailingZeros();
}
return 0;
}
uint32_t ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
auto I = MinTrailingZerosCache.find(S);
if (I != MinTrailingZerosCache.end())
return I->second;
uint32_t Result = GetMinTrailingZerosImpl(S);
auto InsertPair = MinTrailingZerosCache.insert({S, Result});
assert(InsertPair.second && "Should insert a new key");
return InsertPair.first->second;
}
static Optional<ConstantRange> GetRangeFromMetadata(Value *V) {
if (Instruction *I = dyn_cast<Instruction>(V))
if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
return getConstantRangeFromMetadata(*MD);
return None;
}
void ScalarEvolution::setNoWrapFlags(SCEVAddRecExpr *AddRec,
SCEV::NoWrapFlags Flags) {
if (AddRec->getNoWrapFlags(Flags) != Flags) {
AddRec->setNoWrapFlags(Flags);
UnsignedRanges.erase(AddRec);
SignedRanges.erase(AddRec);
}
}
ConstantRange ScalarEvolution::
getRangeForUnknownRecurrence(const SCEVUnknown *U) {
const DataLayout &DL = getDataLayout();
unsigned BitWidth = getTypeSizeInBits(U->getType());
const ConstantRange FullSet(BitWidth, true);
auto *P = dyn_cast<PHINode>(U->getValue());
if (!P)
return FullSet;
for (auto *Pred : predecessors(P->getParent()))
if (!DT.isReachableFromEntry(Pred))
return FullSet;
BinaryOperator *BO;
Value *Start, *Step;
if (!matchSimpleRecurrence(P, BO, Start, Step))
return FullSet;
auto *L = LI.getLoopFor(P->getParent());
assert(L && L->getHeader() == P->getParent());
if (!L->contains(BO->getParent()))
return FullSet;
switch (BO->getOpcode()) {
default:
return FullSet;
case Instruction::AShr:
case Instruction::LShr:
case Instruction::Shl:
break;
};
if (BO->getOperand(0) != P)
return FullSet;
unsigned TC = getSmallConstantMaxTripCount(L);
if (!TC || TC >= BitWidth)
return FullSet;
auto KnownStart = computeKnownBits(Start, DL, 0, &AC, nullptr, &DT);
auto KnownStep = computeKnownBits(Step, DL, 0, &AC, nullptr, &DT);
assert(KnownStart.getBitWidth() == BitWidth &&
KnownStep.getBitWidth() == BitWidth);
auto MaxShiftAmt = KnownStep.getMaxValue();
APInt TCAP(BitWidth, TC-1);
bool Overflow = false;
auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
if (Overflow)
return FullSet;
switch (BO->getOpcode()) {
default:
llvm_unreachable("filtered out above");
case Instruction::AShr: {
auto KnownEnd = KnownBits::ashr(KnownStart,
KnownBits::makeConstant(TotalShift));
if (KnownStart.isNonNegative())
return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
KnownStart.getMaxValue() + 1);
if (KnownStart.isNegative())
return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
KnownEnd.getMaxValue() + 1);
break;
}
case Instruction::LShr: {
auto KnownEnd = KnownBits::lshr(KnownStart,
KnownBits::makeConstant(TotalShift));
return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
KnownStart.getMaxValue() + 1);
}
case Instruction::Shl: {
auto KnownEnd = KnownBits::shl(KnownStart,
KnownBits::makeConstant(TotalShift));
if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
return ConstantRange(KnownStart.getMinValue(),
KnownEnd.getMaxValue() + 1);
break;
}
};
return FullSet;
}
const ConstantRange &
ScalarEvolution::getRangeRef(const SCEV *S,
ScalarEvolution::RangeSignHint SignHint) {
DenseMap<const SCEV *, ConstantRange> &Cache =
SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
: SignedRanges;
ConstantRange::PreferredRangeType RangeType =
SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED
? ConstantRange::Unsigned : ConstantRange::Signed;
DenseMap<const SCEV *, ConstantRange>::iterator I = Cache.find(S);
if (I != Cache.end())
return I->second;
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
return setRange(C, SignHint, ConstantRange(C->getAPInt()));
unsigned BitWidth = getTypeSizeInBits(S->getType());
ConstantRange ConservativeResult(BitWidth, true);
using OBO = OverflowingBinaryOperator;
uint32_t TZ = GetMinTrailingZeros(S);
if (TZ != 0) {
if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED)
ConservativeResult =
ConstantRange(APInt::getMinValue(BitWidth),
APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1);
else
ConservativeResult = ConstantRange(
APInt::getSignedMinValue(BitWidth),
APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
}
if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
ConstantRange X = getRangeRef(Add->getOperand(0), SignHint);
unsigned WrapType = OBO::AnyWrap;
if (Add->hasNoSignedWrap())
WrapType |= OBO::NoSignedWrap;
if (Add->hasNoUnsignedWrap())
WrapType |= OBO::NoUnsignedWrap;
for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
X = X.addWithNoWrap(getRangeRef(Add->getOperand(i), SignHint),
WrapType, RangeType);
return setRange(Add, SignHint,
ConservativeResult.intersectWith(X, RangeType));
}
if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint);
for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
X = X.multiply(getRangeRef(Mul->getOperand(i), SignHint));
return setRange(Mul, SignHint,
ConservativeResult.intersectWith(X, RangeType));
}
if (isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) {
Intrinsic::ID ID;
switch (S->getSCEVType()) {
case scUMaxExpr:
ID = Intrinsic::umax;
break;
case scSMaxExpr:
ID = Intrinsic::smax;
break;
case scUMinExpr:
case scSequentialUMinExpr:
ID = Intrinsic::umin;
break;
case scSMinExpr:
ID = Intrinsic::smin;
break;
default:
llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
}
const auto *NAry = cast<SCEVNAryExpr>(S);
ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint);
for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
X = X.intrinsic(ID, {X, getRangeRef(NAry->getOperand(i), SignHint)});
return setRange(S, SignHint,
ConservativeResult.intersectWith(X, RangeType));
}
if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) {
ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint);
ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint);
return setRange(UDiv, SignHint,
ConservativeResult.intersectWith(X.udiv(Y), RangeType));
}
if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) {
ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint);
return setRange(ZExt, SignHint,
ConservativeResult.intersectWith(X.zeroExtend(BitWidth),
RangeType));
}
if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) {
ConstantRange X = getRangeRef(SExt->getOperand(), SignHint);
return setRange(SExt, SignHint,
ConservativeResult.intersectWith(X.signExtend(BitWidth),
RangeType));
}
if (const SCEVPtrToIntExpr *PtrToInt = dyn_cast<SCEVPtrToIntExpr>(S)) {
ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint);
return setRange(PtrToInt, SignHint, X);
}
if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint);
return setRange(Trunc, SignHint,
ConservativeResult.intersectWith(X.truncate(BitWidth),
RangeType));
}
if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
if (AddRec->hasNoUnsignedWrap()) {
APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
if (!UnsignedMinValue.isZero())
ConservativeResult = ConservativeResult.intersectWith(
ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
}
if (AddRec->hasNoSignedWrap()) {
bool AllNonNeg = true;
bool AllNonPos = true;
for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
if (!isKnownNonNegative(AddRec->getOperand(i)))
AllNonNeg = false;
if (!isKnownNonPositive(AddRec->getOperand(i)))
AllNonPos = false;
}
if (AllNonNeg)
ConservativeResult = ConservativeResult.intersectWith(
ConstantRange::getNonEmpty(getSignedRangeMin(AddRec->getStart()),
APInt::getSignedMinValue(BitWidth)),
RangeType);
else if (AllNonPos)
ConservativeResult = ConservativeResult.intersectWith(
ConstantRange::getNonEmpty(
APInt::getSignedMinValue(BitWidth),
getSignedRangeMax(AddRec->getStart()) + 1),
RangeType);
}
if (AddRec->isAffine()) {
const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(AddRec->getLoop());
if (!isa<SCEVCouldNotCompute>(MaxBECount) &&
getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) {
auto RangeFromAffine = getRangeForAffineAR(
AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount,
BitWidth);
ConservativeResult =
ConservativeResult.intersectWith(RangeFromAffine, RangeType);
auto RangeFromFactoring = getRangeViaFactoring(
AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount,
BitWidth);
ConservativeResult =
ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
}
if (UseExpensiveRangeSharpening) {
const SCEV *SymbolicMaxBECount =
getSymbolicMaxBackedgeTakenCount(AddRec->getLoop());
if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
getTypeSizeInBits(MaxBECount->getType()) <= BitWidth &&
AddRec->hasNoSelfWrap()) {
auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
AddRec, SymbolicMaxBECount, BitWidth, SignHint);
ConservativeResult =
ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
}
}
}
return setRange(AddRec, SignHint, std::move(ConservativeResult));
}
if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
Optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue());
if (MDRange)
ConservativeResult =
ConservativeResult.intersectWith(MDRange.value(), RangeType);
auto CR = getRangeForUnknownRecurrence(U);
ConservativeResult = ConservativeResult.intersectWith(CR);
const DataLayout &DL = getDataLayout();
KnownBits Known = computeKnownBits(U->getValue(), DL, 0, &AC, nullptr, &DT);
if (Known.getBitWidth() != BitWidth)
Known = Known.zextOrTrunc(BitWidth);
unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, &AC, nullptr, &DT);
if (U->getType()->isPointerTy()) {
unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
int ptrIdxDiff = ptrSize - BitWidth;
if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
NS -= ptrIdxDiff;
}
if (NS > 1) {
if (!Known.Zero.getHiBits(NS).isZero())
Known.Zero.setHighBits(NS);
if (!Known.One.getHiBits(NS).isZero())
Known.One.setHighBits(NS);
}
if (Known.getMinValue() != Known.getMaxValue() + 1)
ConservativeResult = ConservativeResult.intersectWith(
ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
RangeType);
if (NS > 1)
ConservativeResult = ConservativeResult.intersectWith(
ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
RangeType);
if (const PHINode *Phi = dyn_cast<PHINode>(U->getValue())) {
if (PendingPhiRanges.insert(Phi).second) {
ConstantRange RangeFromOps(BitWidth, false);
for (const auto &Op : Phi->operands()) {
auto OpRange = getRangeRef(getSCEV(Op), SignHint);
RangeFromOps = RangeFromOps.unionWith(OpRange);
if (RangeFromOps.isFullSet())
break;
}
ConservativeResult =
ConservativeResult.intersectWith(RangeFromOps, RangeType);
bool Erased = PendingPhiRanges.erase(Phi);
assert(Erased && "Failed to erase Phi properly?");
(void) Erased;
}
}
if (const auto *II = dyn_cast<IntrinsicInst>(U->getValue()))
if (II->getIntrinsicID() == Intrinsic::vscale) {
ConstantRange Disallowed = APInt::getZero(BitWidth);
ConservativeResult = ConservativeResult.difference(Disallowed);
}
return setRange(U, SignHint, std::move(ConservativeResult));
}
return setRange(S, SignHint, std::move(ConservativeResult));
}
static ConstantRange getRangeForAffineARHelper(APInt Step,
const ConstantRange &StartRange,
const APInt &MaxBECount,
unsigned BitWidth, bool Signed) {
if (Step == 0 || MaxBECount == 0)
return StartRange;
if (StartRange.isFullSet())
return ConstantRange::getFull(BitWidth);
bool Descending = Signed && Step.isNegative();
if (Signed)
Step = Step.abs();
if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
return ConstantRange::getFull(BitWidth);
APInt Offset = Step * MaxBECount;
APInt StartLower = StartRange.getLower();
APInt StartUpper = StartRange.getUpper() - 1;
APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
: (StartUpper + std::move(Offset));
if (StartRange.contains(MovedBoundary))
return ConstantRange::getFull(BitWidth);
APInt NewLower =
Descending ? std::move(MovedBoundary) : std::move(StartLower);
APInt NewUpper =
Descending ? std::move(StartUpper) : std::move(MovedBoundary);
NewUpper += 1;
return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
}
ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
const SCEV *Step,
const SCEV *MaxBECount,
unsigned BitWidth) {
assert(!isa<SCEVCouldNotCompute>(MaxBECount) &&
getTypeSizeInBits(MaxBECount->getType()) <= BitWidth &&
"Precondition!");
MaxBECount = getNoopOrZeroExtend(MaxBECount, Start->getType());
APInt MaxBECountValue = getUnsignedRangeMax(MaxBECount);
ConstantRange StartSRange = getSignedRange(Start);
ConstantRange StepSRange = getSignedRange(Step);
ConstantRange SR =
getRangeForAffineARHelper(StepSRange.getSignedMin(), StartSRange,
MaxBECountValue, BitWidth, true);
SR = SR.unionWith(getRangeForAffineARHelper(StepSRange.getSignedMax(),
StartSRange, MaxBECountValue,
BitWidth, true));
ConstantRange UR = getRangeForAffineARHelper(
getUnsignedRangeMax(Step), getUnsignedRange(Start),
MaxBECountValue, BitWidth, false);
return SR.intersectWith(UR, ConstantRange::Smallest);
}
ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
ScalarEvolution::RangeSignHint SignHint) {
assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
assert(AddRec->hasNoSelfWrap() &&
"This only works for non-self-wrapping AddRecs!");
const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
const SCEV *Step = AddRec->getStepRecurrence(*this);
if (!isa<SCEVConstant>(Step))
return ConstantRange::getFull(BitWidth);
if (getTypeSizeInBits(MaxBECount->getType()) >
getTypeSizeInBits(AddRec->getType()))
return ConstantRange::getFull(BitWidth);
MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
const SCEV *RangeWidth = getMinusOne(AddRec->getType());
const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
MaxItersWithoutWrap))
return ConstantRange::getFull(BitWidth);
ICmpInst::Predicate LEPred =
IsSigned ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
ICmpInst::Predicate GEPred =
IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
const SCEV *Start = AddRec->getStart();
ConstantRange StartRange = getRangeRef(Start, SignHint);
ConstantRange EndRange = getRangeRef(End, SignHint);
ConstantRange RangeBetween = StartRange.unionWith(EndRange);
if (RangeBetween.isFullSet())
return RangeBetween;
bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
: RangeBetween.isWrappedSet();
if (IsWrappedSet)
return ConstantRange::getFull(BitWidth);
if (isKnownPositive(Step) &&
isKnownPredicateViaConstantRanges(LEPred, Start, End))
return RangeBetween;
else if (isKnownNegative(Step) &&
isKnownPredicateViaConstantRanges(GEPred, Start, End))
return RangeBetween;
return ConstantRange::getFull(BitWidth);
}
ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
const SCEV *Step,
const SCEV *MaxBECount,
unsigned BitWidth) {
struct SelectPattern {
Value *Condition = nullptr;
APInt TrueValue;
APInt FalseValue;
explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
const SCEV *S) {
Optional<unsigned> CastOp;
APInt Offset(BitWidth, 0);
assert(SE.getTypeSizeInBits(S->getType()) == BitWidth &&
"Should be!");
if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
return;
Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
S = SA->getOperand(1);
}
if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
CastOp = SCast->getSCEVType();
S = SCast->getOperand();
}
using namespace llvm::PatternMatch;
auto *SU = dyn_cast<SCEVUnknown>(S);
const APInt *TrueVal, *FalseVal;
if (!SU ||
!match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
m_APInt(FalseVal)))) {
Condition = nullptr;
return;
}
TrueValue = *TrueVal;
FalseValue = *FalseVal;
if (CastOp)
switch (*CastOp) {
default:
llvm_unreachable("Unknown SCEV cast type!");
case scTruncate:
TrueValue = TrueValue.trunc(BitWidth);
FalseValue = FalseValue.trunc(BitWidth);
break;
case scZeroExtend:
TrueValue = TrueValue.zext(BitWidth);
FalseValue = FalseValue.zext(BitWidth);
break;
case scSignExtend:
TrueValue = TrueValue.sext(BitWidth);
FalseValue = FalseValue.sext(BitWidth);
break;
}
TrueValue += Offset;
FalseValue += Offset;
}
bool isRecognized() { return Condition != nullptr; }
};
SelectPattern StartPattern(*this, BitWidth, Start);
if (!StartPattern.isRecognized())
return ConstantRange::getFull(BitWidth);
SelectPattern StepPattern(*this, BitWidth, Step);
if (!StepPattern.isRecognized())
return ConstantRange::getFull(BitWidth);
if (StartPattern.Condition != StepPattern.Condition) {
return ConstantRange::getFull(BitWidth);
}
const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
ConstantRange TrueRange =
this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount, BitWidth);
ConstantRange FalseRange =
this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount, BitWidth);
return TrueRange.unionWith(FalseRange);
}
SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
const BinaryOperator *BinOp = cast<BinaryOperator>(V);
SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
if (BinOp->hasNoUnsignedWrap())
Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
if (BinOp->hasNoSignedWrap())
Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
if (Flags == SCEV::FlagAnyWrap)
return SCEV::FlagAnyWrap;
return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
}
const Instruction *
ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
return &*AddRec->getLoop()->getHeader()->begin();
if (auto *U = dyn_cast<SCEVUnknown>(S))
if (auto *I = dyn_cast<Instruction>(U->getValue()))
return I;
return nullptr;
}
static void collectUniqueOps(const SCEV *S,
SmallVectorImpl<const SCEV *> &Ops) {
SmallPtrSet<const SCEV *, 4> Unique;
auto InsertUnique = [&](const SCEV *S) {
if (Unique.insert(S).second)
Ops.push_back(S);
};
if (auto *S2 = dyn_cast<SCEVCastExpr>(S))
for (const auto *Op : S2->operands())
InsertUnique(Op);
else if (auto *S2 = dyn_cast<SCEVNAryExpr>(S))
for (const auto *Op : S2->operands())
InsertUnique(Op);
else if (auto *S2 = dyn_cast<SCEVUDivExpr>(S))
for (const auto *Op : S2->operands())
InsertUnique(Op);
}
const Instruction *
ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
bool &Precise) {
Precise = true;
SmallSet<const SCEV *, 16> Visited;
SmallVector<const SCEV *> Worklist;
auto pushOp = [&](const SCEV *S) {
if (!Visited.insert(S).second)
return;
if (Visited.size() > 30) {
Precise = false;
return;
}
Worklist.push_back(S);
};
for (const auto *S : Ops)
pushOp(S);
const Instruction *Bound = nullptr;
while (!Worklist.empty()) {
auto *S = Worklist.pop_back_val();
if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
if (!Bound || DT.dominates(Bound, DefI))
Bound = DefI;
} else {
SmallVector<const SCEV *, 4> Ops;
collectUniqueOps(S, Ops);
for (const auto *Op : Ops)
pushOp(Op);
}
}
return Bound ? Bound : &*F.getEntryBlock().begin();
}
const Instruction *
ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
bool Discard;
return getDefiningScopeBound(Ops, Discard);
}
bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
const Instruction *B) {
if (A->getParent() == B->getParent() &&
isGuaranteedToTransferExecutionToSuccessor(A->getIterator(),
B->getIterator()))
return true;
auto *BLoop = LI.getLoopFor(B->getParent());
if (BLoop && BLoop->getHeader() == B->getParent() &&
BLoop->getLoopPreheader() == A->getParent() &&
isGuaranteedToTransferExecutionToSuccessor(A->getIterator(),
A->getParent()->end()) &&
isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
B->getIterator()))
return true;
return false;
}
bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
if (!programUndefinedIfPoison(I))
return false;
SmallVector<const SCEV *> SCEVOps;
for (const Use &Op : I->operands()) {
if (isSCEVable(Op->getType()))
SCEVOps.push_back(getSCEV(Op));
}
auto *DefI = getDefiningScopeBound(SCEVOps);
return isGuaranteedToTransferExecutionTo(DefI, I);
}
bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
if (isSCEVExprNeverPoison(I))
return true;
auto *ExitingBB = L->getExitingBlock();
auto *LatchBB = L->getLoopLatch();
if (!ExitingBB || !LatchBB || ExitingBB != LatchBB)
return false;
SmallPtrSet<const Instruction *, 16> Pushed;
SmallVector<const Instruction *, 8> PoisonStack;
Pushed.insert(I);
PoisonStack.push_back(I);
bool LatchControlDependentOnPoison = false;
while (!PoisonStack.empty() && !LatchControlDependentOnPoison) {
const Instruction *Poison = PoisonStack.pop_back_val();
for (const auto *PoisonUser : Poison->users()) {
if (propagatesPoison(cast<Operator>(PoisonUser))) {
if (Pushed.insert(cast<Instruction>(PoisonUser)).second)
PoisonStack.push_back(cast<Instruction>(PoisonUser));
} else if (auto *BI = dyn_cast<BranchInst>(PoisonUser)) {
assert(BI->isConditional() && "Only possibility!");
if (BI->getParent() == LatchBB) {
LatchControlDependentOnPoison = true;
break;
}
}
}
}
return LatchControlDependentOnPoison && loopHasNoAbnormalExits(L);
}
ScalarEvolution::LoopProperties
ScalarEvolution::getLoopProperties(const Loop *L) {
using LoopProperties = ScalarEvolution::LoopProperties;
auto Itr = LoopPropertiesCache.find(L);
if (Itr == LoopPropertiesCache.end()) {
auto HasSideEffects = [](Instruction *I) {
if (auto *SI = dyn_cast<StoreInst>(I))
return !SI->isSimple();
return I->mayThrow() || I->mayWriteToMemory();
};
LoopProperties LP = { true,
true};
for (auto *BB : L->getBlocks())
for (auto &I : *BB) {
if (!isGuaranteedToTransferExecutionToSuccessor(&I))
LP.HasNoAbnormalExits = false;
if (HasSideEffects(&I))
LP.HasNoSideEffects = false;
if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
break; }
auto InsertPair = LoopPropertiesCache.insert({L, LP});
assert(InsertPair.second && "We just checked!");
Itr = InsertPair.first;
}
return Itr->second;
}
bool ScalarEvolution::loopIsFiniteByAssumption(const Loop *L) {
return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
}
const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
using PointerTy = PointerIntPair<Value *, 1, bool>;
SmallVector<PointerTy> Stack;
Stack.emplace_back(V, true);
Stack.emplace_back(V, false);
while (!Stack.empty()) {
auto E = Stack.pop_back_val();
Value *CurV = E.getPointer();
if (getExistingSCEV(CurV))
continue;
SmallVector<Value *> Ops;
const SCEV *CreatedSCEV = nullptr;
if (E.getInt()) {
CreatedSCEV = createSCEV(CurV);
} else {
CreatedSCEV = getOperandsToCreate(CurV, Ops);
}
if (CreatedSCEV) {
insertValueToMap(CurV, CreatedSCEV);
} else {
Stack.emplace_back(CurV, true);
for (Value *Op : Ops)
Stack.emplace_back(Op, false);
}
}
return getExistingSCEV(V);
}
const SCEV *
ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
if (!isSCEVable(V->getType()))
return getUnknown(V);
if (Instruction *I = dyn_cast<Instruction>(V)) {
if (!DT.isReachableFromEntry(I->getParent()))
return getUnknown(PoisonValue::get(V->getType()));
} else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
return getConstant(CI);
else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) {
if (!GA->isInterposable()) {
Ops.push_back(GA->getAliasee());
return nullptr;
}
return getUnknown(V);
} else if (!isa<ConstantExpr>(V))
return getUnknown(V);
Operator *U = cast<Operator>(V);
if (auto BO = MatchBinaryOp(U, DT)) {
bool IsConstArg = isa<ConstantInt>(BO->RHS);
switch (BO->Opcode) {
case Instruction::Add: {
do {
if (BO->Op) {
if (BO->Op != V && getExistingSCEV(BO->Op)) {
Ops.push_back(BO->Op);
break;
}
}
Ops.push_back(BO->RHS);
auto NewBO = MatchBinaryOp(BO->LHS, DT);
if (!NewBO || (NewBO->Opcode != Instruction::Add &&
NewBO->Opcode != Instruction::Sub)) {
Ops.push_back(BO->LHS);
break;
}
BO = NewBO;
} while (true);
return nullptr;
}
case Instruction::Mul: {
do {
if (BO->Op) {
if (BO->Op != V && getExistingSCEV(BO->Op)) {
Ops.push_back(BO->Op);
break;
}
}
Ops.push_back(BO->RHS);
auto NewBO = MatchBinaryOp(BO->LHS, DT);
if (!NewBO || NewBO->Opcode != Instruction::Mul) {
Ops.push_back(BO->LHS);
break;
}
BO = NewBO;
} while (true);
return nullptr;
}
case Instruction::Sub:
case Instruction::UDiv:
case Instruction::URem:
break;
case Instruction::AShr:
case Instruction::Shl:
case Instruction::Xor:
if (!IsConstArg)
return nullptr;
break;
case Instruction::And:
case Instruction::Or:
if (!IsConstArg && BO->LHS->getType()->isIntegerTy(1))
return nullptr;
break;
case Instruction::LShr:
return getUnknown(V);
default:
llvm_unreachable("Unhandled binop");
break;
}
Ops.push_back(BO->LHS);
Ops.push_back(BO->RHS);
return nullptr;
}
switch (U->getOpcode()) {
case Instruction::Trunc:
case Instruction::ZExt:
case Instruction::SExt:
case Instruction::PtrToInt:
Ops.push_back(U->getOperand(0));
return nullptr;
case Instruction::BitCast:
if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
Ops.push_back(U->getOperand(0));
return nullptr;
}
return getUnknown(V);
case Instruction::SDiv:
case Instruction::SRem:
Ops.push_back(U->getOperand(0));
Ops.push_back(U->getOperand(1));
return nullptr;
case Instruction::GetElementPtr:
assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
"GEP source element type must be sized");
for (Value *Index : U->operands())
Ops.push_back(Index);
return nullptr;
case Instruction::IntToPtr:
return getUnknown(V);
case Instruction::PHI:
return nullptr;
case Instruction::Select: {
auto CanSimplifyToUnknown = [this, U]() {
if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
return false;
auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
if (!ICI)
return false;
Value *LHS = ICI->getOperand(0);
Value *RHS = ICI->getOperand(1);
if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
ICI->getPredicate() == CmpInst::ICMP_NE) {
if (!(isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()))
return true;
} else if (getTypeSizeInBits(LHS->getType()) >
getTypeSizeInBits(U->getType()))
return true;
return false;
};
if (CanSimplifyToUnknown())
return getUnknown(U);
for (Value *Inc : U->operands())
Ops.push_back(Inc);
return nullptr;
break;
}
case Instruction::Call:
case Instruction::Invoke:
if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
Ops.push_back(RV);
return nullptr;
}
if (auto *II = dyn_cast<IntrinsicInst>(U)) {
switch (II->getIntrinsicID()) {
case Intrinsic::abs:
Ops.push_back(II->getArgOperand(0));
return nullptr;
case Intrinsic::umax:
case Intrinsic::umin:
case Intrinsic::smax:
case Intrinsic::smin:
case Intrinsic::usub_sat:
case Intrinsic::uadd_sat:
Ops.push_back(II->getArgOperand(0));
Ops.push_back(II->getArgOperand(1));
return nullptr;
case Intrinsic::start_loop_iterations:
case Intrinsic::annotation:
case Intrinsic::ptr_annotation:
Ops.push_back(II->getArgOperand(0));
return nullptr;
default:
break;
}
}
break;
}
return nullptr;
}
const SCEV *ScalarEvolution::createSCEV(Value *V) {
if (!isSCEVable(V->getType()))
return getUnknown(V);
if (Instruction *I = dyn_cast<Instruction>(V)) {
if (!DT.isReachableFromEntry(I->getParent()))
return getUnknown(PoisonValue::get(V->getType()));
} else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
return getConstant(CI);
else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V))
return GA->isInterposable() ? getUnknown(V) : getSCEV(GA->getAliasee());
else if (!isa<ConstantExpr>(V))
return getUnknown(V);
const SCEV *LHS;
const SCEV *RHS;
Operator *U = cast<Operator>(V);
if (auto BO = MatchBinaryOp(U, DT)) {
switch (BO->Opcode) {
case Instruction::Add: {
SmallVector<const SCEV *, 4> AddOps;
do {
if (BO->Op) {
if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
AddOps.push_back(OpSCEV);
break;
}
const SCEV *RHS = getSCEV(BO->RHS);
SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
if (Flags != SCEV::FlagAnyWrap) {
const SCEV *LHS = getSCEV(BO->LHS);
if (BO->Opcode == Instruction::Sub)
AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
else
AddOps.push_back(getAddExpr(LHS, RHS, Flags));
break;
}
}
if (BO->Opcode == Instruction::Sub)
AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
else
AddOps.push_back(getSCEV(BO->RHS));
auto NewBO = MatchBinaryOp(BO->LHS, DT);
if (!NewBO || (NewBO->Opcode != Instruction::Add &&
NewBO->Opcode != Instruction::Sub)) {
AddOps.push_back(getSCEV(BO->LHS));
break;
}
BO = NewBO;
} while (true);
return getAddExpr(AddOps);
}
case Instruction::Mul: {
SmallVector<const SCEV *, 4> MulOps;
do {
if (BO->Op) {
if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
MulOps.push_back(OpSCEV);
break;
}
SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
if (Flags != SCEV::FlagAnyWrap) {
LHS = getSCEV(BO->LHS);
RHS = getSCEV(BO->RHS);
MulOps.push_back(getMulExpr(LHS, RHS, Flags));
break;
}
}
MulOps.push_back(getSCEV(BO->RHS));
auto NewBO = MatchBinaryOp(BO->LHS, DT);
if (!NewBO || NewBO->Opcode != Instruction::Mul) {
MulOps.push_back(getSCEV(BO->LHS));
break;
}
BO = NewBO;
} while (true);
return getMulExpr(MulOps);
}
case Instruction::UDiv:
LHS = getSCEV(BO->LHS);
RHS = getSCEV(BO->RHS);
return getUDivExpr(LHS, RHS);
case Instruction::URem:
LHS = getSCEV(BO->LHS);
RHS = getSCEV(BO->RHS);
return getURemExpr(LHS, RHS);
case Instruction::Sub: {
SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
if (BO->Op)
Flags = getNoWrapFlagsFromUB(BO->Op);
LHS = getSCEV(BO->LHS);
RHS = getSCEV(BO->RHS);
return getMinusSCEV(LHS, RHS, Flags);
}
case Instruction::And:
if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
if (CI->isZero())
return getSCEV(BO->RHS);
if (CI->isMinusOne())
return getSCEV(BO->LHS);
const APInt &A = CI->getValue();
unsigned LZ = A.countLeadingZeros();
unsigned TZ = A.countTrailingZeros();
unsigned BitWidth = A.getBitWidth();
KnownBits Known(BitWidth);
computeKnownBits(BO->LHS, Known, getDataLayout(),
0, &AC, nullptr, &DT);
APInt EffectiveMask =
APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
const SCEV *LHS = getSCEV(BO->LHS);
const SCEV *ShiftedLHS = nullptr;
if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
unsigned MulZeros = OpC->getAPInt().countTrailingZeros();
unsigned GCD = std::min(MulZeros, TZ);
APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
SmallVector<const SCEV*, 4> MulOps;
MulOps.push_back(getConstant(OpC->getAPInt().lshr(GCD)));
MulOps.append(LHSMul->op_begin() + 1, LHSMul->op_end());
auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
}
}
if (!ShiftedLHS)
ShiftedLHS = getUDivExpr(LHS, MulCount);
return getMulExpr(
getZeroExtendExpr(
getTruncateExpr(ShiftedLHS,
IntegerType::get(getContext(), BitWidth - LZ - TZ)),
BO->LHS->getType()),
MulCount);
}
}
if (BO->LHS->getType()->isIntegerTy(1)) {
LHS = getSCEV(BO->LHS);
RHS = getSCEV(BO->RHS);
return getUMinExpr(LHS, RHS);
}
break;
case Instruction::Or:
if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
const SCEV *LHS = getSCEV(BO->LHS);
const APInt &CIVal = CI->getValue();
if (GetMinTrailingZeros(LHS) >=
(CIVal.getBitWidth() - CIVal.countLeadingZeros())) {
return getAddExpr(LHS, getSCEV(CI),
(SCEV::NoWrapFlags)(SCEV::FlagNUW | SCEV::FlagNSW));
}
}
if (BO->LHS->getType()->isIntegerTy(1)) {
LHS = getSCEV(BO->LHS);
RHS = getSCEV(BO->RHS);
return getUMaxExpr(LHS, RHS);
}
break;
case Instruction::Xor:
if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
if (CI->isMinusOne())
return getNotSCEV(getSCEV(BO->LHS));
if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
if (LBO->getOpcode() == Instruction::And &&
LCI->getValue() == CI->getValue())
if (const SCEVZeroExtendExpr *Z =
dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) {
Type *UTy = BO->LHS->getType();
const SCEV *Z0 = Z->getOperand();
Type *Z0Ty = Z0->getType();
unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
if (CI->getValue().isMask(Z0TySize))
return getZeroExtendExpr(getNotSCEV(Z0), UTy);
APInt Trunc = CI->getValue().trunc(Z0TySize);
if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
Trunc.isSignMask())
return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
UTy);
}
}
break;
case Instruction::Shl:
if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
if (SA->getValue().uge(BitWidth))
break;
auto Flags = SCEV::FlagAnyWrap;
if (BO->Op) {
auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
if ((MulFlags & SCEV::FlagNSW) &&
((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
Flags = (SCEV::NoWrapFlags)(Flags | SCEV::FlagNSW);
if (MulFlags & SCEV::FlagNUW)
Flags = (SCEV::NoWrapFlags)(Flags | SCEV::FlagNUW);
}
ConstantInt *X = ConstantInt::get(
getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
}
break;
case Instruction::AShr: {
ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
if (!CI)
break;
Type *OuterTy = BO->LHS->getType();
uint64_t BitWidth = getTypeSizeInBits(OuterTy);
if (CI->getValue().uge(BitWidth))
break;
if (CI->isZero())
return getSCEV(BO->LHS);
uint64_t AShrAmt = CI->getZExtValue();
Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
Operator *L = dyn_cast<Operator>(BO->LHS);
if (L && L->getOpcode() == Instruction::Shl) {
const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
if (L->getOperand(1) == BO->RHS)
return getSignExtendExpr(
getTruncateExpr(ShlOp0SCEV, TruncTy), OuterTy);
ConstantInt *ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
if (ShlAmtCI && ShlAmtCI->getValue().ult(BitWidth)) {
uint64_t ShlAmt = ShlAmtCI->getZExtValue();
if (ShlAmt > AShrAmt) {
APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
ShlAmt - AShrAmt);
return getSignExtendExpr(
getMulExpr(getTruncateExpr(ShlOp0SCEV, TruncTy),
getConstant(Mul)), OuterTy);
}
}
}
break;
}
}
}
switch (U->getOpcode()) {
case Instruction::Trunc:
return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
case Instruction::ZExt:
return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
case Instruction::SExt:
if (auto BO = MatchBinaryOp(U->getOperand(0), DT)) {
if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
Type *Ty = U->getType();
auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
return getMinusSCEV(V1, V2, SCEV::FlagNSW);
}
}
return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
case Instruction::BitCast:
if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
return getSCEV(U->getOperand(0));
break;
case Instruction::PtrToInt: {
const SCEV *Op = getSCEV(U->getOperand(0));
Type *DstIntTy = U->getType();
const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
if (isa<SCEVCouldNotCompute>(IntOp))
return getUnknown(V);
return IntOp;
}
case Instruction::IntToPtr:
return getUnknown(V);
case Instruction::SDiv:
if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
isKnownNonNegative(getSCEV(U->getOperand(1))))
return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
break;
case Instruction::SRem:
if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
isKnownNonNegative(getSCEV(U->getOperand(1))))
return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
break;
case Instruction::GetElementPtr:
return createNodeForGEP(cast<GEPOperator>(U));
case Instruction::PHI:
return createNodeForPHI(cast<PHINode>(U));
case Instruction::Select:
return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
U->getOperand(2));
case Instruction::Call:
case Instruction::Invoke:
if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
return getSCEV(RV);
if (auto *II = dyn_cast<IntrinsicInst>(U)) {
switch (II->getIntrinsicID()) {
case Intrinsic::abs:
return getAbsExpr(
getSCEV(II->getArgOperand(0)),
cast<ConstantInt>(II->getArgOperand(1))->isOne());
case Intrinsic::umax:
LHS = getSCEV(II->getArgOperand(0));
RHS = getSCEV(II->getArgOperand(1));
return getUMaxExpr(LHS, RHS);
case Intrinsic::umin:
LHS = getSCEV(II->getArgOperand(0));
RHS = getSCEV(II->getArgOperand(1));
return getUMinExpr(LHS, RHS);
case Intrinsic::smax:
LHS = getSCEV(II->getArgOperand(0));
RHS = getSCEV(II->getArgOperand(1));
return getSMaxExpr(LHS, RHS);
case Intrinsic::smin:
LHS = getSCEV(II->getArgOperand(0));
RHS = getSCEV(II->getArgOperand(1));
return getSMinExpr(LHS, RHS);
case Intrinsic::usub_sat: {
const SCEV *X = getSCEV(II->getArgOperand(0));
const SCEV *Y = getSCEV(II->getArgOperand(1));
const SCEV *ClampedY = getUMinExpr(X, Y);
return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
}
case Intrinsic::uadd_sat: {
const SCEV *X = getSCEV(II->getArgOperand(0));
const SCEV *Y = getSCEV(II->getArgOperand(1));
const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
}
case Intrinsic::start_loop_iterations:
case Intrinsic::annotation:
case Intrinsic::ptr_annotation:
return getSCEV(II->getArgOperand(0));
default:
break;
}
}
break;
}
return getUnknown(V);
}
const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount,
bool Extend) {
if (isa<SCEVCouldNotCompute>(ExitCount))
return getCouldNotCompute();
auto *ExitCountType = ExitCount->getType();
assert(ExitCountType->isIntegerTy());
if (!Extend)
return getAddExpr(ExitCount, getOne(ExitCountType));
auto *WiderType = Type::getIntNTy(ExitCountType->getContext(),
1 + ExitCountType->getScalarSizeInBits());
return getAddExpr(getNoopOrZeroExtend(ExitCount, WiderType),
getOne(WiderType));
}
static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
if (!ExitCount)
return 0;
ConstantInt *ExitConst = ExitCount->getValue();
if (ExitConst->getValue().getActiveBits() > 32)
return 0;
return ((unsigned)ExitConst->getZExtValue()) + 1;
}
unsigned ScalarEvolution::getSmallConstantTripCount(const Loop *L) {
auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
return getConstantTripCount(ExitCount);
}
unsigned
ScalarEvolution::getSmallConstantTripCount(const Loop *L,
const BasicBlock *ExitingBlock) {
assert(ExitingBlock && "Must pass a non-null exiting block!");
assert(L->isLoopExiting(ExitingBlock) &&
"Exiting block must actually branch out of the loop!");
const SCEVConstant *ExitCount =
dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
return getConstantTripCount(ExitCount);
}
unsigned ScalarEvolution::getSmallConstantMaxTripCount(const Loop *L) {
const auto *MaxExitCount =
dyn_cast<SCEVConstant>(getConstantMaxBackedgeTakenCount(L));
return getConstantTripCount(MaxExitCount);
}
const SCEV *ScalarEvolution::getConstantMaxTripCountFromArray(const Loop *L) {
if (!L->isLoopSimplifyForm() || !L->isInnermost())
return getCouldNotCompute();
const BasicBlock *LoopLatch = L->getLoopLatch();
assert(LoopLatch && "See defination of simplify form loop.");
if (L->getExitingBlock() != LoopLatch)
return getCouldNotCompute();
const DataLayout &DL = getDataLayout();
SmallVector<const SCEV *> InferCountColl;
for (auto *BB : L->getBlocks()) {
if (!DT.dominates(BB, LoopLatch))
continue;
for (Instruction &Inst : *BB) {
auto *GEP = getLoadStorePointerOperand(&Inst);
if (!GEP)
continue;
auto *ElemSize = dyn_cast<SCEVConstant>(getElementSize(&Inst));
if (!ElemSize)
continue;
auto *AddRec = dyn_cast<SCEVAddRecExpr>(getSCEV(GEP));
if (!AddRec)
continue;
auto *ArrBase = dyn_cast<SCEVUnknown>(getPointerBase(AddRec));
auto *Step = dyn_cast<SCEVConstant>(AddRec->getStepRecurrence(*this));
if (!ArrBase || !Step)
continue;
assert(isLoopInvariant(ArrBase, L) && "See addrec definition");
if (AddRec->getStart() != ArrBase)
continue;
if (Step->getAPInt().getActiveBits() > 32 ||
Step->getAPInt().getZExtValue() !=
ElemSize->getAPInt().getZExtValue() ||
Step->isZero() || Step->getAPInt().isNegative())
continue;
AllocaInst *AllocateInst = dyn_cast<AllocaInst>(ArrBase->getValue());
if (!AllocateInst || L->contains(AllocateInst->getParent()))
continue;
auto *Ty = dyn_cast<ArrayType>(AllocateInst->getAllocatedType());
auto *ArrSize = dyn_cast<ConstantInt>(AllocateInst->getArraySize());
if (!Ty || !ArrSize || !ArrSize->isOne())
continue;
const SCEV *MemSize =
getConstant(Step->getType(), DL.getTypeAllocSize(Ty));
auto *MaxExeCount =
dyn_cast<SCEVConstant>(getUDivCeilSCEV(MemSize, Step));
if (!MaxExeCount || MaxExeCount->getAPInt().getActiveBits() > 32)
continue;
auto *InferCount = dyn_cast<SCEVConstant>(
getAddExpr(MaxExeCount, getOne(MaxExeCount->getType())));
if (!InferCount || InferCount->getAPInt().getActiveBits() > 32)
continue;
InferCountColl.push_back(InferCount);
}
}
if (InferCountColl.size() == 0)
return getCouldNotCompute();
return getUMinFromMismatchedTypes(InferCountColl);
}
unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L) {
SmallVector<BasicBlock *, 8> ExitingBlocks;
L->getExitingBlocks(ExitingBlocks);
Optional<unsigned> Res = None;
for (auto *ExitingBB : ExitingBlocks) {
unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
if (!Res)
Res = Multiple;
Res = (unsigned)GreatestCommonDivisor64(*Res, Multiple);
}
return Res.value_or(1);
}
unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
const SCEV *ExitCount) {
if (ExitCount == getCouldNotCompute())
return 1;
const SCEV *TCExpr = getTripCountFromExitCount(ExitCount);
const SCEVConstant *TC = dyn_cast<SCEVConstant>(TCExpr);
if (!TC)
return 1U << std::min((uint32_t)31,
GetMinTrailingZeros(applyLoopGuards(TCExpr, L)));
ConstantInt *Result = TC->getValue();
if (!Result || Result->getValue().getActiveBits() > 32 ||
Result->getValue().getActiveBits() == 0)
return 1;
return (unsigned)Result->getZExtValue();
}
unsigned
ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
const BasicBlock *ExitingBlock) {
assert(ExitingBlock && "Must pass a non-null exiting block!");
assert(L->isLoopExiting(ExitingBlock) &&
"Exiting block must actually branch out of the loop!");
const SCEV *ExitCount = getExitCount(L, ExitingBlock);
return getSmallConstantTripMultiple(L, ExitCount);
}
const SCEV *ScalarEvolution::getExitCount(const Loop *L,
const BasicBlock *ExitingBlock,
ExitCountKind Kind) {
switch (Kind) {
case Exact:
case SymbolicMaximum:
return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
case ConstantMaximum:
return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
};
llvm_unreachable("Invalid ExitCountKind!");
}
const SCEV *
ScalarEvolution::getPredicatedBackedgeTakenCount(const Loop *L,
SmallVector<const SCEVPredicate *, 4> &Preds) {
return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
}
const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L,
ExitCountKind Kind) {
switch (Kind) {
case Exact:
return getBackedgeTakenInfo(L).getExact(L, this);
case ConstantMaximum:
return getBackedgeTakenInfo(L).getConstantMax(this);
case SymbolicMaximum:
return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
};
llvm_unreachable("Invalid ExitCountKind!");
}
bool ScalarEvolution::isBackedgeTakenCountMaxOrZero(const Loop *L) {
return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
}
static void PushLoopPHIs(const Loop *L,
SmallVectorImpl<Instruction *> &Worklist,
SmallPtrSetImpl<Instruction *> &Visited) {
BasicBlock *Header = L->getHeader();
for (PHINode &PN : Header->phis())
if (Visited.insert(&PN).second)
Worklist.push_back(&PN);
}
const ScalarEvolution::BackedgeTakenInfo &
ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
auto &BTI = getBackedgeTakenInfo(L);
if (BTI.hasFullInfo())
return BTI;
auto Pair = PredicatedBackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
if (!Pair.second)
return Pair.first->second;
BackedgeTakenInfo Result =
computeBackedgeTakenCount(L, true);
return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
}
ScalarEvolution::BackedgeTakenInfo &
ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
BackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
if (!Pair.second)
return Pair.first->second;
BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
(void)NumTripCountsComputed;
(void)NumTripCountsNotComputed;
#if LLVM_ENABLE_STATS || !defined(NDEBUG)
const SCEV *BEExact = Result.getExact(L, this);
if (BEExact != getCouldNotCompute()) {
assert(isLoopInvariant(BEExact, L) &&
isLoopInvariant(Result.getConstantMax(this), L) &&
"Computed backedge-taken count isn't loop invariant for loop!");
++NumTripCountsComputed;
} else if (Result.getConstantMax(this) == getCouldNotCompute() &&
isa<PHINode>(L->getHeader()->begin())) {
++NumTripCountsNotComputed;
}
#endif
if (Result.hasAnyInfo()) {
SmallVector<const SCEV *, 8> ToForget;
auto LoopUsersIt = LoopUsers.find(L);
if (LoopUsersIt != LoopUsers.end())
append_range(ToForget, LoopUsersIt->second);
forgetMemoizedResults(ToForget);
for (PHINode &PN : L->getHeader()->phis())
ConstantEvolutionLoopExitValue.erase(&PN);
}
return BackedgeTakenCounts.find(L)->second = std::move(Result);
}
void ScalarEvolution::forgetAllLoops() {
BackedgeTakenCounts.clear();
PredicatedBackedgeTakenCounts.clear();
BECountUsers.clear();
LoopPropertiesCache.clear();
ConstantEvolutionLoopExitValue.clear();
ValueExprMap.clear();
ValuesAtScopes.clear();
ValuesAtScopesUsers.clear();
LoopDispositions.clear();
BlockDispositions.clear();
UnsignedRanges.clear();
SignedRanges.clear();
ExprValueMap.clear();
HasRecMap.clear();
MinTrailingZerosCache.clear();
PredicatedSCEVRewrites.clear();
}
void ScalarEvolution::forgetLoop(const Loop *L) {
SmallVector<const Loop *, 16> LoopWorklist(1, L);
SmallVector<Instruction *, 32> Worklist;
SmallPtrSet<Instruction *, 16> Visited;
SmallVector<const SCEV *, 16> ToForget;
while (!LoopWorklist.empty()) {
auto *CurrL = LoopWorklist.pop_back_val();
forgetBackedgeTakenCounts(CurrL, false);
forgetBackedgeTakenCounts(CurrL, true);
for (auto I = PredicatedSCEVRewrites.begin();
I != PredicatedSCEVRewrites.end();) {
std::pair<const SCEV *, const Loop *> Entry = I->first;
if (Entry.second == CurrL)
PredicatedSCEVRewrites.erase(I++);
else
++I;
}
auto LoopUsersItr = LoopUsers.find(CurrL);
if (LoopUsersItr != LoopUsers.end()) {
ToForget.insert(ToForget.end(), LoopUsersItr->second.begin(),
LoopUsersItr->second.end());
}
PushLoopPHIs(CurrL, Worklist, Visited);
while (!Worklist.empty()) {
Instruction *I = Worklist.pop_back_val();
ValueExprMapType::iterator It =
ValueExprMap.find_as(static_cast<Value *>(I));
if (It != ValueExprMap.end()) {
eraseValueFromMap(It->first);
ToForget.push_back(It->second);
if (PHINode *PN = dyn_cast<PHINode>(I))
ConstantEvolutionLoopExitValue.erase(PN);
}
PushDefUseChildren(I, Worklist, Visited);
}
LoopPropertiesCache.erase(CurrL);
LoopWorklist.append(CurrL->begin(), CurrL->end());
}
forgetMemoizedResults(ToForget);
}
void ScalarEvolution::forgetTopmostLoop(const Loop *L) {
forgetLoop(L->getOutermostLoop());
}
void ScalarEvolution::forgetValue(Value *V) {
Instruction *I = dyn_cast<Instruction>(V);
if (!I) return;
SmallVector<Instruction *, 16> Worklist;
SmallPtrSet<Instruction *, 8> Visited;
SmallVector<const SCEV *, 8> ToForget;
Worklist.push_back(I);
Visited.insert(I);
while (!Worklist.empty()) {
I = Worklist.pop_back_val();
ValueExprMapType::iterator It =
ValueExprMap.find_as(static_cast<Value *>(I));
if (It != ValueExprMap.end()) {
eraseValueFromMap(It->first);
ToForget.push_back(It->second);
if (PHINode *PN = dyn_cast<PHINode>(I))
ConstantEvolutionLoopExitValue.erase(PN);
}
PushDefUseChildren(I, Worklist, Visited);
}
forgetMemoizedResults(ToForget);
}
void ScalarEvolution::forgetLoopDispositions(const Loop *L) {
LoopDispositions.clear();
}
const SCEV *
ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE,
SmallVector<const SCEVPredicate *, 4> *Preds) const {
if (!isComplete() || ExitNotTaken.empty())
return SE->getCouldNotCompute();
const BasicBlock *Latch = L->getLoopLatch();
if (!Latch)
return SE->getCouldNotCompute();
SmallVector<const SCEV *, 2> Ops;
for (const auto &ENT : ExitNotTaken) {
const SCEV *BECount = ENT.ExactNotTaken;
assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
"We should only have known counts for exiting blocks that dominate "
"latch!");
Ops.push_back(BECount);
if (Preds)
for (const auto *P : ENT.Predicates)
Preds->push_back(P);
assert((Preds || ENT.hasAlwaysTruePredicate()) &&
"Predicate should be always true!");
}
return SE->getUMinFromMismatchedTypes(Ops, true);
}
const SCEV *
ScalarEvolution::BackedgeTakenInfo::getExact(const BasicBlock *ExitingBlock,
ScalarEvolution *SE) const {
for (const auto &ENT : ExitNotTaken)
if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
return ENT.ExactNotTaken;
return SE->getCouldNotCompute();
}
const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
for (const auto &ENT : ExitNotTaken)
if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
return ENT.MaxNotTaken;
return SE->getCouldNotCompute();
}
const SCEV *
ScalarEvolution::BackedgeTakenInfo::getConstantMax(ScalarEvolution *SE) const {
auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
return !ENT.hasAlwaysTruePredicate();
};
if (!getConstantMax() || any_of(ExitNotTaken, PredicateNotAlwaysTrue))
return SE->getCouldNotCompute();
assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
isa<SCEVConstant>(getConstantMax())) &&
"No point in having a non-constant max backedge taken count!");
return getConstantMax();
}
const SCEV *
ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(const Loop *L,
ScalarEvolution *SE) {
if (!SymbolicMax)
SymbolicMax = SE->computeSymbolicMaxBackedgeTakenCount(L);
return SymbolicMax;
}
bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
ScalarEvolution *SE) const {
auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
return !ENT.hasAlwaysTruePredicate();
};
return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
}
ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E)
: ExitLimit(E, E, false, None) {
}
ScalarEvolution::ExitLimit::ExitLimit(
const SCEV *E, const SCEV *M, bool MaxOrZero,
ArrayRef<const SmallPtrSetImpl<const SCEVPredicate *> *> PredSetList)
: ExactNotTaken(E), MaxNotTaken(M), MaxOrZero(MaxOrZero) {
if (MaxNotTaken->isZero())
ExactNotTaken = MaxNotTaken;
assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
!isa<SCEVCouldNotCompute>(MaxNotTaken)) &&
"Exact is not allowed to be less precise than Max");
assert((isa<SCEVCouldNotCompute>(MaxNotTaken) ||
isa<SCEVConstant>(MaxNotTaken)) &&
"No point in having a non-constant max backedge taken count!");
for (const auto *PredSet : PredSetList)
for (const auto *P : *PredSet)
addPredicate(P);
assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
"Backedge count should be int");
assert((isa<SCEVCouldNotCompute>(M) || !M->getType()->isPointerTy()) &&
"Max backedge count should be int");
}
ScalarEvolution::ExitLimit::ExitLimit(
const SCEV *E, const SCEV *M, bool MaxOrZero,
const SmallPtrSetImpl<const SCEVPredicate *> &PredSet)
: ExitLimit(E, M, MaxOrZero, {&PredSet}) {
}
ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E, const SCEV *M,
bool MaxOrZero)
: ExitLimit(E, M, MaxOrZero, None) {
}
ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
ArrayRef<ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo> ExitCounts,
bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
: ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
ExitNotTaken.reserve(ExitCounts.size());
std::transform(
ExitCounts.begin(), ExitCounts.end(), std::back_inserter(ExitNotTaken),
[&](const EdgeExitInfo &EEI) {
BasicBlock *ExitBB = EEI.first;
const ExitLimit &EL = EEI.second;
return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, EL.MaxNotTaken,
EL.Predicates);
});
assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
isa<SCEVConstant>(ConstantMax)) &&
"No point in having a non-constant max backedge taken count!");
}
ScalarEvolution::BackedgeTakenInfo
ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
bool AllowPredicates) {
SmallVector<BasicBlock *, 8> ExitingBlocks;
L->getExitingBlocks(ExitingBlocks);
using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
SmallVector<EdgeExitInfo, 4> ExitCounts;
bool CouldComputeBECount = true;
BasicBlock *Latch = L->getLoopLatch(); const SCEV *MustExitMaxBECount = nullptr;
const SCEV *MayExitMaxBECount = nullptr;
bool MustExitMaxOrZero = false;
for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
BasicBlock *ExitBB = ExitingBlocks[i];
if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
if (ExitIfTrue == CI->isZero())
continue;
}
ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates);
assert((AllowPredicates || EL.Predicates.empty()) &&
"Predicated exit limit when predicates are not allowed!");
if (EL.ExactNotTaken == getCouldNotCompute())
CouldComputeBECount = false;
else
ExitCounts.emplace_back(ExitBB, EL);
if (EL.MaxNotTaken != getCouldNotCompute() && Latch &&
DT.dominates(ExitBB, Latch)) {
if (!MustExitMaxBECount) {
MustExitMaxBECount = EL.MaxNotTaken;
MustExitMaxOrZero = EL.MaxOrZero;
} else {
MustExitMaxBECount =
getUMinFromMismatchedTypes(MustExitMaxBECount, EL.MaxNotTaken);
}
} else if (MayExitMaxBECount != getCouldNotCompute()) {
if (!MayExitMaxBECount || EL.MaxNotTaken == getCouldNotCompute())
MayExitMaxBECount = EL.MaxNotTaken;
else {
MayExitMaxBECount =
getUMaxFromMismatchedTypes(MayExitMaxBECount, EL.MaxNotTaken);
}
}
}
const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
(MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
for (const auto &Pair : ExitCounts)
if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
MaxBECount, MaxOrZero);
}
ScalarEvolution::ExitLimit
ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
bool AllowPredicates) {
assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
const BasicBlock *Latch = L->getLoopLatch();
if (!Latch || !DT.dominates(ExitingBlock, Latch))
return getCouldNotCompute();
bool IsOnlyExit = (L->getExitingBlock() != nullptr);
Instruction *Term = ExitingBlock->getTerminator();
if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
assert(BI->isConditional() && "If unconditional, it can't be in loop!");
bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
"It should have one successor in loop and one exit block!");
return computeExitLimitFromCond(
L, BI->getCondition(), ExitIfTrue,
IsOnlyExit, AllowPredicates);
}
if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
BasicBlock *Exit = nullptr;
for (auto *SBB : successors(ExitingBlock))
if (!L->contains(SBB)) {
if (Exit) return getCouldNotCompute();
Exit = SBB;
}
assert(Exit && "Exiting block must have at least one exit");
return computeExitLimitFromSingleExitSwitch(L, SI, Exit,
IsOnlyExit);
}
return getCouldNotCompute();
}
ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond(
const Loop *L, Value *ExitCond, bool ExitIfTrue,
bool ControlsExit, bool AllowPredicates) {
ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
ControlsExit, AllowPredicates);
}
Optional<ScalarEvolution::ExitLimit>
ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
bool ExitIfTrue, bool ControlsExit,
bool AllowPredicates) {
(void)this->L;
(void)this->ExitIfTrue;
(void)this->AllowPredicates;
assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
this->AllowPredicates == AllowPredicates &&
"Variance in assumed invariant key components!");
auto Itr = TripCountMap.find({ExitCond, ControlsExit});
if (Itr == TripCountMap.end())
return None;
return Itr->second;
}
void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
bool ExitIfTrue,
bool ControlsExit,
bool AllowPredicates,
const ExitLimit &EL) {
assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
this->AllowPredicates == AllowPredicates &&
"Variance in assumed invariant key components!");
auto InsertResult = TripCountMap.insert({{ExitCond, ControlsExit}, EL});
assert(InsertResult.second && "Expected successful insertion!");
(void)InsertResult;
(void)ExitIfTrue;
}
ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
bool ControlsExit, bool AllowPredicates) {
if (auto MaybeEL =
Cache.find(L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates))
return *MaybeEL;
ExitLimit EL = computeExitLimitFromCondImpl(Cache, L, ExitCond, ExitIfTrue,
ControlsExit, AllowPredicates);
Cache.insert(L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates, EL);
return EL;
}
ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
bool ControlsExit, bool AllowPredicates) {
if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
Cache, L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates))
return *LimitFromBinOp;
if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
ExitLimit EL =
computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsExit);
if (EL.hasFullInfo() || !AllowPredicates)
return EL;
return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsExit,
true);
}
if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
if (ExitIfTrue == !CI->getZExtValue())
return getCouldNotCompute();
else
return getZero(CI->getType());
}
const WithOverflowInst *WO;
const APInt *C;
if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
match(WO->getRHS(), m_APInt(C))) {
ConstantRange NWR =
ConstantRange::makeExactNoWrapRegion(WO->getBinaryOp(), *C,
WO->getNoWrapKind());
CmpInst::Predicate Pred;
APInt NewRHSC, Offset;
NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
if (!ExitIfTrue)
Pred = ICmpInst::getInversePredicate(Pred);
auto *LHS = getSCEV(WO->getLHS());
if (Offset != 0)
LHS = getAddExpr(LHS, getConstant(Offset));
auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
ControlsExit, AllowPredicates);
if (EL.hasAnyInfo()) return EL;
}
return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
}
Optional<ScalarEvolution::ExitLimit>
ScalarEvolution::computeExitLimitFromCondFromBinOp(
ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
bool ControlsExit, bool AllowPredicates) {
Value *Op0, *Op1;
bool IsAnd = false;
if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
IsAnd = true;
else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
IsAnd = false;
else
return None;
bool EitherMayExit = IsAnd ^ ExitIfTrue;
ExitLimit EL0 = computeExitLimitFromCondCached(Cache, L, Op0, ExitIfTrue,
ControlsExit && !EitherMayExit,
AllowPredicates);
ExitLimit EL1 = computeExitLimitFromCondCached(Cache, L, Op1, ExitIfTrue,
ControlsExit && !EitherMayExit,
AllowPredicates);
const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
if (isa<ConstantInt>(Op1))
return Op1 == NeutralElement ? EL0 : EL1;
if (isa<ConstantInt>(Op0))
return Op0 == NeutralElement ? EL1 : EL0;
const SCEV *BECount = getCouldNotCompute();
const SCEV *MaxBECount = getCouldNotCompute();
if (EitherMayExit) {
if (EL0.ExactNotTaken != getCouldNotCompute() &&
EL1.ExactNotTaken != getCouldNotCompute()) {
BECount = getUMinFromMismatchedTypes(
EL0.ExactNotTaken, EL1.ExactNotTaken,
!isa<BinaryOperator>(ExitCond));
}
if (EL0.MaxNotTaken == getCouldNotCompute())
MaxBECount = EL1.MaxNotTaken;
else if (EL1.MaxNotTaken == getCouldNotCompute())
MaxBECount = EL0.MaxNotTaken;
else
MaxBECount = getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken);
} else {
if (EL0.ExactNotTaken == EL1.ExactNotTaken)
BECount = EL0.ExactNotTaken;
}
if (isa<SCEVCouldNotCompute>(MaxBECount) &&
!isa<SCEVCouldNotCompute>(BECount))
MaxBECount = getConstant(getUnsignedRangeMax(BECount));
return ExitLimit(BECount, MaxBECount, false,
{ &EL0.Predicates, &EL1.Predicates });
}
ScalarEvolution::ExitLimit
ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
ICmpInst *ExitCond,
bool ExitIfTrue,
bool ControlsExit,
bool AllowPredicates) {
ICmpInst::Predicate Pred;
if (!ExitIfTrue)
Pred = ExitCond->getPredicate();
else
Pred = ExitCond->getInversePredicate();
const ICmpInst::Predicate OriginalPred = Pred;
const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsExit,
AllowPredicates);
if (EL.hasAnyInfo()) return EL;
auto *ExhaustiveCount =
computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
return ExhaustiveCount;
return computeShiftCompareExitLimit(ExitCond->getOperand(0),
ExitCond->getOperand(1), L, OriginalPred);
}
ScalarEvolution::ExitLimit
ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS,
bool ControlsExit,
bool AllowPredicates) {
LHS = getSCEVAtScope(LHS, L);
RHS = getSCEVAtScope(RHS, L);
if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
std::swap(LHS, RHS);
Pred = ICmpInst::getSwappedPredicate(Pred);
}
bool ControllingFiniteLoop =
ControlsExit && loopHasNoAbnormalExits(L) && loopIsFiniteByAssumption(L);
(void)SimplifyICmpOperands(Pred, LHS, RHS, 0,
(EnableFiniteLoopControl ? ControllingFiniteLoop
: false));
if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
if (AddRec->getLoop() == L) {
ConstantRange CompRange =
ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
}
if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
auto *InnerLHS = LHS;
if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
InnerLHS = ZExt->getOperand();
if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS)) {
auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
if (!AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
StrideC && StrideC->getAPInt().isPowerOf2()) {
auto Flags = AR->getNoWrapFlags();
Flags = setFlags(Flags, SCEV::FlagNW);
SmallVector<const SCEV*> Operands{AR->operands()};
Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
}
}
}
switch (Pred) {
case ICmpInst::ICMP_NE: { if (LHS->getType()->isPointerTy()) {
LHS = getLosslessPtrToIntExpr(LHS);
if (isa<SCEVCouldNotCompute>(LHS))
return LHS;
}
if (RHS->getType()->isPointerTy()) {
RHS = getLosslessPtrToIntExpr(RHS);
if (isa<SCEVCouldNotCompute>(RHS))
return RHS;
}
ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit,
AllowPredicates);
if (EL.hasAnyInfo()) return EL;
break;
}
case ICmpInst::ICMP_EQ: { if (LHS->getType()->isPointerTy()) {
LHS = getLosslessPtrToIntExpr(LHS);
if (isa<SCEVCouldNotCompute>(LHS))
return LHS;
}
if (RHS->getType()->isPointerTy()) {
RHS = getLosslessPtrToIntExpr(RHS);
if (isa<SCEVCouldNotCompute>(RHS))
return RHS;
}
ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
if (EL.hasAnyInfo()) return EL;
break;
}
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_ULT: { bool IsSigned = Pred == ICmpInst::ICMP_SLT;
ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsExit,
AllowPredicates);
if (EL.hasAnyInfo()) return EL;
break;
}
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_UGT: { bool IsSigned = Pred == ICmpInst::ICMP_SGT;
ExitLimit EL =
howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit,
AllowPredicates);
if (EL.hasAnyInfo()) return EL;
break;
}
default:
break;
}
return getCouldNotCompute();
}
ScalarEvolution::ExitLimit
ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
SwitchInst *Switch,
BasicBlock *ExitingBlock,
bool ControlsExit) {
assert(!L->contains(ExitingBlock) && "Not an exiting block!");
if (Switch->getDefaultDest() == ExitingBlock)
return getCouldNotCompute();
assert(L->contains(Switch->getDefaultDest()) &&
"Default case must not exit the loop!");
const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit);
if (EL.hasAnyInfo())
return EL;
return getCouldNotCompute();
}
static ConstantInt *
EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C,
ScalarEvolution &SE) {
const SCEV *InVal = SE.getConstant(C);
const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
assert(isa<SCEVConstant>(Val) &&
"Evaluation of SCEV at constant didn't fold correctly?");
return cast<SCEVConstant>(Val)->getValue();
}
ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
if (!RHS)
return getCouldNotCompute();
const BasicBlock *Latch = L->getLoopLatch();
if (!Latch)
return getCouldNotCompute();
const BasicBlock *Predecessor = L->getLoopPredecessor();
if (!Predecessor)
return getCouldNotCompute();
auto MatchPositiveShift =
[](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
using namespace PatternMatch;
ConstantInt *ShiftAmt;
if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
OutOpCode = Instruction::LShr;
else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
OutOpCode = Instruction::AShr;
else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
OutOpCode = Instruction::Shl;
else
return false;
return ShiftAmt->getValue().isStrictlyPositive();
};
auto MatchShiftRecurrence =
[&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
Optional<Instruction::BinaryOps> PostShiftOpCode;
{
Instruction::BinaryOps OpC;
Value *V;
if (MatchPositiveShift(LHS, V, OpC)) {
PostShiftOpCode = OpC;
LHS = V;
}
}
PNOut = dyn_cast<PHINode>(LHS);
if (!PNOut || PNOut->getParent() != L->getHeader())
return false;
Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
Value *OpLHS;
return
MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
OpLHS == PNOut &&
(!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
};
PHINode *PN;
Instruction::BinaryOps OpCode;
if (!MatchShiftRecurrence(LHS, PN, OpCode))
return getCouldNotCompute();
const DataLayout &DL = getDataLayout();
ConstantInt *StableValue = nullptr;
switch (OpCode) {
default:
llvm_unreachable("Impossible case!");
case Instruction::AShr: {
Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
KnownBits Known = computeKnownBits(FirstValue, DL, 0, &AC,
Predecessor->getTerminator(), &DT);
auto *Ty = cast<IntegerType>(RHS->getType());
if (Known.isNonNegative())
StableValue = ConstantInt::get(Ty, 0);
else if (Known.isNegative())
StableValue = ConstantInt::get(Ty, -1, true);
else
return getCouldNotCompute();
break;
}
case Instruction::LShr:
case Instruction::Shl:
StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
break;
}
auto *Result =
ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
assert(Result->getType()->isIntegerTy(1) &&
"Otherwise cannot be an operand to a branch instruction");
if (Result->isZeroValue()) {
unsigned BitWidth = getTypeSizeInBits(RHS->getType());
const SCEV *UpperBound =
getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth);
return ExitLimit(getCouldNotCompute(), UpperBound, false);
}
return getCouldNotCompute();
}
static bool CanConstantFold(const Instruction *I) {
if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) ||
isa<LoadInst>(I) || isa<ExtractValueInst>(I))
return true;
if (const CallInst *CI = dyn_cast<CallInst>(I))
if (const Function *F = CI->getCalledFunction())
return canConstantFoldCallTo(CI, F);
return false;
}
static bool canConstantEvolve(Instruction *I, const Loop *L) {
if (!L->contains(I)) return false;
if (isa<PHINode>(I)) {
return L->getHeader() == I->getParent();
}
return CanConstantFold(I);
}
static PHINode *
getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L,
DenseMap<Instruction *, PHINode *> &PHIMap,
unsigned Depth) {
if (Depth > MaxConstantEvolvingDepth)
return nullptr;
PHINode *PHI = nullptr;
for (Value *Op : UseInst->operands()) {
if (isa<Constant>(Op)) continue;
Instruction *OpInst = dyn_cast<Instruction>(Op);
if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
PHINode *P = dyn_cast<PHINode>(OpInst);
if (!P)
P = PHIMap.lookup(OpInst);
if (!P) {
P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
PHIMap[OpInst] = P;
}
if (!P)
return nullptr; if (PHI && PHI != P)
return nullptr; PHI = P;
}
return PHI;
}
static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) {
Instruction *I = dyn_cast<Instruction>(V);
if (!I || !canConstantEvolve(I, L)) return nullptr;
if (PHINode *PN = dyn_cast<PHINode>(I))
return PN;
DenseMap<Instruction *, PHINode *> PHIMap;
return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
}
static Constant *EvaluateExpression(Value *V, const Loop *L,
DenseMap<Instruction *, Constant *> &Vals,
const DataLayout &DL,
const TargetLibraryInfo *TLI) {
if (Constant *C = dyn_cast<Constant>(V)) return C;
Instruction *I = dyn_cast<Instruction>(V);
if (!I) return nullptr;
if (Constant *C = Vals.lookup(I)) return C;
if (!canConstantEvolve(I, L)) return nullptr;
if (isa<PHINode>(I)) return nullptr;
std::vector<Constant*> Operands(I->getNumOperands());
for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
if (!Operand) {
Operands[i] = dyn_cast<Constant>(I->getOperand(i));
if (!Operands[i]) return nullptr;
continue;
}
Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
Vals[Operand] = C;
if (!C) return nullptr;
Operands[i] = C;
}
return ConstantFoldInstOperands(I, Operands, DL, TLI);
}
static Constant *getOtherIncomingValue(PHINode *PN, BasicBlock *BB) {
Constant *IncomingVal = nullptr;
for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
if (PN->getIncomingBlock(i) == BB)
continue;
auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
if (!CurrentVal)
return nullptr;
if (IncomingVal != CurrentVal) {
if (IncomingVal)
return nullptr;
IncomingVal = CurrentVal;
}
}
return IncomingVal;
}
Constant *
ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
const APInt &BEs,
const Loop *L) {
auto I = ConstantEvolutionLoopExitValue.find(PN);
if (I != ConstantEvolutionLoopExitValue.end())
return I->second;
if (BEs.ugt(MaxBruteForceIterations))
return ConstantEvolutionLoopExitValue[PN] = nullptr;
Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
DenseMap<Instruction *, Constant *> CurrentIterVals;
BasicBlock *Header = L->getHeader();
assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
BasicBlock *Latch = L->getLoopLatch();
if (!Latch)
return nullptr;
for (PHINode &PHI : Header->phis()) {
if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
CurrentIterVals[&PHI] = StartCST;
}
if (!CurrentIterVals.count(PN))
return RetVal = nullptr;
Value *BEValue = PN->getIncomingValueForBlock(Latch);
assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
"BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
unsigned NumIterations = BEs.getZExtValue(); unsigned IterationNum = 0;
const DataLayout &DL = getDataLayout();
for (; ; ++IterationNum) {
if (IterationNum == NumIterations)
return RetVal = CurrentIterVals[PN];
DenseMap<Instruction *, Constant *> NextIterVals;
Constant *NextPHI =
EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
if (!NextPHI)
return nullptr; NextIterVals[PN] = NextPHI;
bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
SmallVector<std::pair<PHINode *, Constant *>, 8> PHIsToCompute;
for (const auto &I : CurrentIterVals) {
PHINode *PHI = dyn_cast<PHINode>(I.first);
if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
PHIsToCompute.emplace_back(PHI, I.second);
}
for (const auto &I : PHIsToCompute) {
PHINode *PHI = I.first;
Constant *&NextPHI = NextIterVals[PHI];
if (!NextPHI) { Value *BEValue = PHI->getIncomingValueForBlock(Latch);
NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
}
if (NextPHI != I.second)
StoppedEvolving = false;
}
if (StoppedEvolving)
return RetVal = CurrentIterVals[PN];
CurrentIterVals.swap(NextIterVals);
}
}
const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
Value *Cond,
bool ExitWhen) {
PHINode *PN = getConstantEvolvingPHI(Cond, L);
if (!PN) return getCouldNotCompute();
if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
DenseMap<Instruction *, Constant *> CurrentIterVals;
BasicBlock *Header = L->getHeader();
assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
BasicBlock *Latch = L->getLoopLatch();
assert(Latch && "Should follow from NumIncomingValues == 2!");
for (PHINode &PHI : Header->phis()) {
if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
CurrentIterVals[&PHI] = StartCST;
}
if (!CurrentIterVals.count(PN))
return getCouldNotCompute();
unsigned MaxIterations = MaxBruteForceIterations; const DataLayout &DL = getDataLayout();
for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
auto *CondVal = dyn_cast_or_null<ConstantInt>(
EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
if (!CondVal) return getCouldNotCompute();
if (CondVal->getValue() == uint64_t(ExitWhen)) {
++NumBruteForceTripCountsComputed;
return getConstant(Type::getInt32Ty(getContext()), IterationNum);
}
DenseMap<Instruction *, Constant *> NextIterVals;
SmallVector<PHINode *, 8> PHIsToCompute;
for (const auto &I : CurrentIterVals) {
PHINode *PHI = dyn_cast<PHINode>(I.first);
if (!PHI || PHI->getParent() != Header) continue;
PHIsToCompute.push_back(PHI);
}
for (PHINode *PHI : PHIsToCompute) {
Constant *&NextPHI = NextIterVals[PHI];
if (NextPHI) continue;
Value *BEValue = PHI->getIncomingValueForBlock(Latch);
NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
}
CurrentIterVals.swap(NextIterVals);
}
return getCouldNotCompute();
}
const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values =
ValuesAtScopes[V];
for (auto &LS : Values)
if (LS.first == L)
return LS.second ? LS.second : V;
Values.emplace_back(L, nullptr);
const SCEV *C = computeSCEVAtScope(V, L);
for (auto &LS : reverse(ValuesAtScopes[V]))
if (LS.first == L) {
LS.second = C;
if (!isa<SCEVConstant>(C))
ValuesAtScopesUsers[C].push_back({L, V});
break;
}
return C;
}
static Constant *BuildConstantFromSCEV(const SCEV *V) {
switch (V->getSCEVType()) {
case scCouldNotCompute:
case scAddRecExpr:
return nullptr;
case scConstant:
return cast<SCEVConstant>(V)->getValue();
case scUnknown:
return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
case scSignExtend: {
const SCEVSignExtendExpr *SS = cast<SCEVSignExtendExpr>(V);
if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand()))
return ConstantExpr::getSExt(CastOp, SS->getType());
return nullptr;
}
case scZeroExtend: {
const SCEVZeroExtendExpr *SZ = cast<SCEVZeroExtendExpr>(V);
if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand()))
return ConstantExpr::getZExt(CastOp, SZ->getType());
return nullptr;
}
case scPtrToInt: {
const SCEVPtrToIntExpr *P2I = cast<SCEVPtrToIntExpr>(V);
if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
return nullptr;
}
case scTruncate: {
const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
return ConstantExpr::getTrunc(CastOp, ST->getType());
return nullptr;
}
case scAddExpr: {
const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
Constant *C = nullptr;
for (const SCEV *Op : SA->operands()) {
Constant *OpC = BuildConstantFromSCEV(Op);
if (!OpC)
return nullptr;
if (!C) {
C = OpC;
continue;
}
assert(!C->getType()->isPointerTy() &&
"Can only have one pointer, and it must be last");
if (auto *PT = dyn_cast<PointerType>(OpC->getType())) {
Type *DestPtrTy =
Type::getInt8PtrTy(PT->getContext(), PT->getAddressSpace());
OpC = ConstantExpr::getBitCast(OpC, DestPtrTy);
C = ConstantExpr::getGetElementPtr(Type::getInt8Ty(C->getContext()),
OpC, C);
} else {
C = ConstantExpr::getAdd(C, OpC);
}
}
return C;
}
case scMulExpr: {
const SCEVMulExpr *SM = cast<SCEVMulExpr>(V);
Constant *C = nullptr;
for (const SCEV *Op : SM->operands()) {
assert(!Op->getType()->isPointerTy() && "Can't multiply pointers");
Constant *OpC = BuildConstantFromSCEV(Op);
if (!OpC)
return nullptr;
C = C ? ConstantExpr::getMul(C, OpC) : OpC;
}
return C;
}
case scUDivExpr:
case scSMaxExpr:
case scUMaxExpr:
case scSMinExpr:
case scUMinExpr:
case scSequentialUMinExpr:
return nullptr; }
llvm_unreachable("Unknown SCEV kind!");
}
const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
if (isa<SCEVConstant>(V)) return V;
if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) {
if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) {
if (PHINode *PN = dyn_cast<PHINode>(I)) {
const Loop *CurrLoop = this->LI[I->getParent()];
if (CurrLoop && CurrLoop->getParentLoop() == L &&
PN->getParent() == CurrLoop->getHeader()) {
const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
if (BackedgeTakenCount->isZero()) {
Value *InitValue = nullptr;
bool MultipleInitValues = false;
for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
if (!InitValue)
InitValue = PN->getIncomingValue(i);
else if (InitValue != PN->getIncomingValue(i)) {
MultipleInitValues = true;
break;
}
}
}
if (!MultipleInitValues && InitValue)
return getSCEV(InitValue);
}
if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
isKnownPositive(BackedgeTakenCount) &&
PN->getNumIncomingValues() == 2) {
unsigned InLoopPred =
CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
if (CurrLoop->isLoopInvariant(BackedgeVal))
return getSCEV(BackedgeVal);
}
if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
Constant *RV = getConstantEvolutionLoopExitValue(
PN, BTCC->getAPInt(), CurrLoop);
if (RV) return getSCEV(RV);
}
}
if (PN->getNumOperands() == 1) {
const SCEV *Input = getSCEV(PN->getOperand(0));
const SCEV *InputAtScope = getSCEVAtScope(Input, L);
if (isa<SCEVConstant>(InputAtScope)) return InputAtScope;
}
}
if (CanConstantFold(I)) {
SmallVector<Constant *, 4> Operands;
bool MadeImprovement = false;
for (Value *Op : I->operands()) {
if (Constant *C = dyn_cast<Constant>(Op)) {
Operands.push_back(C);
continue;
}
if (!isSCEVable(Op->getType()))
return V;
const SCEV *OrigV = getSCEV(Op);
const SCEV *OpV = getSCEVAtScope(OrigV, L);
MadeImprovement |= OrigV != OpV;
Constant *C = BuildConstantFromSCEV(OpV);
if (!C) return V;
if (C->getType() != Op->getType())
C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false,
Op->getType(),
false),
C, Op->getType());
Operands.push_back(C);
}
if (MadeImprovement) {
Constant *C = nullptr;
const DataLayout &DL = getDataLayout();
C = ConstantFoldInstOperands(I, Operands, DL, &TLI);
if (!C) return V;
return getSCEV(C);
}
}
}
return V;
}
if (isa<SCEVCommutativeExpr>(V) || isa<SCEVSequentialMinMaxExpr>(V)) {
const auto *Comm = cast<SCEVNAryExpr>(V);
for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) {
const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
if (OpAtScope != Comm->getOperand(i)) {
SmallVector<const SCEV *, 8> NewOps(Comm->op_begin(),
Comm->op_begin()+i);
NewOps.push_back(OpAtScope);
for (++i; i != e; ++i) {
OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
NewOps.push_back(OpAtScope);
}
if (isa<SCEVAddExpr>(Comm))
return getAddExpr(NewOps, Comm->getNoWrapFlags());
if (isa<SCEVMulExpr>(Comm))
return getMulExpr(NewOps, Comm->getNoWrapFlags());
if (isa<SCEVMinMaxExpr>(Comm))
return getMinMaxExpr(Comm->getSCEVType(), NewOps);
if (isa<SCEVSequentialMinMaxExpr>(Comm))
return getSequentialMinMaxExpr(Comm->getSCEVType(), NewOps);
llvm_unreachable("Unknown commutative / sequential min/max SCEV type!");
}
}
return Comm;
}
if (const SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) {
const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L);
const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L);
if (LHS == Div->getLHS() && RHS == Div->getRHS())
return Div; return getUDivExpr(LHS, RHS);
}
if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
if (OpAtScope == AddRec->getOperand(i))
continue;
SmallVector<const SCEV *, 8> NewOps(AddRec->op_begin(),
AddRec->op_begin()+i);
NewOps.push_back(OpAtScope);
for (++i; i != e; ++i)
NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
const SCEV *FoldedRec =
getAddRecExpr(NewOps, AddRec->getLoop(),
AddRec->getNoWrapFlags(SCEV::FlagNW));
AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
if (!AddRec)
return FoldedRec;
break;
}
if (!AddRec->getLoop()->contains(L)) {
const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
if (BackedgeTakenCount == getCouldNotCompute()) return AddRec;
return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
}
return AddRec;
}
if (const SCEVCastExpr *Cast = dyn_cast<SCEVCastExpr>(V)) {
const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
if (Op == Cast->getOperand())
return Cast; return getCastExpr(Cast->getSCEVType(), Op, Cast->getType());
}
llvm_unreachable("Unknown SCEV type!");
}
const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) {
return getSCEVAtScope(getSCEV(V), L);
}
const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S))
return stripInjectiveFunctions(ZExt->getOperand());
if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S))
return stripInjectiveFunctions(SExt->getOperand());
return S;
}
static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
ScalarEvolution &SE) {
uint32_t BW = A.getBitWidth();
assert(BW == SE.getTypeSizeInBits(B->getType()));
assert(A != 0 && "A must be non-zero.");
uint32_t Mult2 = A.countTrailingZeros();
if (SE.GetMinTrailingZeros(B) < Mult2)
return SE.getCouldNotCompute();
APInt AD = A.lshr(Mult2).zext(BW + 1); APInt Mod(BW + 1, 0);
Mod.setBit(BW - Mult2); APInt I = AD.multiplicativeInverse(Mod).trunc(BW);
const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
}
static Optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
GetQuadraticEquation(const SCEVAddRecExpr *AddRec) {
assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
<< *AddRec << '\n');
if (!LC || !MC || !NC) {
LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
return None;
}
APInt L = LC->getAPInt();
APInt M = MC->getAPInt();
APInt N = NC->getAPInt();
assert(!N.isZero() && "This is not a quadratic addrec");
unsigned BitWidth = LC->getAPInt().getBitWidth();
unsigned NewWidth = BitWidth + 1;
LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
<< BitWidth << '\n');
N = N.sext(NewWidth);
M = M.sext(NewWidth);
L = L.sext(NewWidth);
APInt A = N;
APInt B = 2 * M - A;
APInt C = 2 * L;
APInt T = APInt(NewWidth, 2);
LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
<< "x + " << C << ", coeff bw: " << NewWidth
<< ", multiplied by " << T << '\n');
return std::make_tuple(A, B, C, T, BitWidth);
}
static Optional<APInt> MinOptional(Optional<APInt> X, Optional<APInt> Y) {
if (X && Y) {
unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
APInt XW = X->sext(W);
APInt YW = Y->sext(W);
return XW.slt(YW) ? *X : *Y;
}
if (!X && !Y)
return None;
return X ? *X : *Y;
}
static Optional<APInt> TruncIfPossible(Optional<APInt> X, unsigned BitWidth) {
if (!X)
return None;
unsigned W = X->getBitWidth();
if (BitWidth > 1 && BitWidth < W && X->isIntN(BitWidth))
return X->trunc(BitWidth);
return X;
}
static Optional<APInt>
SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
APInt A, B, C, M;
unsigned BitWidth;
auto T = GetQuadraticEquation(AddRec);
if (!T)
return None;
std::tie(A, B, C, M, BitWidth) = *T;
LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
Optional<APInt> X = APIntOps::SolveQuadraticEquationWrap(A, B, C, BitWidth+1);
if (!X)
return None;
ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
if (!V->isZero())
return None;
return TruncIfPossible(X, BitWidth);
}
static Optional<APInt>
SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec,
const ConstantRange &Range, ScalarEvolution &SE) {
assert(AddRec->getOperand(0)->isZero() &&
"Starting value of addrec should be 0");
LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
<< Range << ", addrec " << *AddRec << '\n');
assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
"Addrec's initial value should be in range");
APInt A, B, C, M;
unsigned BitWidth;
auto T = GetQuadraticEquation(AddRec);
if (!T)
return None;
auto SolveForBoundary = [&](APInt Bound) -> std::pair<Optional<APInt>,bool> {
LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
<< Bound << " (before multiplying by " << M << ")\n");
Bound *= M;
Optional<APInt> SO = None;
if (BitWidth > 1) {
LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
"signed overflow\n");
SO = APIntOps::SolveQuadraticEquationWrap(A, B, -Bound, BitWidth);
}
LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
"unsigned overflow\n");
Optional<APInt> UO = APIntOps::SolveQuadraticEquationWrap(A, B, -Bound,
BitWidth+1);
auto LeavesRange = [&] (const APInt &X) {
ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
if (Range.contains(V0->getValue()))
return false;
ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
if (Range.contains(V1->getValue()))
return true;
return false;
};
if (!SO || !UO)
return { None, false };
Optional<APInt> Min = MinOptional(SO, UO);
if (LeavesRange(*Min))
return { Min, true };
Optional<APInt> Max = Min == SO ? UO : SO;
if (LeavesRange(*Max))
return { Max, true };
return { None, true };
};
std::tie(A, B, C, M, BitWidth) = *T;
APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
APInt Upper = Range.getUpper().sext(A.getBitWidth());
auto SL = SolveForBoundary(Lower);
auto SU = SolveForBoundary(Upper);
if (!SL.second || !SU.second)
return None;
return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
}
ScalarEvolution::ExitLimit
ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
bool AllowPredicates) {
SmallPtrSet<const SCEVPredicate *, 4> Predicates;
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
if (C->getValue()->isZero()) return C;
return getCouldNotCompute(); }
const SCEVAddRecExpr *AddRec =
dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
if (!AddRec && AllowPredicates)
AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
if (!AddRec || AddRec->getLoop() != L)
return getCouldNotCompute();
if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
const auto *R = cast<SCEVConstant>(getConstant(*S));
return ExitLimit(R, R, false, Predicates);
}
return getCouldNotCompute();
}
if (!AddRec->isAffine())
return getCouldNotCompute();
const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
if (!StepC || StepC->getValue()->isZero())
return getCouldNotCompute();
bool CountDown = StepC->getAPInt().isNegative();
const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
if (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne()) {
APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, L));
MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
const SCEV *Zero = getZero(Distance->getType());
const SCEV *One = getOne(Distance->getType());
const SCEV *DistancePlusOne = getAddExpr(Distance, One);
if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
ConstantRange CR = getUnsignedRange(DistancePlusOne);
MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
}
return ExitLimit(Distance, getConstant(MaxBECount), false, Predicates);
}
if (ControlsExit && AddRec->hasNoSelfWrap() &&
loopHasNoAbnormalExits(AddRec->getLoop())) {
const SCEV *Exact =
getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
const SCEV *Max = getCouldNotCompute();
if (Exact != getCouldNotCompute()) {
APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, L));
Max = getConstant(APIntOps::umin(MaxInt, getUnsignedRangeMax(Exact)));
}
return ExitLimit(Exact, Max, false, Predicates);
}
const SCEV *E = SolveLinEquationWithOverflow(StepC->getAPInt(),
getNegativeSCEV(Start), *this);
const SCEV *M = E;
if (E != getCouldNotCompute()) {
APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, L));
M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
}
return ExitLimit(E, M, false, Predicates);
}
ScalarEvolution::ExitLimit
ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
if (!C->getValue()->isZero())
return getZero(C->getType());
return getCouldNotCompute(); }
return getCouldNotCompute();
}
std::pair<const BasicBlock *, const BasicBlock *>
ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
const {
if (const BasicBlock *Pred = BB->getSinglePredecessor())
return {Pred, BB};
if (const Loop *L = LI.getLoopFor(BB))
return {L->getLoopPredecessor(), L->getHeader()};
return {nullptr, nullptr};
}
static bool HasSameValue(const SCEV *A, const SCEV *B) {
if (A == B) return true;
auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
};
if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
if (ComputesEqualValues(AI, BI))
return true;
return false;
}
bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
const SCEV *&LHS, const SCEV *&RHS,
unsigned Depth,
bool ControllingFiniteLoop) {
bool Changed = false;
auto TrivialCase = [&](bool TriviallyTrue) {
LHS = RHS = getConstant(ConstantInt::getFalse(getContext()));
Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
return true;
};
if (Depth >= 3)
return false;
if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
if (ConstantExpr::getICmp(Pred,
LHSC->getValue(),
RHSC->getValue())->isNullValue())
return TrivialCase(false);
else
return TrivialCase(true);
}
std::swap(LHS, RHS);
Pred = ICmpInst::getSwappedPredicate(Pred);
Changed = true;
}
if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
const Loop *L = AR->getLoop();
if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
std::swap(LHS, RHS);
Pred = ICmpInst::getSwappedPredicate(Pred);
Changed = true;
}
}
if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
const APInt &RA = RC->getAPInt();
bool SimplifiedByConstantRange = false;
if (!ICmpInst::isEquality(Pred)) {
ConstantRange ExactCR = ConstantRange::makeExactICmpRegion(Pred, RA);
if (ExactCR.isFullSet())
return TrivialCase(true);
else if (ExactCR.isEmptySet())
return TrivialCase(false);
APInt NewRHS;
CmpInst::Predicate NewPred;
if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
ICmpInst::isEquality(NewPred)) {
Pred = NewPred;
RHS = getConstant(NewRHS);
Changed = SimplifiedByConstantRange = true;
}
}
if (!SimplifiedByConstantRange) {
switch (Pred) {
default:
break;
case ICmpInst::ICMP_EQ:
case ICmpInst::ICMP_NE:
if (!RA)
if (const SCEVAddExpr *AE = dyn_cast<SCEVAddExpr>(LHS))
if (const SCEVMulExpr *ME =
dyn_cast<SCEVMulExpr>(AE->getOperand(0)))
if (AE->getNumOperands() == 2 && ME->getNumOperands() == 2 &&
ME->getOperand(0)->isAllOnesValue()) {
RHS = AE->getOperand(1);
LHS = ME->getOperand(1);
Changed = true;
}
break;
case ICmpInst::ICMP_UGE:
assert(!RA.isMinValue() && "Should have been caught earlier!");
Pred = ICmpInst::ICMP_UGT;
RHS = getConstant(RA - 1);
Changed = true;
break;
case ICmpInst::ICMP_ULE:
assert(!RA.isMaxValue() && "Should have been caught earlier!");
Pred = ICmpInst::ICMP_ULT;
RHS = getConstant(RA + 1);
Changed = true;
break;
case ICmpInst::ICMP_SGE:
assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
Pred = ICmpInst::ICMP_SGT;
RHS = getConstant(RA - 1);
Changed = true;
break;
case ICmpInst::ICMP_SLE:
assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
Pred = ICmpInst::ICMP_SLT;
RHS = getConstant(RA + 1);
Changed = true;
break;
}
}
}
if (HasSameValue(LHS, RHS)) {
if (ICmpInst::isTrueWhenEqual(Pred))
return TrivialCase(true);
if (ICmpInst::isFalseWhenEqual(Pred))
return TrivialCase(false);
}
switch (Pred) {
case ICmpInst::ICMP_SLE:
if (ControllingFiniteLoop || !getSignedRangeMax(RHS).isMaxSignedValue()) {
RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
SCEV::FlagNSW);
Pred = ICmpInst::ICMP_SLT;
Changed = true;
} else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
SCEV::FlagNSW);
Pred = ICmpInst::ICMP_SLT;
Changed = true;
}
break;
case ICmpInst::ICMP_SGE:
if (ControllingFiniteLoop || !getSignedRangeMin(RHS).isMinSignedValue()) {
RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
SCEV::FlagNSW);
Pred = ICmpInst::ICMP_SGT;
Changed = true;
} else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
SCEV::FlagNSW);
Pred = ICmpInst::ICMP_SGT;
Changed = true;
}
break;
case ICmpInst::ICMP_ULE:
if (ControllingFiniteLoop || !getUnsignedRangeMax(RHS).isMaxValue()) {
RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
SCEV::FlagNUW);
Pred = ICmpInst::ICMP_ULT;
Changed = true;
} else if (!getUnsignedRangeMin(LHS).isMinValue()) {
LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
Pred = ICmpInst::ICMP_ULT;
Changed = true;
}
break;
case ICmpInst::ICMP_UGE:
if (ControllingFiniteLoop || !getUnsignedRangeMin(RHS).isMinValue()) {
RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
Pred = ICmpInst::ICMP_UGT;
Changed = true;
} else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
SCEV::FlagNUW);
Pred = ICmpInst::ICMP_UGT;
Changed = true;
}
break;
default:
break;
}
if (Changed)
return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1,
ControllingFiniteLoop);
return Changed;
}
bool ScalarEvolution::isKnownNegative(const SCEV *S) {
return getSignedRangeMax(S).isNegative();
}
bool ScalarEvolution::isKnownPositive(const SCEV *S) {
return getSignedRangeMin(S).isStrictlyPositive();
}
bool ScalarEvolution::isKnownNonNegative(const SCEV *S) {
return !getSignedRangeMin(S).isNegative();
}
bool ScalarEvolution::isKnownNonPositive(const SCEV *S) {
return !getSignedRangeMax(S).isStrictlyPositive();
}
bool ScalarEvolution::isKnownNonZero(const SCEV *S) {
return getUnsignedRangeMin(S) != 0;
}
std::pair<const SCEV *, const SCEV *>
ScalarEvolution::SplitIntoInitAndPostInc(const Loop *L, const SCEV *S) {
const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
if (Start == getCouldNotCompute())
return { Start, Start };
const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
return { Start, PostInc };
}
bool ScalarEvolution::isKnownViaInduction(ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS) {
SmallPtrSet<const Loop *, 8> LoopsUsed;
getUsedLoops(LHS, LoopsUsed);
getUsedLoops(RHS, LoopsUsed);
if (LoopsUsed.empty())
return false;
#ifndef NDEBUG
for (const auto *L1 : LoopsUsed)
for (const auto *L2 : LoopsUsed)
assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
DT.dominates(L2->getHeader(), L1->getHeader())) &&
"Domination relationship is not a linear order");
#endif
const Loop *MDL =
*std::max_element(LoopsUsed.begin(), LoopsUsed.end(),
[&](const Loop *L1, const Loop *L2) {
return DT.properlyDominates(L1->getHeader(), L2->getHeader());
});
auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
if (SplitLHS.first == getCouldNotCompute())
return false;
assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
if (SplitRHS.first == getCouldNotCompute())
return false;
assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
!isAvailableAtLoopEntry(SplitRHS.first, MDL))
return false;
return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
SplitRHS.second) &&
isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
}
bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS) {
(void)SimplifyICmpOperands(Pred, LHS, RHS);
if (isKnownViaInduction(Pred, LHS, RHS))
return true;
if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
return true;
return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
}
Optional<bool> ScalarEvolution::evaluatePredicate(ICmpInst::Predicate Pred,
const SCEV *LHS,
const SCEV *RHS) {
if (isKnownPredicate(Pred, LHS, RHS))
return true;
else if (isKnownPredicate(ICmpInst::getInversePredicate(Pred), LHS, RHS))
return false;
return None;
}
bool ScalarEvolution::isKnownPredicateAt(ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS,
const Instruction *CtxI) {
return isKnownPredicate(Pred, LHS, RHS) ||
isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS);
}
Optional<bool> ScalarEvolution::evaluatePredicateAt(ICmpInst::Predicate Pred,
const SCEV *LHS,
const SCEV *RHS,
const Instruction *CtxI) {
Optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
if (KnownWithoutContext)
return KnownWithoutContext;
if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
return true;
else if (isBasicBlockEntryGuardedByCond(CtxI->getParent(),
ICmpInst::getInversePredicate(Pred),
LHS, RHS))
return false;
return None;
}
bool ScalarEvolution::isKnownOnEveryIteration(ICmpInst::Predicate Pred,
const SCEVAddRecExpr *LHS,
const SCEV *RHS) {
const Loop *L = LHS->getLoop();
return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
}
Optional<ScalarEvolution::MonotonicPredicateType>
ScalarEvolution::getMonotonicPredicateType(const SCEVAddRecExpr *LHS,
ICmpInst::Predicate Pred) {
auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
#ifndef NDEBUG
if (Result) {
auto ResultSwapped =
getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
assert(ResultSwapped && "should be able to analyze both!");
assert(ResultSwapped.value() != Result.value() &&
"monotonicity should flip as we flip the predicate");
}
#endif
return Result;
}
Optional<ScalarEvolution::MonotonicPredicateType>
ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
ICmpInst::Predicate Pred) {
if (!ICmpInst::isRelational(Pred))
return None;
bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
"Should be greater or less!");
if (ICmpInst::isUnsigned(Pred)) {
if (!LHS->hasNoUnsignedWrap())
return None;
return IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing;
} else {
assert(ICmpInst::isSigned(Pred) &&
"Relational predicate is either signed or unsigned!");
if (!LHS->hasNoSignedWrap())
return None;
const SCEV *Step = LHS->getStepRecurrence(*this);
if (isKnownNonNegative(Step))
return IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing;
if (isKnownNonPositive(Step))
return !IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing;
return None;
}
}
Optional<ScalarEvolution::LoopInvariantPredicate>
ScalarEvolution::getLoopInvariantPredicate(ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS,
const Loop *L) {
if (!isLoopInvariant(RHS, L)) {
if (!isLoopInvariant(LHS, L))
return None;
std::swap(LHS, RHS);
Pred = ICmpInst::getSwappedPredicate(Pred);
}
const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
if (!ArLHS || ArLHS->getLoop() != L)
return None;
auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
if (!MonotonicType)
return None;
bool Increasing = *MonotonicType == ScalarEvolution::MonotonicallyIncreasing;
auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred);
if (!isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
return None;
return ScalarEvolution::LoopInvariantPredicate(Pred, ArLHS->getStart(), RHS);
}
Optional<ScalarEvolution::LoopInvariantPredicate>
ScalarEvolution::getLoopInvariantExitCondDuringFirstIterations(
ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
const Instruction *CtxI, const SCEV *MaxIter) {
if (!isLoopInvariant(RHS, L)) {
if (!isLoopInvariant(LHS, L))
return None;
std::swap(LHS, RHS);
Pred = ICmpInst::getSwappedPredicate(Pred);
}
auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
if (!AR || AR->getLoop() != L)
return None;
if (!ICmpInst::isRelational(Pred))
return None;
const SCEV *Step = AR->getStepRecurrence(*this);
auto *One = getOne(Step->getType());
auto *MinusOne = getNegativeSCEV(One);
if (Step != One && Step != MinusOne)
return None;
if (AR->getType() != MaxIter->getType())
return None;
const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
return None;
ICmpInst::Predicate NoOverflowPred =
CmpInst::isSigned(Pred) ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
if (Step == MinusOne)
NoOverflowPred = CmpInst::getSwappedPredicate(NoOverflowPred);
const SCEV *Start = AR->getStart();
if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
return None;
return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
}
bool ScalarEvolution::isKnownPredicateViaConstantRanges(
ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
if (HasSameValue(LHS, RHS))
return ICmpInst::isTrueWhenEqual(Pred);
auto CheckRanges = [&](const ConstantRange &RangeLHS,
const ConstantRange &RangeRHS) {
return RangeLHS.icmp(Pred, RangeRHS);
};
if (Pred == CmpInst::ICMP_EQ)
return false;
if (Pred == CmpInst::ICMP_NE) {
auto SL = getSignedRange(LHS);
auto SR = getSignedRange(RHS);
if (CheckRanges(SL, SR))
return true;
auto UL = getUnsignedRange(LHS);
auto UR = getUnsignedRange(RHS);
if (CheckRanges(UL, UR))
return true;
auto *Diff = getMinusSCEV(LHS, RHS);
return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
}
if (CmpInst::isSigned(Pred)) {
auto SL = getSignedRange(LHS);
auto SR = getSignedRange(RHS);
return CheckRanges(SL, SR);
}
auto UL = getUnsignedRange(LHS);
auto UR = getUnsignedRange(RHS);
return CheckRanges(UL, UR);
}
bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
const SCEV *LHS,
const SCEV *RHS) {
auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
APInt &OutC1, APInt &OutC2,
SCEV::NoWrapFlags ExpectedFlags) {
const SCEV *XNonConstOp, *XConstOp;
const SCEV *YNonConstOp, *YConstOp;
SCEV::NoWrapFlags XFlagsPresent;
SCEV::NoWrapFlags YFlagsPresent;
if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
XConstOp = getZero(X->getType());
XNonConstOp = X;
XFlagsPresent = ExpectedFlags;
}
if (!isa<SCEVConstant>(XConstOp) ||
(XFlagsPresent & ExpectedFlags) != ExpectedFlags)
return false;
if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
YConstOp = getZero(Y->getType());
YNonConstOp = Y;
YFlagsPresent = ExpectedFlags;
}
if (!isa<SCEVConstant>(YConstOp) ||
(YFlagsPresent & ExpectedFlags) != ExpectedFlags)
return false;
if (YNonConstOp != XNonConstOp)
return false;
OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
return true;
};
APInt C1;
APInt C2;
switch (Pred) {
default:
break;
case ICmpInst::ICMP_SGE:
std::swap(LHS, RHS);
LLVM_FALLTHROUGH;
case ICmpInst::ICMP_SLE:
if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
return true;
break;
case ICmpInst::ICMP_SGT:
std::swap(LHS, RHS);
LLVM_FALLTHROUGH;
case ICmpInst::ICMP_SLT:
if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
return true;
break;
case ICmpInst::ICMP_UGE:
std::swap(LHS, RHS);
LLVM_FALLTHROUGH;
case ICmpInst::ICMP_ULE:
if (MatchBinaryAddToConst(RHS, LHS, C2, C1, SCEV::FlagNUW) && C1.ule(C2))
return true;
break;
case ICmpInst::ICMP_UGT:
std::swap(LHS, RHS);
LLVM_FALLTHROUGH;
case ICmpInst::ICMP_ULT:
if (MatchBinaryAddToConst(RHS, LHS, C2, C1, SCEV::FlagNUW) && C1.ult(C2))
return true;
break;
}
return false;
}
bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred,
const SCEV *LHS,
const SCEV *RHS) {
if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
return false;
SaveAndRestore<bool> Restore(ProvingSplitPredicate, true);
return isKnownNonNegative(RHS) &&
isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) &&
isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS);
}
bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB,
ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS) {
if (!HasGuards)
return false;
return any_of(*BB, [&](const Instruction &I) {
using namespace llvm::PatternMatch;
Value *Condition;
return match(&I, m_Intrinsic<Intrinsic::experimental_guard>(
m_Value(Condition))) &&
isImpliedCond(Pred, LHS, RHS, Condition, false);
});
}
bool
ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS) {
if (!L || !DT.isReachableFromEntry(L->getHeader()))
return true;
if (VerifyIR)
assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
"This cannot be done on broken IR!");
if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
return true;
BasicBlock *Latch = L->getLoopLatch();
if (!Latch)
return false;
BranchInst *LoopContinuePredicate =
dyn_cast<BranchInst>(Latch->getTerminator());
if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
isImpliedCond(Pred, LHS, RHS,
LoopContinuePredicate->getCondition(),
LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
return true;
if (WalkingBEDominatingConds)
return false;
SaveAndRestore<bool> ClearOnExit(WalkingBEDominatingConds, true);
const auto &BETakenInfo = getBackedgeTakenInfo(L);
const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
if (LatchBECount != getCouldNotCompute()) {
Type *Ty = LatchBECount->getType();
auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
const SCEV *LoopCounter =
getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
LatchBECount))
return true;
}
for (auto &AssumeVH : AC.assumptions()) {
if (!AssumeVH)
continue;
auto *CI = cast<CallInst>(AssumeVH);
if (!DT.dominates(CI, Latch->getTerminator()))
continue;
if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
return true;
}
if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
return true;
for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
DTN != HeaderDTN; DTN = DTN->getIDom()) {
assert(DTN && "should reach the loop header before reaching the root!");
BasicBlock *BB = DTN->getBlock();
if (isImpliedViaGuard(BB, Pred, LHS, RHS))
return true;
BasicBlock *PBB = BB->getSinglePredecessor();
if (!PBB)
continue;
BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
if (!ContinuePredicate || !ContinuePredicate->isConditional())
continue;
Value *Condition = ContinuePredicate->getCondition();
BasicBlockEdge DominatingEdge(PBB, BB);
if (DominatingEdge.isSingleEdge()) {
assert(DT.dominates(DominatingEdge, Latch) && "should be!");
if (isImpliedCond(Pred, LHS, RHS, Condition,
BB != ContinuePredicate->getSuccessor(0)))
return true;
}
}
return false;
}
bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB,
ICmpInst::Predicate Pred,
const SCEV *LHS,
const SCEV *RHS) {
if (!DT.isReachableFromEntry(BB))
return true;
if (VerifyIR)
assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
"This cannot be done on broken IR!");
auto NonStrictPredicate = ICmpInst::getNonStrictPredicate(Pred);
const bool ProvingStrictComparison = (Pred != NonStrictPredicate);
bool ProvedNonStrictComparison = false;
bool ProvedNonEquality = false;
auto SplitAndProve =
[&](std::function<bool(ICmpInst::Predicate)> Fn) -> bool {
if (!ProvedNonStrictComparison)
ProvedNonStrictComparison = Fn(NonStrictPredicate);
if (!ProvedNonEquality)
ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
if (ProvedNonStrictComparison && ProvedNonEquality)
return true;
return false;
};
if (ProvingStrictComparison) {
auto ProofFn = [&](ICmpInst::Predicate P) {
return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
};
if (SplitAndProve(ProofFn))
return true;
}
auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
const Instruction *CtxI = &BB->front();
if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
return true;
if (ProvingStrictComparison) {
auto ProofFn = [&](ICmpInst::Predicate P) {
return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
};
if (SplitAndProve(ProofFn))
return true;
}
return false;
};
const Loop *ContainingLoop = LI.getLoopFor(BB);
const BasicBlock *PredBB;
if (ContainingLoop && ContainingLoop->getHeader() == BB)
PredBB = ContainingLoop->getLoopPredecessor();
else
PredBB = BB->getSinglePredecessor();
for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
const BranchInst *BlockEntryPredicate =
dyn_cast<BranchInst>(Pair.first->getTerminator());
if (!BlockEntryPredicate || BlockEntryPredicate->isUnconditional())
continue;
if (ProveViaCond(BlockEntryPredicate->getCondition(),
BlockEntryPredicate->getSuccessor(0) != Pair.second))
return true;
}
for (auto &AssumeVH : AC.assumptions()) {
if (!AssumeVH)
continue;
auto *CI = cast<CallInst>(AssumeVH);
if (!DT.dominates(CI, BB))
continue;
if (ProveViaCond(CI->getArgOperand(0), false))
return true;
}
auto *GuardDecl = F.getParent()->getFunction(
Intrinsic::getName(Intrinsic::experimental_guard));
if (GuardDecl)
for (const auto *GU : GuardDecl->users())
if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
if (ProveViaCond(Guard->getArgOperand(0), false))
return true;
return false;
}
bool ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
ICmpInst::Predicate Pred,
const SCEV *LHS,
const SCEV *RHS) {
if (!L)
return false;
assert(isAvailableAtLoopEntry(LHS, L) &&
"LHS is not available at Loop Entry");
assert(isAvailableAtLoopEntry(RHS, L) &&
"RHS is not available at Loop Entry");
if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
return true;
return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
}
bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
const SCEV *RHS,
const Value *FoundCondValue, bool Inverse,
const Instruction *CtxI) {
if (FoundCondValue ==
ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
return true;
if (!PendingLoopPredicates.insert(FoundCondValue).second)
return false;
auto ClearOnExit =
make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
const Value *Op0, *Op1;
if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
if (!Inverse)
return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
} else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
if (Inverse)
return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
}
const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
if (!ICI) return false;
ICmpInst::Predicate FoundPred;
if (Inverse)
FoundPred = ICI->getInversePredicate();
else
FoundPred = ICI->getPredicate();
const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
}
bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
const SCEV *RHS,
ICmpInst::Predicate FoundPred,
const SCEV *FoundLHS, const SCEV *FoundRHS,
const Instruction *CtxI) {
if (getTypeSizeInBits(LHS->getType()) <
getTypeSizeInBits(FoundLHS->getType())) {
if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
!FoundRHS->getType()->isPointerTy()) {
auto *NarrowType = LHS->getType();
auto *WideType = FoundLHS->getType();
auto BitWidth = getTypeSizeInBits(NarrowType);
const SCEV *MaxValue = getZeroExtendExpr(
getConstant(APInt::getMaxValue(BitWidth)), WideType);
if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
MaxValue) &&
isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
MaxValue)) {
const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, TruncFoundLHS,
TruncFoundRHS, CtxI))
return true;
}
}
if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
return false;
if (CmpInst::isSigned(Pred)) {
LHS = getSignExtendExpr(LHS, FoundLHS->getType());
RHS = getSignExtendExpr(RHS, FoundLHS->getType());
} else {
LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
}
} else if (getTypeSizeInBits(LHS->getType()) >
getTypeSizeInBits(FoundLHS->getType())) {
if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
return false;
if (CmpInst::isSigned(FoundPred)) {
FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
} else {
FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
}
}
return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
FoundRHS, CtxI);
}
bool ScalarEvolution::isImpliedCondBalancedTypes(
ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, const SCEV *FoundRHS,
const Instruction *CtxI) {
assert(getTypeSizeInBits(LHS->getType()) ==
getTypeSizeInBits(FoundLHS->getType()) &&
"Types should be balanced!");
if (SimplifyICmpOperands(Pred, LHS, RHS))
if (LHS == RHS)
return CmpInst::isTrueWhenEqual(Pred);
if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
if (FoundLHS == FoundRHS)
return CmpInst::isFalseWhenEqual(FoundPred);
if (LHS == FoundRHS || RHS == FoundLHS) {
if (isa<SCEVConstant>(RHS)) {
std::swap(FoundLHS, FoundRHS);
FoundPred = ICmpInst::getSwappedPredicate(FoundPred);
} else {
std::swap(LHS, RHS);
Pred = ICmpInst::getSwappedPredicate(Pred);
}
}
if (FoundPred == Pred)
return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
if (!isa<SCEVConstant>(RHS) && !isa<SCEVAddRecExpr>(LHS))
return isImpliedCondOperands(FoundPred, RHS, LHS, FoundLHS, FoundRHS,
CtxI);
if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS, CtxI);
if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
isImpliedCondOperands(FoundPred, getNotSCEV(LHS), getNotSCEV(RHS),
FoundLHS, FoundRHS, CtxI))
return true;
if (!FoundLHS->getType()->isPointerTy() &&
!FoundRHS->getType()->isPointerTy() &&
isImpliedCondOperands(Pred, LHS, RHS, getNotSCEV(FoundLHS),
getNotSCEV(FoundRHS), CtxI))
return true;
return false;
}
auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
CmpInst::Predicate P2) {
assert(P1 != P2 && "Handled earlier!");
return CmpInst::isRelational(P2) &&
P1 == CmpInst::getFlippedSignednessPredicate(P2);
};
if (IsSignFlippedPredicate(Pred, FoundPred)) {
if ((isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) ||
(isKnownNegative(FoundLHS) && isKnownNegative(FoundRHS)))
return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
ICmpInst::Predicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
*CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
CanonicalPred = ICmpInst::getSwappedPredicate(CanonicalPred);
CanonicalFoundPred = ICmpInst::getSwappedPredicate(CanonicalFoundPred);
std::swap(CanonicalLHS, CanonicalRHS);
std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
}
assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
"Must be!");
assert((ICmpInst::isLT(CanonicalFoundPred) ||
ICmpInst::isLE(CanonicalFoundPred)) &&
"Must be!");
if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
CanonicalRHS, CanonicalFoundLHS,
CanonicalFoundRHS);
if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
CanonicalRHS, CanonicalFoundLHS,
CanonicalFoundRHS);
}
if (FoundPred == ICmpInst::ICMP_NE &&
(isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
const SCEVConstant *C = nullptr;
const SCEV *V = nullptr;
if (isa<SCEVConstant>(FoundLHS)) {
C = cast<SCEVConstant>(FoundLHS);
V = FoundRHS;
} else {
C = cast<SCEVConstant>(FoundRHS);
V = FoundLHS;
}
APInt Min = ICmpInst::isSigned(Pred) ?
getSignedRangeMin(V) : getUnsignedRangeMin(V);
if (Min == C->getAPInt()) {
APInt SharperMin = Min + 1;
switch (Pred) {
case ICmpInst::ICMP_SGE:
case ICmpInst::ICMP_UGE:
if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
CtxI))
return true;
LLVM_FALLTHROUGH;
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_UGT:
if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
return true;
break;
case ICmpInst::ICMP_SLE:
case ICmpInst::ICMP_ULE:
if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
LHS, V, getConstant(SharperMin), CtxI))
return true;
LLVM_FALLTHROUGH;
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_ULT:
if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
LHS, V, getConstant(Min), CtxI))
return true;
break;
default:
break;
}
}
}
if (FoundPred == ICmpInst::ICMP_EQ)
if (ICmpInst::isTrueWhenEqual(Pred))
if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
return true;
if (Pred == ICmpInst::ICMP_NE)
if (!ICmpInst::isTrueWhenEqual(FoundPred))
if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
return true;
return false;
}
bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
const SCEV *&L, const SCEV *&R,
SCEV::NoWrapFlags &Flags) {
const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
if (!AE || AE->getNumOperands() != 2)
return false;
L = AE->getOperand(0);
R = AE->getOperand(1);
Flags = AE->getNoWrapFlags();
return true;
}
Optional<APInt> ScalarEvolution::computeConstantDifference(const SCEV *More,
const SCEV *Less) {
if (More == Less)
return APInt(getTypeSizeInBits(More->getType()), 0);
if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
const auto *LAR = cast<SCEVAddRecExpr>(Less);
const auto *MAR = cast<SCEVAddRecExpr>(More);
if (LAR->getLoop() != MAR->getLoop())
return None;
if (!LAR->isAffine() || !MAR->isAffine())
return None;
if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
return None;
Less = LAR->getStart();
More = MAR->getStart();
}
if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) {
const auto &M = cast<SCEVConstant>(More)->getAPInt();
const auto &L = cast<SCEVConstant>(Less)->getAPInt();
return M - L;
}
SCEV::NoWrapFlags Flags;
const SCEV *LLess = nullptr, *RLess = nullptr;
const SCEV *LMore = nullptr, *RMore = nullptr;
const SCEVConstant *C1 = nullptr, *C2 = nullptr;
if (splitBinaryAdd(Less, LLess, RLess, Flags))
if ((C1 = dyn_cast<SCEVConstant>(LLess)))
if (RLess == More)
return -(C1->getAPInt());
if (splitBinaryAdd(More, LMore, RMore, Flags))
if ((C2 = dyn_cast<SCEVConstant>(LMore)))
if (RMore == Less)
return C2->getAPInt();
if (C1 && C2 && RLess == RMore)
return C2->getAPInt() - C1->getAPInt();
return None;
}
bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
if (!CtxI)
return false;
const BasicBlock *ContextBB = CtxI->getParent();
if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
const Loop *L = AR->getLoop();
if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
return false;
if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
return false;
return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
}
if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
const Loop *L = AR->getLoop();
if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
return false;
if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
return false;
return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
}
return false;
}
bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(
ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
const SCEV *FoundLHS, const SCEV *FoundRHS) {
if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
return false;
const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
if (!AddRecLHS)
return false;
const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
if (!AddRecFoundLHS)
return false;
const Loop *L = AddRecFoundLHS->getLoop();
if (L != AddRecLHS->getLoop())
return false;
Optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
Optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
if (!LDiff || !RDiff || *LDiff != *RDiff)
return false;
if (LDiff->isMinValue())
return true;
APInt FoundRHSLimit;
if (Pred == CmpInst::ICMP_ULT) {
FoundRHSLimit = -(*RDiff);
} else {
assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
}
return isAvailableAtLoopEntry(FoundRHS, L) &&
isLoopEntryGuardedByCond(L, Pred, FoundRHS,
getConstant(FoundRHSLimit));
}
bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS,
const SCEV *FoundLHS,
const SCEV *FoundRHS, unsigned Depth) {
const PHINode *LPhi = nullptr, *RPhi = nullptr;
auto ClearOnExit = make_scope_exit([&]() {
if (LPhi) {
bool Erased = PendingMerges.erase(LPhi);
assert(Erased && "Failed to erase LPhi!");
(void)Erased;
}
if (RPhi) {
bool Erased = PendingMerges.erase(RPhi);
assert(Erased && "Failed to erase RPhi!");
(void)Erased;
}
});
if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
if (!PendingMerges.insert(Phi).second)
return false;
LPhi = Phi;
}
if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
if (!PendingMerges.insert(Phi).second)
return false;
RPhi = Phi;
}
if (!LPhi && !RPhi)
return false;
if (!LPhi) {
std::swap(LHS, RHS);
std::swap(FoundLHS, FoundRHS);
std::swap(LPhi, RPhi);
Pred = ICmpInst::getSwappedPredicate(Pred);
}
assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
const BasicBlock *LBB = LPhi->getParent();
const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
isImpliedCondOperandsViaRanges(Pred, S1, S2, FoundLHS, FoundRHS) ||
isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
};
if (RPhi && RPhi->getParent() == LBB) {
for (const BasicBlock *IncBB : predecessors(LBB)) {
const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
if (!ProvedEasily(L, R))
return false;
}
} else if (RAR && RAR->getLoop()->getHeader() == LBB) {
if (LPhi->getNumIncomingValues() != 2) return false;
auto *RLoop = RAR->getLoop();
auto *Predecessor = RLoop->getLoopPredecessor();
assert(Predecessor && "Loop with AddRec with no predecessor?");
const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
if (!ProvedEasily(L1, RAR->getStart()))
return false;
auto *Latch = RLoop->getLoopLatch();
assert(Latch && "Loop with AddRec with no latch?");
const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
return false;
} else {
for (const BasicBlock *IncBB : predecessors(LBB)) {
if (!dominates(RHS, IncBB))
return false;
const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
if (!properlyDominates(L, LBB))
return false;
if (!ProvedEasily(L, RHS))
return false;
}
}
return true;
}
bool ScalarEvolution::isImpliedCondOperandsViaShift(ICmpInst::Predicate Pred,
const SCEV *LHS,
const SCEV *RHS,
const SCEV *FoundLHS,
const SCEV *FoundRHS) {
if (RHS == FoundRHS) {
std::swap(LHS, RHS);
std::swap(FoundLHS, FoundRHS);
Pred = ICmpInst::getSwappedPredicate(Pred);
}
if (LHS != FoundLHS)
return false;
auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
if (!SUFoundRHS)
return false;
Value *Shiftee, *ShiftValue;
using namespace PatternMatch;
if (match(SUFoundRHS->getValue(),
m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
auto *ShifteeS = getSCEV(Shiftee);
if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
if (isKnownNonNegative(ShifteeS))
return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
}
return false;
}
bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS,
const SCEV *FoundLHS,
const SCEV *FoundRHS,
const Instruction *CtxI) {
if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS))
return true;
if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS))
return true;
if (isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS))
return true;
if (isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
CtxI))
return true;
return isImpliedCondOperandsHelper(Pred, LHS, RHS,
FoundLHS, FoundRHS);
}
template <typename MinMaxExprType>
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
const SCEV *Candidate) {
const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
if (!MinMaxExpr)
return false;
return is_contained(MinMaxExpr->operands(), Candidate);
}
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE,
ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS) {
if (!ICmpInst::isRelational(Pred))
return false;
const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS);
if (!LAR)
return false;
const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
if (!RAR)
return false;
if (LAR->getLoop() != RAR->getLoop())
return false;
if (!LAR->isAffine() || !RAR->isAffine())
return false;
if (LAR->getStepRecurrence(SE) != RAR->getStepRecurrence(SE))
return false;
SCEV::NoWrapFlags NW = ICmpInst::isSigned(Pred) ?
SCEV::FlagNSW : SCEV::FlagNUW;
if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
return false;
return SE.isKnownPredicate(Pred, LAR->getStart(), RAR->getStart());
}
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE,
ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS) {
switch (Pred) {
default:
return false;
case ICmpInst::ICMP_SGE:
std::swap(LHS, RHS);
LLVM_FALLTHROUGH;
case ICmpInst::ICMP_SLE:
return
IsMinMaxConsistingOf<SCEVSMinExpr>(LHS, RHS) ||
IsMinMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS);
case ICmpInst::ICMP_UGE:
std::swap(LHS, RHS);
LLVM_FALLTHROUGH;
case ICmpInst::ICMP_ULE:
return
IsMinMaxConsistingOf<SCEVUMinExpr>(LHS, RHS) ||
IsMinMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS);
}
llvm_unreachable("covered switch fell through?!");
}
bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS,
const SCEV *FoundLHS,
const SCEV *FoundRHS,
unsigned Depth) {
assert(getTypeSizeInBits(LHS->getType()) ==
getTypeSizeInBits(RHS->getType()) &&
"LHS and RHS have different sizes?");
assert(getTypeSizeInBits(FoundLHS->getType()) ==
getTypeSizeInBits(FoundRHS->getType()) &&
"FoundLHS and FoundRHS have different sizes?");
if (Depth > MaxSCEVOperationsImplicationDepth)
return false;
if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) {
Pred = CmpInst::getSwappedPredicate(Pred);
std::swap(LHS, RHS);
std::swap(FoundLHS, FoundRHS);
}
if (Pred == ICmpInst::ICMP_UGT)
if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
const SCEV *MinusOne = getMinusOne(LHS->getType());
if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
FoundRHS) &&
isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
FoundRHS))
Pred = ICmpInst::ICMP_SGT;
}
if (Pred != ICmpInst::ICMP_SGT)
return false;
auto GetOpFromSExt = [&](const SCEV *S) {
if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
return Ext->getOperand();
return S;
};
auto *OrigLHS = LHS;
auto *OrigFoundLHS = FoundLHS;
LHS = GetOpFromSExt(LHS);
FoundLHS = GetOpFromSExt(FoundLHS);
auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
FoundRHS, Depth + 1);
};
if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
if (getTypeSizeInBits(LHS->getType()) != getTypeSizeInBits(RHS->getType()))
return false;
if (!LHSAddExpr->hasNoSignedWrap())
return false;
auto *LL = LHSAddExpr->getOperand(0);
auto *LR = LHSAddExpr->getOperand(1);
auto *MinusOne = getMinusOne(RHS->getType());
auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
};
if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
return true;
} else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
Value *LL, *LR;
using namespace llvm::PatternMatch;
if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
if (!isa<ConstantInt>(LR))
return false;
auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
auto *Numerator = getExistingSCEV(LL);
if (!Numerator || Numerator->getType() != FoundLHS->getType())
return false;
if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
return false;
auto *DTy = Denominator->getType();
auto *FRHSTy = FoundRHS->getType();
if (DTy->isPointerTy() != FRHSTy->isPointerTy())
return false;
auto *WTy = getWiderType(DTy, FRHSTy);
auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
if (isKnownNonPositive(RHS) &&
IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
return true;
auto *MinusOne = getMinusOne(WTy);
auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
if (isKnownNegative(RHS) &&
IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
return true;
}
}
if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
return true;
return false;
}
static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS) {
switch (Pred) {
case ICmpInst::ICMP_SGE:
std::swap(LHS, RHS);
LLVM_FALLTHROUGH;
case ICmpInst::ICMP_SLE: {
const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(LHS);
const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(RHS);
if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
return true;
break;
}
case ICmpInst::ICMP_UGE:
std::swap(LHS, RHS);
LLVM_FALLTHROUGH;
case ICmpInst::ICMP_ULE: {
const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS);
const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(RHS);
if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
return true;
break;
}
default:
break;
};
return false;
}
bool
ScalarEvolution::isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS) {
return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
}
bool
ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS,
const SCEV *FoundLHS,
const SCEV *FoundRHS) {
switch (Pred) {
default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
case ICmpInst::ICMP_EQ:
case ICmpInst::ICMP_NE:
if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
return true;
break;
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_SLE:
if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
return true;
break;
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_SGE:
if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
return true;
break;
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_ULE:
if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
return true;
break;
case ICmpInst::ICMP_UGT:
case ICmpInst::ICMP_UGE:
if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
return true;
break;
}
if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
return true;
return false;
}
bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred,
const SCEV *LHS,
const SCEV *RHS,
const SCEV *FoundLHS,
const SCEV *FoundRHS) {
if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
return false;
Optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
if (!Addend)
return false;
const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
ConstantRange FoundLHSRange =
ConstantRange::makeExactICmpRegion(Pred, ConstFoundRHS);
ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
return LHSRange.icmp(Pred, ConstRHS);
}
bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
bool IsSigned) {
assert(isKnownPositive(Stride) && "Positive stride expected!");
unsigned BitWidth = getTypeSizeInBits(RHS->getType());
const SCEV *One = getOne(Stride->getType());
if (IsSigned) {
APInt MaxRHS = getSignedRangeMax(RHS);
APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
}
APInt MaxRHS = getUnsignedRangeMax(RHS);
APInt MaxValue = APInt::getMaxValue(BitWidth);
APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
}
bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
bool IsSigned) {
unsigned BitWidth = getTypeSizeInBits(RHS->getType());
const SCEV *One = getOne(Stride->getType());
if (IsSigned) {
APInt MinRHS = getSignedRangeMin(RHS);
APInt MinValue = APInt::getSignedMinValue(BitWidth);
APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
}
APInt MinRHS = getUnsignedRangeMin(RHS);
APInt MinValue = APInt::getMinValue(BitWidth);
APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
}
const SCEV *ScalarEvolution::getUDivCeilSCEV(const SCEV *N, const SCEV *D) {
const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
}
const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
const SCEV *Stride,
const SCEV *End,
unsigned BitWidth,
bool IsSigned) {
if (IsSigned && BitWidth == 1)
return getZero(Stride->getType());
assert((!IsSigned || !isKnownNonPositive(Stride)) &&
"Stride is expected strictly positive for signed case!");
APInt MinStart =
IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
APInt MinStride =
IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
APInt One(BitWidth, 1);
APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
: APIntOps::umax(One, MinStride);
APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
: APInt::getMaxValue(BitWidth);
APInt Limit = MaxValue - (StrideForMaxBECount - 1);
APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
: APIntOps::umin(getUnsignedRangeMax(End), Limit);
MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
: APIntOps::umax(MaxEnd, MinStart);
return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) ,
getConstant(StrideForMaxBECount) );
}
ScalarEvolution::ExitLimit
ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
const Loop *L, bool IsSigned,
bool ControlsExit, bool AllowPredicates) {
SmallPtrSet<const SCEVPredicate *, 4> Predicates;
const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
bool PredicatedIV = false;
auto canAssumeNoSelfWrap = [&](const SCEVAddRecExpr *AR) {
if (!isLoopInvariant(RHS, L))
return false;
auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
if (!StrideC || !StrideC->getAPInt().isPowerOf2())
return false;
if (!ControlsExit || !loopHasNoAbnormalExits(L))
return false;
return loopIsFiniteByAssumption(L);
};
if (!IV) {
if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
if (AR && AR->getLoop() == L && AR->isAffine()) {
auto canProveNUW = [&]() {
if (!isLoopInvariant(RHS, L))
return false;
if (!isKnownNonZero(AR->getStepRecurrence(*this)))
return false;
const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
Limit = Limit.zext(OuterBitWidth);
return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
};
auto Flags = AR->getNoWrapFlags();
if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
Flags = setFlags(Flags, SCEV::FlagNUW);
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
if (AR->hasNoUnsignedWrap()) {
const SCEV *Step = AR->getStepRecurrence(*this);
Type *Ty = ZExt->getType();
auto *S = getAddRecExpr(
getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, 0),
getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
IV = dyn_cast<SCEVAddRecExpr>(S);
}
}
}
}
if (!IV && AllowPredicates) {
IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
PredicatedIV = true;
}
if (!IV || IV->getLoop() != L || !IV->isAffine())
return getCouldNotCompute();
auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
bool NoWrap = ControlsExit && IV->getNoWrapFlags(WrapType);
ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
const SCEV *Stride = IV->getStepRecurrence(*this);
bool PositiveStride = isKnownPositive(Stride);
if (!PositiveStride) {
if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
!loopHasNoAbnormalExits(L))
return getCouldNotCompute();
if (IsSigned && isKnownNonPositive(Stride))
return getCouldNotCompute();
if (!isKnownNonZero(Stride)) {
if (!isLoopInvariant(RHS, L))
return getCouldNotCompute();
auto wouldZeroStrideBeUB = [&]() {
auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
};
if (!wouldZeroStrideBeUB()) {
Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
}
}
} else if (!Stride->isOne() && !NoWrap) {
auto isUBOnWrap = [&]() {
return canAssumeNoSelfWrap(IV);
};
if (canIVOverflowOnLT(RHS, Stride, IsSigned) && !isUBOnWrap())
return getCouldNotCompute();
}
const SCEV *Start = IV->getStart();
const SCEV *OrigStart = Start;
const SCEV *OrigRHS = RHS;
if (Start->getType()->isPointerTy()) {
Start = getLosslessPtrToIntExpr(Start);
if (isa<SCEVCouldNotCompute>(Start))
return Start;
}
if (RHS->getType()->isPointerTy()) {
RHS = getLosslessPtrToIntExpr(RHS);
if (isa<SCEVCouldNotCompute>(RHS))
return RHS;
}
if (!isLoopInvariant(RHS, L)) {
const SCEV *MaxBECount = computeMaxBECountForLT(
Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
return ExitLimit(getCouldNotCompute() , MaxBECount,
false , Predicates);
}
const SCEV *BECount = nullptr;
auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
const SCEV *MinusOne = getMinusOne(Stride->getType());
const SCEV *Numerator =
getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
BECount = getUDivExpr(Numerator, Stride);
}
const SCEV *BECountIfBackedgeTaken = nullptr;
if (!BECount) {
auto canProveRHSGreaterThanEqualStart = [&]() {
auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart))
return true;
auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
auto *StartMinusOne = getAddExpr(OrigStart,
getMinusOne(OrigStart->getType()));
return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
};
const SCEV *End;
if (canProveRHSGreaterThanEqualStart()) {
End = RHS;
} else {
End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
BECountIfBackedgeTaken = getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
}
const SCEV *One = getOne(Stride->getType());
bool MayAddOverflow = [&] {
if (auto *StrideC = dyn_cast<SCEVConstant>(Stride)) {
if (StrideC->getAPInt().isPowerOf2()) {
return false;
}
}
if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
return false;
}
return true;
}();
const SCEV *Delta = getMinusSCEV(End, Start);
if (!MayAddOverflow) {
BECount =
getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
} else {
BECount = getUDivCeilSCEV(Delta, Stride);
}
}
const SCEV *MaxBECount;
bool MaxOrZero = false;
if (isa<SCEVConstant>(BECount)) {
MaxBECount = BECount;
} else if (BECountIfBackedgeTaken &&
isa<SCEVConstant>(BECountIfBackedgeTaken)) {
MaxBECount = BECountIfBackedgeTaken;
MaxOrZero = true;
} else {
MaxBECount = computeMaxBECountForLT(
Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
}
if (isa<SCEVCouldNotCompute>(MaxBECount) &&
!isa<SCEVCouldNotCompute>(BECount))
MaxBECount = getConstant(getUnsignedRangeMax(BECount));
return ExitLimit(BECount, MaxBECount, MaxOrZero, Predicates);
}
ScalarEvolution::ExitLimit
ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
const Loop *L, bool IsSigned,
bool ControlsExit, bool AllowPredicates) {
SmallPtrSet<const SCEVPredicate *, 4> Predicates;
if (!isLoopInvariant(RHS, L))
return getCouldNotCompute();
const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
if (!IV && AllowPredicates)
IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
if (!IV || IV->getLoop() != L || !IV->isAffine())
return getCouldNotCompute();
auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
bool NoWrap = ControlsExit && IV->getNoWrapFlags(WrapType);
ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
if (!isKnownPositive(Stride))
return getCouldNotCompute();
if (!Stride->isOne() && !NoWrap)
if (canIVOverflowOnGT(RHS, Stride, IsSigned))
return getCouldNotCompute();
const SCEV *Start = IV->getStart();
const SCEV *End = RHS;
if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
if (isLoopEntryGuardedByCond(
L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
End = RHS;
else
End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
}
if (Start->getType()->isPointerTy()) {
Start = getLosslessPtrToIntExpr(Start);
if (isa<SCEVCouldNotCompute>(Start))
return Start;
}
if (End->getType()->isPointerTy()) {
End = getLosslessPtrToIntExpr(End);
if (isa<SCEVCouldNotCompute>(End))
return End;
}
const SCEV *One = getOne(Stride->getType());
const SCEV *BECount = getUDivExpr(
getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
: getUnsignedRangeMax(Start);
APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
: getUnsignedRangeMin(Stride);
unsigned BitWidth = getTypeSizeInBits(LHS->getType());
APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
: APInt::getMinValue(BitWidth) + (MinStride - 1);
APInt MinEnd =
IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
: APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
const SCEV *MaxBECount = isa<SCEVConstant>(BECount)
? BECount
: getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
getConstant(MinStride));
if (isa<SCEVCouldNotCompute>(MaxBECount))
MaxBECount = BECount;
return ExitLimit(BECount, MaxBECount, false, Predicates);
}
const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,
ScalarEvolution &SE) const {
if (Range.isFullSet()) return SE.getCouldNotCompute();
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
if (!SC->getValue()->isZero()) {
SmallVector<const SCEV *, 4> Operands(operands());
Operands[0] = SE.getZero(SC->getType());
const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
getNoWrapFlags(FlagNW));
if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
return ShiftedAddRec->getNumIterationsInRange(
Range.subtract(SC->getAPInt()), SE);
return SE.getCouldNotCompute();
}
if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
return SE.getCouldNotCompute();
unsigned BitWidth = SE.getTypeSizeInBits(getType());
if (!Range.contains(APInt(BitWidth, 0)))
return SE.getZero(getType());
if (isAffine()) {
APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
APInt ExitVal = (End + A).udiv(A);
ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
if (Range.contains(Val->getValue()))
return SE.getCouldNotCompute();
assert(Range.contains(
EvaluateConstantChrecAtConstant(this,
ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
"Linear scev computation is off in a bad way!");
return SE.getConstant(ExitValue);
}
if (isQuadratic()) {
if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
return SE.getConstant(*S);
}
return SE.getCouldNotCompute();
}
const SCEVAddRecExpr *
SCEVAddRecExpr::getPostIncExpr(ScalarEvolution &SE) const {
assert(getNumOperands() > 1 && "AddRec with zero step?");
SmallVector<const SCEV *, 3> Ops;
for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
const SCEV *Last = getOperand(getNumOperands() - 1);
assert(!Last->isZero() && "Recurrency with zero step?");
Ops.push_back(Last);
return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, getLoop(),
SCEV::FlagAnyWrap));
}
bool ScalarEvolution::containsUndefs(const SCEV *S) const {
return SCEVExprContains(S, [](const SCEV *S) {
if (const auto *SU = dyn_cast<SCEVUnknown>(S))
return isa<UndefValue>(SU->getValue());
return false;
});
}
bool ScalarEvolution::containsErasedValue(const SCEV *S) const {
return SCEVExprContains(S, [](const SCEV *S) {
if (const auto *SU = dyn_cast<SCEVUnknown>(S))
return SU->getValue() == nullptr;
return false;
});
}
const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) {
Type *Ty;
if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
Ty = Store->getValueOperand()->getType();
else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
Ty = Load->getType();
else
return nullptr;
Type *ETy = getEffectiveSCEVType(PointerType::getUnqual(Ty));
return getSizeOfExpr(ETy, Ty);
}
void ScalarEvolution::SCEVCallbackVH::deleted() {
assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
SE->ConstantEvolutionLoopExitValue.erase(PN);
SE->eraseValueFromMap(getValPtr());
}
void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
Value *Old = getValPtr();
SmallVector<User *, 16> Worklist(Old->users());
SmallPtrSet<User *, 8> Visited;
while (!Worklist.empty()) {
User *U = Worklist.pop_back_val();
if (U == Old)
continue;
if (!Visited.insert(U).second)
continue;
if (PHINode *PN = dyn_cast<PHINode>(U))
SE->ConstantEvolutionLoopExitValue.erase(PN);
SE->eraseValueFromMap(U);
llvm::append_range(Worklist, U->users());
}
if (PHINode *PN = dyn_cast<PHINode>(Old))
SE->ConstantEvolutionLoopExitValue.erase(PN);
SE->eraseValueFromMap(Old);
}
ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
: CallbackVH(V), SE(se) {}
ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI,
AssumptionCache &AC, DominatorTree &DT,
LoopInfo &LI)
: F(F), TLI(TLI), AC(AC), DT(DT), LI(LI),
CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
LoopDispositions(64), BlockDispositions(64) {
auto *GuardDecl = F.getParent()->getFunction(
Intrinsic::getName(Intrinsic::experimental_guard));
HasGuards = GuardDecl && !GuardDecl->use_empty();
}
ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)
: F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT),
LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
ValueExprMap(std::move(Arg.ValueExprMap)),
PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
PendingMerges(std::move(Arg.PendingMerges)),
MinTrailingZerosCache(std::move(Arg.MinTrailingZerosCache)),
BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
PredicatedBackedgeTakenCounts(
std::move(Arg.PredicatedBackedgeTakenCounts)),
BECountUsers(std::move(Arg.BECountUsers)),
ConstantEvolutionLoopExitValue(
std::move(Arg.ConstantEvolutionLoopExitValue)),
ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
LoopDispositions(std::move(Arg.LoopDispositions)),
LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
BlockDispositions(std::move(Arg.BlockDispositions)),
SCEVUsers(std::move(Arg.SCEVUsers)),
UnsignedRanges(std::move(Arg.UnsignedRanges)),
SignedRanges(std::move(Arg.SignedRanges)),
UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
UniquePreds(std::move(Arg.UniquePreds)),
SCEVAllocator(std::move(Arg.SCEVAllocator)),
LoopUsers(std::move(Arg.LoopUsers)),
PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
FirstUnknown(Arg.FirstUnknown) {
Arg.FirstUnknown = nullptr;
}
ScalarEvolution::~ScalarEvolution() {
for (SCEVUnknown *U = FirstUnknown; U;) {
SCEVUnknown *Tmp = U;
U = U->Next;
Tmp->~SCEVUnknown();
}
FirstUnknown = nullptr;
ExprValueMap.clear();
ValueExprMap.clear();
HasRecMap.clear();
BackedgeTakenCounts.clear();
PredicatedBackedgeTakenCounts.clear();
assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
assert(PendingPhiRanges.empty() && "getRangeRef garbage");
assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
}
bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) {
return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
}
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
const Loop *L) {
for (Loop *I : *L)
PrintLoopInfo(OS, SE, I);
OS << "Loop ";
L->getHeader()->printAsOperand(OS, false);
OS << ": ";
SmallVector<BasicBlock *, 8> ExitingBlocks;
L->getExitingBlocks(ExitingBlocks);
if (ExitingBlocks.size() != 1)
OS << "<multiple exits> ";
if (SE->hasLoopInvariantBackedgeTakenCount(L))
OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L) << "\n";
else
OS << "Unpredictable backedge-taken count.\n";
if (ExitingBlocks.size() > 1)
for (BasicBlock *ExitingBlock : ExitingBlocks) {
OS << " exit count for " << ExitingBlock->getName() << ": "
<< *SE->getExitCount(L, ExitingBlock) << "\n";
}
OS << "Loop ";
L->getHeader()->printAsOperand(OS, false);
OS << ": ";
if (!isa<SCEVCouldNotCompute>(SE->getConstantMaxBackedgeTakenCount(L))) {
OS << "max backedge-taken count is " << *SE->getConstantMaxBackedgeTakenCount(L);
if (SE->isBackedgeTakenCountMaxOrZero(L))
OS << ", actual taken count either this or zero.";
} else {
OS << "Unpredictable max backedge-taken count. ";
}
OS << "\n"
"Loop ";
L->getHeader()->printAsOperand(OS, false);
OS << ": ";
SmallVector<const SCEVPredicate *, 4> Preds;
auto PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
if (!isa<SCEVCouldNotCompute>(PBT)) {
OS << "Predicated backedge-taken count is " << *PBT << "\n";
OS << " Predicates:\n";
for (const auto *P : Preds)
P->print(OS, 4);
} else {
OS << "Unpredictable predicated backedge-taken count. ";
}
OS << "\n";
if (SE->hasLoopInvariantBackedgeTakenCount(L)) {
OS << "Loop ";
L->getHeader()->printAsOperand(OS, false);
OS << ": ";
OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
}
}
static StringRef loopDispositionToStr(ScalarEvolution::LoopDisposition LD) {
switch (LD) {
case ScalarEvolution::LoopVariant:
return "Variant";
case ScalarEvolution::LoopInvariant:
return "Invariant";
case ScalarEvolution::LoopComputable:
return "Computable";
}
llvm_unreachable("Unknown ScalarEvolution::LoopDisposition kind!");
}
void ScalarEvolution::print(raw_ostream &OS) const {
ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
if (ClassifyExpressions) {
OS << "Classifying expressions for: ";
F.printAsOperand(OS, false);
OS << "\n";
for (Instruction &I : instructions(F))
if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
OS << I << '\n';
OS << " --> ";
const SCEV *SV = SE.getSCEV(&I);
SV->print(OS);
if (!isa<SCEVCouldNotCompute>(SV)) {
OS << " U: ";
SE.getUnsignedRange(SV).print(OS);
OS << " S: ";
SE.getSignedRange(SV).print(OS);
}
const Loop *L = LI.getLoopFor(I.getParent());
const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
if (AtUse != SV) {
OS << " --> ";
AtUse->print(OS);
if (!isa<SCEVCouldNotCompute>(AtUse)) {
OS << " U: ";
SE.getUnsignedRange(AtUse).print(OS);
OS << " S: ";
SE.getSignedRange(AtUse).print(OS);
}
}
if (L) {
OS << "\t\t" "Exits: ";
const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
if (!SE.isLoopInvariant(ExitValue, L)) {
OS << "<<Unknown>>";
} else {
OS << *ExitValue;
}
bool First = true;
for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
if (First) {
OS << "\t\t" "LoopDispositions: { ";
First = false;
} else {
OS << ", ";
}
Iter->getHeader()->printAsOperand(OS, false);
OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, Iter));
}
for (const auto *InnerL : depth_first(L)) {
if (InnerL == L)
continue;
if (First) {
OS << "\t\t" "LoopDispositions: { ";
First = false;
} else {
OS << ", ";
}
InnerL->getHeader()->printAsOperand(OS, false);
OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, InnerL));
}
OS << " }";
}
OS << "\n";
}
}
OS << "Determining loop execution counts for: ";
F.printAsOperand(OS, false);
OS << "\n";
for (Loop *I : LI)
PrintLoopInfo(OS, &SE, I);
}
ScalarEvolution::LoopDisposition
ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) {
auto &Values = LoopDispositions[S];
for (auto &V : Values) {
if (V.getPointer() == L)
return V.getInt();
}
Values.emplace_back(L, LoopVariant);
LoopDisposition D = computeLoopDisposition(S, L);
auto &Values2 = LoopDispositions[S];
for (auto &V : llvm::reverse(Values2)) {
if (V.getPointer() == L) {
V.setInt(D);
break;
}
}
return D;
}
ScalarEvolution::LoopDisposition
ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
switch (S->getSCEVType()) {
case scConstant:
return LoopInvariant;
case scPtrToInt:
case scTruncate:
case scZeroExtend:
case scSignExtend:
return getLoopDisposition(cast<SCEVCastExpr>(S)->getOperand(), L);
case scAddRecExpr: {
const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
if (AR->getLoop() == L)
return LoopComputable;
if (!L)
return LoopVariant;
if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
return LoopVariant;
assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
" dominate the contained loop's header?");
if (AR->getLoop()->contains(L))
return LoopInvariant;
for (const auto *Op : AR->operands())
if (!isLoopInvariant(Op, L))
return LoopVariant;
return LoopInvariant;
}
case scAddExpr:
case scMulExpr:
case scUMaxExpr:
case scSMaxExpr:
case scUMinExpr:
case scSMinExpr:
case scSequentialUMinExpr: {
bool HasVarying = false;
for (const auto *Op : cast<SCEVNAryExpr>(S)->operands()) {
LoopDisposition D = getLoopDisposition(Op, L);
if (D == LoopVariant)
return LoopVariant;
if (D == LoopComputable)
HasVarying = true;
}
return HasVarying ? LoopComputable : LoopInvariant;
}
case scUDivExpr: {
const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
LoopDisposition LD = getLoopDisposition(UDiv->getLHS(), L);
if (LD == LoopVariant)
return LoopVariant;
LoopDisposition RD = getLoopDisposition(UDiv->getRHS(), L);
if (RD == LoopVariant)
return LoopVariant;
return (LD == LoopInvariant && RD == LoopInvariant) ?
LoopInvariant : LoopComputable;
}
case scUnknown:
if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
return LoopInvariant;
case scCouldNotCompute:
llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
}
llvm_unreachable("Unknown SCEV kind!");
}
bool ScalarEvolution::isLoopInvariant(const SCEV *S, const Loop *L) {
return getLoopDisposition(S, L) == LoopInvariant;
}
bool ScalarEvolution::hasComputableLoopEvolution(const SCEV *S, const Loop *L) {
return getLoopDisposition(S, L) == LoopComputable;
}
ScalarEvolution::BlockDisposition
ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) {
auto &Values = BlockDispositions[S];
for (auto &V : Values) {
if (V.getPointer() == BB)
return V.getInt();
}
Values.emplace_back(BB, DoesNotDominateBlock);
BlockDisposition D = computeBlockDisposition(S, BB);
auto &Values2 = BlockDispositions[S];
for (auto &V : llvm::reverse(Values2)) {
if (V.getPointer() == BB) {
V.setInt(D);
break;
}
}
return D;
}
ScalarEvolution::BlockDisposition
ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
switch (S->getSCEVType()) {
case scConstant:
return ProperlyDominatesBlock;
case scPtrToInt:
case scTruncate:
case scZeroExtend:
case scSignExtend:
return getBlockDisposition(cast<SCEVCastExpr>(S)->getOperand(), BB);
case scAddRecExpr: {
const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
if (!DT.dominates(AR->getLoop()->getHeader(), BB))
return DoesNotDominateBlock;
LLVM_FALLTHROUGH;
}
case scAddExpr:
case scMulExpr:
case scUMaxExpr:
case scSMaxExpr:
case scUMinExpr:
case scSMinExpr:
case scSequentialUMinExpr: {
const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
bool Proper = true;
for (const SCEV *NAryOp : NAry->operands()) {
BlockDisposition D = getBlockDisposition(NAryOp, BB);
if (D == DoesNotDominateBlock)
return DoesNotDominateBlock;
if (D == DominatesBlock)
Proper = false;
}
return Proper ? ProperlyDominatesBlock : DominatesBlock;
}
case scUDivExpr: {
const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
const SCEV *LHS = UDiv->getLHS(), *RHS = UDiv->getRHS();
BlockDisposition LD = getBlockDisposition(LHS, BB);
if (LD == DoesNotDominateBlock)
return DoesNotDominateBlock;
BlockDisposition RD = getBlockDisposition(RHS, BB);
if (RD == DoesNotDominateBlock)
return DoesNotDominateBlock;
return (LD == ProperlyDominatesBlock && RD == ProperlyDominatesBlock) ?
ProperlyDominatesBlock : DominatesBlock;
}
case scUnknown:
if (Instruction *I =
dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
if (I->getParent() == BB)
return DominatesBlock;
if (DT.properlyDominates(I->getParent(), BB))
return ProperlyDominatesBlock;
return DoesNotDominateBlock;
}
return ProperlyDominatesBlock;
case scCouldNotCompute:
llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
}
llvm_unreachable("Unknown SCEV kind!");
}
bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
return getBlockDisposition(S, BB) >= DominatesBlock;
}
bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) {
return getBlockDisposition(S, BB) == ProperlyDominatesBlock;
}
bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
}
void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
bool Predicated) {
auto &BECounts =
Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
auto It = BECounts.find(L);
if (It != BECounts.end()) {
for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
if (!isa<SCEVConstant>(ENT.ExactNotTaken)) {
auto UserIt = BECountUsers.find(ENT.ExactNotTaken);
assert(UserIt != BECountUsers.end());
UserIt->second.erase({L, Predicated});
}
}
BECounts.erase(It);
}
}
void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
SmallPtrSet<const SCEV *, 8> ToForget(SCEVs.begin(), SCEVs.end());
SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
while (!Worklist.empty()) {
const SCEV *Curr = Worklist.pop_back_val();
auto Users = SCEVUsers.find(Curr);
if (Users != SCEVUsers.end())
for (const auto *User : Users->second)
if (ToForget.insert(User).second)
Worklist.push_back(User);
}
for (const auto *S : ToForget)
forgetMemoizedResultsImpl(S);
for (auto I = PredicatedSCEVRewrites.begin();
I != PredicatedSCEVRewrites.end();) {
std::pair<const SCEV *, const Loop *> Entry = I->first;
if (ToForget.count(Entry.first))
PredicatedSCEVRewrites.erase(I++);
else
++I;
}
}
void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
LoopDispositions.erase(S);
BlockDispositions.erase(S);
UnsignedRanges.erase(S);
SignedRanges.erase(S);
HasRecMap.erase(S);
MinTrailingZerosCache.erase(S);
auto ExprIt = ExprValueMap.find(S);
if (ExprIt != ExprValueMap.end()) {
for (Value *V : ExprIt->second) {
auto ValueIt = ValueExprMap.find_as(V);
if (ValueIt != ValueExprMap.end())
ValueExprMap.erase(ValueIt);
}
ExprValueMap.erase(ExprIt);
}
auto ScopeIt = ValuesAtScopes.find(S);
if (ScopeIt != ValuesAtScopes.end()) {
for (const auto &Pair : ScopeIt->second)
if (!isa_and_nonnull<SCEVConstant>(Pair.second))
erase_value(ValuesAtScopesUsers[Pair.second],
std::make_pair(Pair.first, S));
ValuesAtScopes.erase(ScopeIt);
}
auto ScopeUserIt = ValuesAtScopesUsers.find(S);
if (ScopeUserIt != ValuesAtScopesUsers.end()) {
for (const auto &Pair : ScopeUserIt->second)
erase_value(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
ValuesAtScopesUsers.erase(ScopeUserIt);
}
auto BEUsersIt = BECountUsers.find(S);
if (BEUsersIt != BECountUsers.end()) {
auto Copy = BEUsersIt->second;
for (const auto &Pair : Copy)
forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
BECountUsers.erase(BEUsersIt);
}
}
void
ScalarEvolution::getUsedLoops(const SCEV *S,
SmallPtrSetImpl<const Loop *> &LoopsUsed) {
struct FindUsedLoops {
FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
: LoopsUsed(LoopsUsed) {}
SmallPtrSetImpl<const Loop *> &LoopsUsed;
bool follow(const SCEV *S) {
if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
LoopsUsed.insert(AR->getLoop());
return true;
}
bool isDone() const { return false; }
};
FindUsedLoops F(LoopsUsed);
SCEVTraversal<FindUsedLoops>(F).visitAll(S);
}
void ScalarEvolution::getReachableBlocks(
SmallPtrSetImpl<BasicBlock *> &Reachable, Function &F) {
SmallVector<BasicBlock *> Worklist;
Worklist.push_back(&F.getEntryBlock());
while (!Worklist.empty()) {
BasicBlock *BB = Worklist.pop_back_val();
if (!Reachable.insert(BB).second)
continue;
Value *Cond;
BasicBlock *TrueBB, *FalseBB;
if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
m_BasicBlock(FalseBB)))) {
if (auto *C = dyn_cast<ConstantInt>(Cond)) {
Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
continue;
}
if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
const SCEV *L = getSCEV(Cmp->getOperand(0));
const SCEV *R = getSCEV(Cmp->getOperand(1));
if (isKnownPredicateViaConstantRanges(Cmp->getPredicate(), L, R)) {
Worklist.push_back(TrueBB);
continue;
}
if (isKnownPredicateViaConstantRanges(Cmp->getInversePredicate(), L,
R)) {
Worklist.push_back(FalseBB);
continue;
}
}
}
append_range(Worklist, successors(BB));
}
}
void ScalarEvolution::verify() const {
ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
ScalarEvolution SE2(F, TLI, AC, DT, LI);
SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
const SCEV *visitConstant(const SCEVConstant *Constant) {
return SE.getConstant(Constant->getAPInt());
}
const SCEV *visitUnknown(const SCEVUnknown *Expr) {
return SE.getUnknown(Expr->getValue());
}
const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
return SE.getCouldNotCompute();
}
};
SCEVMapper SCM(SE2);
SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
SE2.getReachableBlocks(ReachableBlocks, F);
auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
if (containsUndefs(Old) || containsUndefs(New)) {
return nullptr;
}
const SCEV *Delta = SE2.getMinusSCEV(Old, New);
if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
return nullptr;
return Delta;
};
while (!LoopStack.empty()) {
auto *L = LoopStack.pop_back_val();
llvm::append_range(LoopStack, *L);
if (!ReachableBlocks.contains(L->getHeader()))
continue;
auto It = BackedgeTakenCounts.find(L);
if (It == BackedgeTakenCounts.end())
continue;
auto *CurBECount =
SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
auto *NewBECount = SE2.getBackedgeTakenCount(L);
if (CurBECount == SE2.getCouldNotCompute() ||
NewBECount == SE2.getCouldNotCompute()) {
continue;
}
if (SE.getTypeSizeInBits(CurBECount->getType()) >
SE.getTypeSizeInBits(NewBECount->getType()))
NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
else if (SE.getTypeSizeInBits(CurBECount->getType()) <
SE.getTypeSizeInBits(NewBECount->getType()))
CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
const SCEV *Delta = GetDelta(CurBECount, NewBECount);
if (Delta && !Delta->isZero()) {
dbgs() << "Trip Count for " << *L << " Changed!\n";
dbgs() << "Old: " << *CurBECount << "\n";
dbgs() << "New: " << *NewBECount << "\n";
dbgs() << "Delta: " << *Delta << "\n";
std::abort();
}
}
SmallPtrSet<Loop *, 32> ValidLoops;
SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
while (!Worklist.empty()) {
Loop *L = Worklist.pop_back_val();
if (ValidLoops.insert(L).second)
Worklist.append(L->begin(), L->end());
}
for (const auto &KV : ValueExprMap) {
#ifndef NDEBUG
if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
assert(ValidLoops.contains(AR->getLoop()) &&
"AddRec references invalid loop");
}
#endif
auto It = ExprValueMap.find(KV.second);
if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
dbgs() << "Value " << *KV.first
<< " is in ValueExprMap but not in ExprValueMap\n";
std::abort();
}
if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
if (!ReachableBlocks.contains(I->getParent()))
continue;
const SCEV *OldSCEV = SCM.visit(KV.second);
const SCEV *NewSCEV = SE2.getSCEV(I);
const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
if (Delta && !Delta->isZero()) {
dbgs() << "SCEV for value " << *I << " changed!\n"
<< "Old: " << *OldSCEV << "\n"
<< "New: " << *NewSCEV << "\n"
<< "Delta: " << *Delta << "\n";
std::abort();
}
}
}
for (const auto &KV : ExprValueMap) {
for (Value *V : KV.second) {
auto It = ValueExprMap.find_as(V);
if (It == ValueExprMap.end()) {
dbgs() << "Value " << *V
<< " is in ExprValueMap but not in ValueExprMap\n";
std::abort();
}
if (It->second != KV.first) {
dbgs() << "Value " << *V << " mapped to " << *It->second
<< " rather than " << *KV.first << "\n";
std::abort();
}
}
}
for (const auto &S : UniqueSCEVs) {
SmallVector<const SCEV *, 4> Ops;
collectUniqueOps(&S, Ops);
for (const auto *Op : Ops) {
if (isa<SCEVConstant>(Op))
continue;
auto It = SCEVUsers.find(Op);
if (It != SCEVUsers.end() && It->second.count(&S))
continue;
dbgs() << "Use of operand " << *Op << " by user " << S
<< " is not being tracked!\n";
std::abort();
}
}
for (const auto &ValueAndVec : ValuesAtScopes) {
const SCEV *Value = ValueAndVec.first;
for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
const Loop *L = LoopAndValueAtScope.first;
const SCEV *ValueAtScope = LoopAndValueAtScope.second;
if (!isa<SCEVConstant>(ValueAtScope)) {
auto It = ValuesAtScopesUsers.find(ValueAtScope);
if (It != ValuesAtScopesUsers.end() &&
is_contained(It->second, std::make_pair(L, Value)))
continue;
dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
<< *ValueAtScope << " missing in ValuesAtScopesUsers\n";
std::abort();
}
}
}
for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
const Loop *L = LoopAndValue.first;
const SCEV *Value = LoopAndValue.second;
assert(!isa<SCEVConstant>(Value));
auto It = ValuesAtScopes.find(Value);
if (It != ValuesAtScopes.end() &&
is_contained(It->second, std::make_pair(L, ValueAtScope)))
continue;
dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
<< *ValueAtScope << " missing in ValuesAtScopes\n";
std::abort();
}
}
auto VerifyBECountUsers = [&](bool Predicated) {
auto &BECounts =
Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
for (const auto &LoopAndBEInfo : BECounts) {
for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
if (!isa<SCEVConstant>(ENT.ExactNotTaken)) {
auto UserIt = BECountUsers.find(ENT.ExactNotTaken);
if (UserIt != BECountUsers.end() &&
UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
continue;
dbgs() << "Value " << *ENT.ExactNotTaken << " for loop "
<< *LoopAndBEInfo.first << " missing from BECountUsers\n";
std::abort();
}
}
}
};
VerifyBECountUsers( false);
VerifyBECountUsers( true);
}
bool ScalarEvolution::invalidate(
Function &F, const PreservedAnalyses &PA,
FunctionAnalysisManager::Invalidator &Inv) {
auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
Inv.invalidate<AssumptionAnalysis>(F, PA) ||
Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
Inv.invalidate<LoopAnalysis>(F, PA);
}
AnalysisKey ScalarEvolutionAnalysis::Key;
ScalarEvolution ScalarEvolutionAnalysis::run(Function &F,
FunctionAnalysisManager &AM) {
return ScalarEvolution(F, AM.getResult<TargetLibraryAnalysis>(F),
AM.getResult<AssumptionAnalysis>(F),
AM.getResult<DominatorTreeAnalysis>(F),
AM.getResult<LoopAnalysis>(F));
}
PreservedAnalyses
ScalarEvolutionVerifierPass::run(Function &F, FunctionAnalysisManager &AM) {
AM.getResult<ScalarEvolutionAnalysis>(F).verify();
return PreservedAnalyses::all();
}
PreservedAnalyses
ScalarEvolutionPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
<< F.getName() << "':\n";
AM.getResult<ScalarEvolutionAnalysis>(F).print(OS);
return PreservedAnalyses::all();
}
INITIALIZE_PASS_BEGIN(ScalarEvolutionWrapperPass, "scalar-evolution",
"Scalar Evolution Analysis", false, true)
INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
INITIALIZE_PASS_END(ScalarEvolutionWrapperPass, "scalar-evolution",
"Scalar Evolution Analysis", false, true)
char ScalarEvolutionWrapperPass::ID = 0;
ScalarEvolutionWrapperPass::ScalarEvolutionWrapperPass() : FunctionPass(ID) {
initializeScalarEvolutionWrapperPassPass(*PassRegistry::getPassRegistry());
}
bool ScalarEvolutionWrapperPass::runOnFunction(Function &F) {
SE.reset(new ScalarEvolution(
F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F),
getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
return false;
}
void ScalarEvolutionWrapperPass::releaseMemory() { SE.reset(); }
void ScalarEvolutionWrapperPass::print(raw_ostream &OS, const Module *) const {
SE->print(OS);
}
void ScalarEvolutionWrapperPass::verifyAnalysis() const {
if (!VerifySCEV)
return;
SE->verify();
}
void ScalarEvolutionWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesAll();
AU.addRequiredTransitive<AssumptionCacheTracker>();
AU.addRequiredTransitive<LoopInfoWrapperPass>();
AU.addRequiredTransitive<DominatorTreeWrapperPass>();
AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>();
}
const SCEVPredicate *ScalarEvolution::getEqualPredicate(const SCEV *LHS,
const SCEV *RHS) {
return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS);
}
const SCEVPredicate *
ScalarEvolution::getComparePredicate(const ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS) {
FoldingSetNodeID ID;
assert(LHS->getType() == RHS->getType() &&
"Type mismatch between LHS and RHS");
ID.AddInteger(SCEVPredicate::P_Compare);
ID.AddInteger(Pred);
ID.AddPointer(LHS);
ID.AddPointer(RHS);
void *IP = nullptr;
if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
return S;
SCEVComparePredicate *Eq = new (SCEVAllocator)
SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
UniquePreds.InsertNode(Eq, IP);
return Eq;
}
const SCEVPredicate *ScalarEvolution::getWrapPredicate(
const SCEVAddRecExpr *AR,
SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
FoldingSetNodeID ID;
ID.AddInteger(SCEVPredicate::P_Wrap);
ID.AddPointer(AR);
ID.AddInteger(AddedFlags);
void *IP = nullptr;
if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
return S;
auto *OF = new (SCEVAllocator)
SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
UniquePreds.InsertNode(OF, IP);
return OF;
}
namespace {
class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
public:
static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
const SCEVPredicate *Pred) {
SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
return Rewriter.visit(S);
}
const SCEV *visitUnknown(const SCEVUnknown *Expr) {
if (Pred) {
if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
for (const auto *Pred : U->getPredicates())
if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
if (IPred->getLHS() == Expr &&
IPred->getPredicate() == ICmpInst::ICMP_EQ)
return IPred->getRHS();
} else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
if (IPred->getLHS() == Expr &&
IPred->getPredicate() == ICmpInst::ICMP_EQ)
return IPred->getRHS();
}
}
return convertToAddRecWithPreds(Expr);
}
const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
const SCEV *Operand = visit(Expr->getOperand());
const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
if (AR && AR->getLoop() == L && AR->isAffine()) {
const SCEV *Step = AR->getStepRecurrence(SE);
Type *Ty = Expr->getType();
if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
SE.getSignExtendExpr(Step, Ty), L,
AR->getNoWrapFlags());
}
return SE.getZeroExtendExpr(Operand, Expr->getType());
}
const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
const SCEV *Operand = visit(Expr->getOperand());
const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
if (AR && AR->getLoop() == L && AR->isAffine()) {
const SCEV *Step = AR->getStepRecurrence(SE);
Type *Ty = Expr->getType();
if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
SE.getSignExtendExpr(Step, Ty), L,
AR->getNoWrapFlags());
}
return SE.getSignExtendExpr(Operand, Expr->getType());
}
private:
explicit SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE,
SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
const SCEVPredicate *Pred)
: SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
bool addOverflowAssumption(const SCEVPredicate *P) {
if (!NewPreds) {
return Pred && Pred->implies(P);
}
NewPreds->insert(P);
return true;
}
bool addOverflowAssumption(const SCEVAddRecExpr *AR,
SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
auto *A = SE.getWrapPredicate(AR, AddedFlags);
return addOverflowAssumption(A);
}
const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
if (!isa<PHINode>(Expr->getValue()))
return Expr;
Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
if (!PredicatedRewrite)
return Expr;
for (const auto *P : PredicatedRewrite->second){
if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
if (L != WP->getExpr()->getLoop())
return Expr;
}
if (!addOverflowAssumption(P))
return Expr;
}
return PredicatedRewrite->first;
}
SmallPtrSetImpl<const SCEVPredicate *> *NewPreds;
const SCEVPredicate *Pred;
const Loop *L;
};
}
const SCEV *
ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L,
const SCEVPredicate &Preds) {
return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
}
const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates(
const SCEV *S, const Loop *L,
SmallPtrSetImpl<const SCEVPredicate *> &Preds) {
SmallPtrSet<const SCEVPredicate *, 4> TransformPreds;
S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
if (!AddRec)
return nullptr;
for (const auto *P : TransformPreds)
Preds.insert(P);
return AddRec;
}
SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID,
SCEVPredicateKind Kind)
: FastID(ID), Kind(Kind) {}
SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID,
const ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS)
: SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
assert(LHS != RHS && "LHS and RHS are the same SCEV");
}
bool SCEVComparePredicate::implies(const SCEVPredicate *N) const {
const auto *Op = dyn_cast<SCEVComparePredicate>(N);
if (!Op)
return false;
if (Pred != ICmpInst::ICMP_EQ)
return false;
return Op->LHS == LHS && Op->RHS == RHS;
}
bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
void SCEVComparePredicate::print(raw_ostream &OS, unsigned Depth) const {
if (Pred == ICmpInst::ICMP_EQ)
OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
else
OS.indent(Depth) << "Compare predicate: " << *LHS
<< " " << CmpInst::getPredicateName(Pred) << ") "
<< *RHS << "\n";
}
SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,
const SCEVAddRecExpr *AR,
IncrementWrapFlags Flags)
: SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const {
const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
}
bool SCEVWrapPredicate::isAlwaysTrue() const {
SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
IncrementWrapFlags IFlags = Flags;
if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
IFlags = clearFlags(IFlags, IncrementNSSW);
return IFlags == IncrementAnyWrap;
}
void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
OS.indent(Depth) << *getExpr() << " Added Flags: ";
if (SCEVWrapPredicate::IncrementNUSW & getFlags())
OS << "<nusw>";
if (SCEVWrapPredicate::IncrementNSSW & getFlags())
OS << "<nssw>";
OS << "\n";
}
SCEVWrapPredicate::IncrementWrapFlags
SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
ScalarEvolution &SE) {
IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
ImpliedFlags = IncrementNSSW;
if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
if (Step->getValue()->getValue().isNonNegative())
ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
}
return ImpliedFlags;
}
SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds)
: SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
for (const auto *P : Preds)
add(P);
}
bool SCEVUnionPredicate::isAlwaysTrue() const {
return all_of(Preds,
[](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
}
bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
return all_of(Set->Preds,
[this](const SCEVPredicate *I) { return this->implies(I); });
return any_of(Preds,
[N](const SCEVPredicate *I) { return I->implies(N); });
}
void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
for (const auto *Pred : Preds)
Pred->print(OS, Depth);
}
void SCEVUnionPredicate::add(const SCEVPredicate *N) {
if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
for (const auto *Pred : Set->Preds)
add(Pred);
return;
}
Preds.push_back(N);
}
PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
Loop &L)
: SE(SE), L(L) {
SmallVector<const SCEVPredicate*, 4> Empty;
Preds = std::make_unique<SCEVUnionPredicate>(Empty);
}
void ScalarEvolution::registerUser(const SCEV *User,
ArrayRef<const SCEV *> Ops) {
for (const auto *Op : Ops)
if (!isa<SCEVConstant>(Op))
SCEVUsers[Op].insert(User);
}
const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) {
const SCEV *Expr = SE.getSCEV(V);
RewriteEntry &Entry = RewriteMap[Expr];
if (Entry.second && Generation == Entry.first)
return Entry.second;
if (Entry.second)
Expr = Entry.second;
const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
Entry = {Generation, NewSCEV};
return NewSCEV;
}
const SCEV *PredicatedScalarEvolution::getBackedgeTakenCount() {
if (!BackedgeCount) {
SmallVector<const SCEVPredicate *, 4> Preds;
BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
for (const auto *P : Preds)
addPredicate(*P);
}
return BackedgeCount;
}
void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
if (Preds->implies(&Pred))
return;
auto &OldPreds = Preds->getPredicates();
SmallVector<const SCEVPredicate*, 4> NewPreds(OldPreds.begin(), OldPreds.end());
NewPreds.push_back(&Pred);
Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
updateGeneration();
}
const SCEVPredicate &PredicatedScalarEvolution::getPredicate() const {
return *Preds;
}
void PredicatedScalarEvolution::updateGeneration() {
if (++Generation == 0) {
for (auto &II : RewriteMap) {
const SCEV *Rewritten = II.second.second;
II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
}
}
}
void PredicatedScalarEvolution::setNoOverflow(
Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) {
const SCEV *Expr = getSCEV(V);
const auto *AR = cast<SCEVAddRecExpr>(Expr);
auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
addPredicate(*SE.getWrapPredicate(AR, Flags));
auto II = FlagsMap.insert({V, Flags});
if (!II.second)
II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
}
bool PredicatedScalarEvolution::hasNoOverflow(
Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) {
const SCEV *Expr = getSCEV(V);
const auto *AR = cast<SCEVAddRecExpr>(Expr);
Flags = SCEVWrapPredicate::clearFlags(
Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
auto II = FlagsMap.find(V);
if (II != FlagsMap.end())
Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
return Flags == SCEVWrapPredicate::IncrementAnyWrap;
}
const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
const SCEV *Expr = this->getSCEV(V);
SmallPtrSet<const SCEVPredicate *, 4> NewPreds;
auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
if (!New)
return nullptr;
for (const auto *P : NewPreds)
addPredicate(*P);
RewriteMap[SE.getSCEV(V)] = {Generation, New};
return New;
}
PredicatedScalarEvolution::PredicatedScalarEvolution(
const PredicatedScalarEvolution &Init)
: RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
for (auto I : Init.FlagsMap)
FlagsMap.insert(I);
}
void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const {
for (auto *BB : L.getBlocks())
for (auto &I : *BB) {
if (!SE.isSCEVable(I.getType()))
continue;
auto *Expr = SE.getSCEV(&I);
auto II = RewriteMap.find(Expr);
if (II == RewriteMap.end())
continue;
if (II->second.second == Expr)
continue;
OS.indent(Depth) << "[PSE]" << I << ":\n";
OS.indent(Depth + 2) << *Expr << "\n";
OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
}
}
bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
const SCEV *&RHS) {
if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr))
if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) {
LHS = Trunc->getOperand();
if (getTypeSizeInBits(LHS->getType()) >
getTypeSizeInBits(Expr->getType()))
return false;
if (LHS->getType() != Expr->getType())
LHS = getZeroExtendExpr(LHS, Expr->getType());
RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1)
<< getTypeSizeInBits(Trunc->getType()));
return true;
}
const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
if (Add == nullptr || Add->getNumOperands() != 2)
return false;
const SCEV *A = Add->getOperand(1);
const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
if (Mul == nullptr)
return false;
const auto MatchURemWithDivisor = [&](const SCEV *B) {
if (Expr == getURemExpr(A, B)) {
LHS = A;
RHS = B;
return true;
}
return false;
};
if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
return MatchURemWithDivisor(Mul->getOperand(1)) ||
MatchURemWithDivisor(Mul->getOperand(2));
if (Mul->getNumOperands() == 2)
return MatchURemWithDivisor(Mul->getOperand(1)) ||
MatchURemWithDivisor(Mul->getOperand(0)) ||
MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
return false;
}
const SCEV *
ScalarEvolution::computeSymbolicMaxBackedgeTakenCount(const Loop *L) {
SmallVector<BasicBlock*, 16> ExitingBlocks;
L->getExitingBlocks(ExitingBlocks);
SmallVector<const SCEV*, 4> ExitCounts;
for (BasicBlock *ExitingBB : ExitingBlocks) {
const SCEV *ExitCount = getExitCount(L, ExitingBB);
if (isa<SCEVCouldNotCompute>(ExitCount))
ExitCount = getExitCount(L, ExitingBB,
ScalarEvolution::ConstantMaximum);
if (!isa<SCEVCouldNotCompute>(ExitCount)) {
assert(DT.dominates(ExitingBB, L->getLoopLatch()) &&
"We should only have known counts for exiting blocks that "
"dominate latch!");
ExitCounts.push_back(ExitCount);
}
}
if (ExitCounts.empty())
return getCouldNotCompute();
return getUMinFromMismatchedTypes(ExitCounts);
}
class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
const DenseMap<const SCEV *, const SCEV *> ⤅
public:
SCEVLoopGuardRewriter(ScalarEvolution &SE,
DenseMap<const SCEV *, const SCEV *> &M)
: SCEVRewriteVisitor(SE), Map(M) {}
const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
const SCEV *visitUnknown(const SCEVUnknown *Expr) {
auto I = Map.find(Expr);
if (I == Map.end())
return Expr;
return I->second;
}
const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
auto I = Map.find(Expr);
if (I == Map.end())
return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitZeroExtendExpr(
Expr);
return I->second;
}
};
const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
SmallVector<const SCEV *> ExprsToRewrite;
auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
const SCEV *RHS,
DenseMap<const SCEV *, const SCEV *>
&RewriteMap) {
if (isa<SCEVConstant>(LHS)) {
std::swap(LHS, RHS);
Predicate = CmpInst::getSwappedPredicate(Predicate);
}
auto MatchRangeCheckIdiom = [this, Predicate, LHS, RHS, &RewriteMap,
&ExprsToRewrite]() {
auto *AddExpr = dyn_cast<SCEVAddExpr>(LHS);
if (!AddExpr || AddExpr->getNumOperands() != 2)
return false;
auto *C1 = dyn_cast<SCEVConstant>(AddExpr->getOperand(0));
auto *LHSUnknown = dyn_cast<SCEVUnknown>(AddExpr->getOperand(1));
auto *C2 = dyn_cast<SCEVConstant>(RHS);
if (!C1 || !C2 || !LHSUnknown)
return false;
auto ExactRegion =
ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
.sub(C1->getAPInt());
if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
return false;
auto I = RewriteMap.find(LHSUnknown);
const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHSUnknown;
RewriteMap[LHSUnknown] = getUMaxExpr(
getConstant(ExactRegion.getUnsignedMin()),
getUMinExpr(RewrittenLHS, getConstant(ExactRegion.getUnsignedMax())));
ExprsToRewrite.push_back(LHSUnknown);
return true;
};
if (MatchRangeCheckIdiom())
return;
const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS);
if (Predicate == CmpInst::ICMP_EQ && RHSC &&
RHSC->getValue()->isNullValue()) {
const SCEV *URemLHS = nullptr;
const SCEV *URemRHS = nullptr;
if (matchURem(LHS, URemLHS, URemRHS)) {
if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
auto Multiple = getMulExpr(getUDivExpr(URemLHS, URemRHS), URemRHS);
RewriteMap[LHSUnknown] = Multiple;
ExprsToRewrite.push_back(LHSUnknown);
return;
}
}
}
if (isa<SCEVConstant>(LHS) || containsAddRecurrence(RHS))
return;
if (!isa<SCEVUnknown>(LHS) && isa<SCEVUnknown>(RHS)) {
std::swap(LHS, RHS);
Predicate = CmpInst::getSwappedPredicate(Predicate);
}
if (!isa<SCEVUnknown>(LHS) && !isa<SCEVZeroExtendExpr>(LHS))
return;
auto I = RewriteMap.find(LHS);
const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHS;
const SCEV *RewrittenRHS = nullptr;
switch (Predicate) {
case CmpInst::ICMP_ULT:
RewrittenRHS =
getUMinExpr(RewrittenLHS, getMinusSCEV(RHS, getOne(RHS->getType())));
break;
case CmpInst::ICMP_SLT:
RewrittenRHS =
getSMinExpr(RewrittenLHS, getMinusSCEV(RHS, getOne(RHS->getType())));
break;
case CmpInst::ICMP_ULE:
RewrittenRHS = getUMinExpr(RewrittenLHS, RHS);
break;
case CmpInst::ICMP_SLE:
RewrittenRHS = getSMinExpr(RewrittenLHS, RHS);
break;
case CmpInst::ICMP_UGT:
RewrittenRHS =
getUMaxExpr(RewrittenLHS, getAddExpr(RHS, getOne(RHS->getType())));
break;
case CmpInst::ICMP_SGT:
RewrittenRHS =
getSMaxExpr(RewrittenLHS, getAddExpr(RHS, getOne(RHS->getType())));
break;
case CmpInst::ICMP_UGE:
RewrittenRHS = getUMaxExpr(RewrittenLHS, RHS);
break;
case CmpInst::ICMP_SGE:
RewrittenRHS = getSMaxExpr(RewrittenLHS, RHS);
break;
case CmpInst::ICMP_EQ:
if (isa<SCEVConstant>(RHS))
RewrittenRHS = RHS;
break;
case CmpInst::ICMP_NE:
if (isa<SCEVConstant>(RHS) &&
cast<SCEVConstant>(RHS)->getValue()->isNullValue())
RewrittenRHS = getUMaxExpr(RewrittenLHS, getOne(RHS->getType()));
break;
default:
break;
}
if (RewrittenRHS) {
RewriteMap[LHS] = RewrittenRHS;
if (LHS == RewrittenLHS)
ExprsToRewrite.push_back(LHS);
}
};
SmallVector<std::pair<Value *, bool>> Terms;
for (auto &AssumeVH : AC.assumptions()) {
if (!AssumeVH)
continue;
auto *AssumeI = cast<CallInst>(AssumeVH);
if (!DT.dominates(AssumeI, L->getHeader()))
continue;
Terms.emplace_back(AssumeI->getOperand(0), true);
}
for (std::pair<const BasicBlock *, const BasicBlock *> Pair(
L->getLoopPredecessor(), L->getHeader());
Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
const BranchInst *LoopEntryPredicate =
dyn_cast<BranchInst>(Pair.first->getTerminator());
if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
continue;
Terms.emplace_back(LoopEntryPredicate->getCondition(),
LoopEntryPredicate->getSuccessor(0) == Pair.second);
}
DenseMap<const SCEV *, const SCEV *> RewriteMap;
for (auto &E : reverse(Terms)) {
bool EnterIfTrue = E.second;
SmallVector<Value *, 8> Worklist;
SmallPtrSet<Value *, 8> Visited;
Worklist.push_back(E.first);
while (!Worklist.empty()) {
Value *Cond = Worklist.pop_back_val();
if (!Visited.insert(Cond).second)
continue;
if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
auto Predicate =
EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
const auto *LHS = getSCEV(Cmp->getOperand(0));
const auto *RHS = getSCEV(Cmp->getOperand(1));
CollectCondition(Predicate, LHS, RHS, RewriteMap);
continue;
}
Value *L, *R;
if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
: match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
Worklist.push_back(L);
Worklist.push_back(R);
}
}
}
if (RewriteMap.empty())
return Expr;
if (ExprsToRewrite.size() > 1) {
for (const SCEV *Expr : ExprsToRewrite) {
const SCEV *RewriteTo = RewriteMap[Expr];
RewriteMap.erase(Expr);
SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
RewriteMap.insert({Expr, Rewriter.visit(RewriteTo)});
}
}
SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
return Rewriter.visit(Expr);
}