#include "llvm/Analysis/ScalarEvolutionNormalization.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
using namespace llvm;
enum TransformKind {
Normalize,
Denormalize
};
namespace {
struct NormalizeDenormalizeRewriter
: public SCEVRewriteVisitor<NormalizeDenormalizeRewriter> {
const TransformKind Kind;
const NormalizePredTy Pred;
NormalizeDenormalizeRewriter(TransformKind Kind, NormalizePredTy Pred,
ScalarEvolution &SE)
: SCEVRewriteVisitor<NormalizeDenormalizeRewriter>(SE), Kind(Kind),
Pred(Pred) {}
const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr);
};
}
const SCEV *
NormalizeDenormalizeRewriter::visitAddRecExpr(const SCEVAddRecExpr *AR) {
SmallVector<const SCEV *, 8> Operands;
transform(AR->operands(), std::back_inserter(Operands),
[&](const SCEV *Op) { return visit(Op); });
if (!Pred(AR))
return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap);
if (Kind == Denormalize) {
for (int i = 0, e = Operands.size() - 1; i < e; i++)
Operands[i] = SE.getAddExpr(Operands[i], Operands[i + 1]);
} else {
assert(Kind == Normalize && "Only two possibilities!");
for (int i = Operands.size() - 2; i >= 0; i--)
Operands[i] = SE.getMinusSCEV(Operands[i], Operands[i + 1]);
}
return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap);
}
const SCEV *llvm::normalizeForPostIncUse(const SCEV *S,
const PostIncLoopSet &Loops,
ScalarEvolution &SE) {
auto Pred = [&](const SCEVAddRecExpr *AR) {
return Loops.count(AR->getLoop());
};
return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S);
}
const SCEV *llvm::normalizeForPostIncUseIf(const SCEV *S, NormalizePredTy Pred,
ScalarEvolution &SE) {
return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S);
}
const SCEV *llvm::denormalizeForPostIncUse(const SCEV *S,
const PostIncLoopSet &Loops,
ScalarEvolution &SE) {
auto Pred = [&](const SCEVAddRecExpr *AR) {
return Loops.count(AR->getLoop());
};
return NormalizeDenormalizeRewriter(Denormalize, Pred, SE).visit(S);
}