#include "llvm/Analysis/SparsePropagation.h"
#include "llvm/ADT/PointerIntPair.h"
#include "llvm/IR/IRBuilder.h"
#include "gtest/gtest.h"
using namespace llvm;
namespace {
enum class IPOGrouping { Register, Return, Memory };
using TestLatticeKey = PointerIntPair<Value *, 2, IPOGrouping>;
}
namespace llvm {
template <> struct LatticeKeyInfo<TestLatticeKey> {
static inline Value *getValueFromLatticeKey(TestLatticeKey Key) {
return Key.getPointer();
}
static inline TestLatticeKey getLatticeKeyFromValue(Value *V) {
return TestLatticeKey(V, IPOGrouping::Register);
}
};
}
namespace {
class TestLatticeVal {
public:
enum TestLatticeStateTy {
UndefinedVal,
ConstantVal,
OverdefinedVal,
UntrackedVal
};
TestLatticeVal() : LatticeVal(nullptr, UndefinedVal) {}
TestLatticeVal(Constant *C, TestLatticeStateTy State)
: LatticeVal(C, State) {}
bool isConstant() const { return LatticeVal.getInt() == ConstantVal; }
bool isOverdefined() const { return LatticeVal.getInt() == OverdefinedVal; }
bool operator==(const TestLatticeVal &RHS) const {
return LatticeVal == RHS.LatticeVal;
}
bool operator!=(const TestLatticeVal &RHS) const {
return LatticeVal != RHS.LatticeVal;
}
private:
PointerIntPair<const Constant *, 2, TestLatticeStateTy> LatticeVal;
};
class TestLatticeFunc
: public AbstractLatticeFunction<TestLatticeKey, TestLatticeVal> {
public:
TestLatticeFunc()
: AbstractLatticeFunction(
TestLatticeVal(nullptr, TestLatticeVal::UndefinedVal),
TestLatticeVal(nullptr, TestLatticeVal::OverdefinedVal),
TestLatticeVal(nullptr, TestLatticeVal::UntrackedVal)) {}
TestLatticeVal ComputeLatticeVal(TestLatticeKey Key) override {
if (Key.getInt() == IPOGrouping::Register)
if (auto *C = dyn_cast<Constant>(Key.getPointer()))
return TestLatticeVal(C, TestLatticeVal::ConstantVal);
return getUndefVal();
}
TestLatticeVal MergeValues(TestLatticeVal X, TestLatticeVal Y) override {
if (X == getUntrackedVal() || Y == getUntrackedVal())
return getUntrackedVal();
if (X == getOverdefinedVal() || Y == getOverdefinedVal())
return getOverdefinedVal();
if (X == getUndefVal() && Y == getUndefVal())
return getUndefVal();
if (X == getUndefVal())
return Y;
if (Y == getUndefVal())
return X;
if (X == Y)
return X;
return getOverdefinedVal();
}
void ComputeInstructionState(
Instruction &I, DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues,
SparseSolver<TestLatticeKey, TestLatticeVal> &SS) override {
switch (I.getOpcode()) {
case Instruction::Call:
return visitCallBase(cast<CallBase>(I), ChangedValues, SS);
case Instruction::Ret:
return visitReturn(*cast<ReturnInst>(&I), ChangedValues, SS);
case Instruction::Store:
return visitStore(*cast<StoreInst>(&I), ChangedValues, SS);
default:
return visitInst(I, ChangedValues, SS);
}
}
private:
void visitCallBase(CallBase &I,
DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues,
SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
Function *F = I.getCalledFunction();
auto RegI = TestLatticeKey(&I, IPOGrouping::Register);
if (!F) {
ChangedValues[RegI] = getOverdefinedVal();
return;
}
SS.MarkBlockExecutable(&F->front());
for (Argument &A : F->args()) {
auto RegFormal = TestLatticeKey(&A, IPOGrouping::Register);
auto RegActual =
TestLatticeKey(I.getArgOperand(A.getArgNo()), IPOGrouping::Register);
ChangedValues[RegFormal] =
MergeValues(SS.getValueState(RegFormal), SS.getValueState(RegActual));
}
auto RetF = TestLatticeKey(F, IPOGrouping::Return);
ChangedValues[RegI] =
MergeValues(SS.getValueState(RegI), SS.getValueState(RetF));
}
void visitReturn(ReturnInst &I,
DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues,
SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
Function *F = I.getParent()->getParent();
if (F->getReturnType()->isVoidTy())
return;
auto RegR = TestLatticeKey(I.getReturnValue(), IPOGrouping::Register);
auto RetF = TestLatticeKey(F, IPOGrouping::Return);
ChangedValues[RetF] =
MergeValues(SS.getValueState(RegR), SS.getValueState(RetF));
}
void visitStore(StoreInst &I,
DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues,
SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
auto *GV = dyn_cast<GlobalVariable>(I.getPointerOperand());
if (!GV)
return;
auto RegVal = TestLatticeKey(I.getValueOperand(), IPOGrouping::Register);
auto MemPtr = TestLatticeKey(GV, IPOGrouping::Memory);
ChangedValues[MemPtr] =
MergeValues(SS.getValueState(RegVal), SS.getValueState(MemPtr));
}
void visitInst(Instruction &I,
DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues,
SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
auto RegI = TestLatticeKey(&I, IPOGrouping::Register);
ChangedValues[RegI] = getOverdefinedVal();
}
};
class SparsePropagationTest : public testing::Test {
protected:
LLVMContext Context;
Module M;
IRBuilder<> Builder;
TestLatticeFunc Lattice;
SparseSolver<TestLatticeKey, TestLatticeVal> Solver;
public:
SparsePropagationTest()
: M("", Context), Builder(Context), Solver(&Lattice) {}
};
}
TEST_F(SparsePropagationTest, MarkBlockExecutable) {
Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "f", &M);
Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "g", &M);
BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
Builder.SetInsertPoint(FEntry);
Builder.CreateCall(G);
Builder.CreateRetVoid();
Builder.SetInsertPoint(GEntry);
Builder.CreateCall(F);
Builder.CreateRetVoid();
Solver.MarkBlockExecutable(FEntry);
Solver.Solve();
EXPECT_TRUE(Solver.isBlockExecutable(GEntry));
}
TEST_F(SparsePropagationTest, GlobalVariableConstant) {
Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "f", &M);
Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "g", &M);
GlobalVariable *GV =
new GlobalVariable(M, Builder.getInt64Ty(), false,
GlobalValue::InternalLinkage, nullptr, "gv");
BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
Builder.SetInsertPoint(FEntry);
Builder.CreateStore(Builder.getInt64(1), GV);
Builder.CreateRetVoid();
Builder.SetInsertPoint(GEntry);
Builder.CreateStore(Builder.getInt64(1), GV);
Builder.CreateRetVoid();
Solver.MarkBlockExecutable(FEntry);
Solver.MarkBlockExecutable(GEntry);
Solver.Solve();
auto MemGV = TestLatticeKey(GV, IPOGrouping::Memory);
EXPECT_TRUE(Solver.getExistingValueState(MemGV).isConstant());
}
TEST_F(SparsePropagationTest, GlobalVariableOverDefined) {
Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "f", &M);
Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "g", &M);
GlobalVariable *GV =
new GlobalVariable(M, Builder.getInt64Ty(), false,
GlobalValue::InternalLinkage, nullptr, "gv");
BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
Builder.SetInsertPoint(FEntry);
Builder.CreateStore(Builder.getInt64(0), GV);
Builder.CreateRetVoid();
Builder.SetInsertPoint(GEntry);
Builder.CreateStore(Builder.getInt64(1), GV);
Builder.CreateRetVoid();
Solver.MarkBlockExecutable(FEntry);
Solver.MarkBlockExecutable(GEntry);
Solver.Solve();
auto MemGV = TestLatticeKey(GV, IPOGrouping::Memory);
EXPECT_TRUE(Solver.getExistingValueState(MemGV).isOverdefined());
}
TEST_F(SparsePropagationTest, FunctionDefined) {
Function *F =
Function::Create(FunctionType::get(Builder.getInt64Ty(),
{Type::getInt1PtrTy(Context)}, false),
GlobalValue::InternalLinkage, "f", &M);
BasicBlock *If = BasicBlock::Create(Context, "if", F);
BasicBlock *Then = BasicBlock::Create(Context, "then", F);
BasicBlock *Else = BasicBlock::Create(Context, "else", F);
F->arg_begin()->setName("cond");
Builder.SetInsertPoint(If);
LoadInst *Cond = Builder.CreateLoad(Type::getInt1Ty(Context), F->arg_begin());
Builder.CreateCondBr(Cond, Then, Else);
Builder.SetInsertPoint(Then);
Builder.CreateRet(Builder.getInt64(1));
Builder.SetInsertPoint(Else);
Builder.CreateRet(Builder.getInt64(1));
Solver.MarkBlockExecutable(If);
Solver.Solve();
auto RetF = TestLatticeKey(F, IPOGrouping::Return);
EXPECT_TRUE(Solver.getExistingValueState(RetF).isConstant());
}
TEST_F(SparsePropagationTest, FunctionOverDefined) {
Function *F =
Function::Create(FunctionType::get(Builder.getInt64Ty(),
{Type::getInt1PtrTy(Context)}, false),
GlobalValue::InternalLinkage, "f", &M);
BasicBlock *If = BasicBlock::Create(Context, "if", F);
BasicBlock *Then = BasicBlock::Create(Context, "then", F);
BasicBlock *Else = BasicBlock::Create(Context, "else", F);
F->arg_begin()->setName("cond");
Builder.SetInsertPoint(If);
LoadInst *Cond = Builder.CreateLoad(Type::getInt1Ty(Context), F->arg_begin());
Builder.CreateCondBr(Cond, Then, Else);
Builder.SetInsertPoint(Then);
Builder.CreateRet(Builder.getInt64(0));
Builder.SetInsertPoint(Else);
Builder.CreateRet(Builder.getInt64(1));
Solver.MarkBlockExecutable(If);
Solver.Solve();
auto RetF = TestLatticeKey(F, IPOGrouping::Return);
EXPECT_TRUE(Solver.getExistingValueState(RetF).isOverdefined());
}
TEST_F(SparsePropagationTest, ComputeInstructionState) {
Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "f", &M);
Function *G = Function::Create(
FunctionType::get(Builder.getVoidTy(),
{Builder.getInt64Ty(), Builder.getInt64Ty()}, false),
GlobalValue::InternalLinkage, "g", &M);
Argument *A = G->arg_begin();
Argument *B = std::next(G->arg_begin());
A->setName("a");
B->setName("b");
BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
Builder.SetInsertPoint(FEntry);
Builder.CreateCall(G, {Builder.getInt64(0), Builder.getInt64(1)});
Builder.CreateCall(G, {Builder.getInt64(1), Builder.getInt64(1)});
Builder.CreateRetVoid();
Builder.SetInsertPoint(GEntry);
Builder.CreateRetVoid();
Solver.MarkBlockExecutable(FEntry);
Solver.Solve();
auto RegA = TestLatticeKey(A, IPOGrouping::Register);
auto RegB = TestLatticeKey(B, IPOGrouping::Register);
EXPECT_TRUE(Solver.getExistingValueState(RegA).isOverdefined());
EXPECT_TRUE(Solver.getExistingValueState(RegB).isConstant());
}
TEST_F(SparsePropagationTest, ExceptionalTerminatorInsts) {
Function *P = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "p", &M);
Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "g", &M);
Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::InternalLinkage, "f", &M);
Constant *C =
ConstantExpr::getCast(Instruction::BitCast, P, Builder.getInt8PtrTy());
F->setPersonalityFn(C);
BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
BasicBlock *Pad = BasicBlock::Create(Context, "catch.pad", F);
BasicBlock *Body = BasicBlock::Create(Context, "catch.body", F);
BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
Builder.SetInsertPoint(Entry);
Builder.CreateInvoke(G, Exit, Pad);
Builder.SetInsertPoint(Pad);
CatchSwitchInst *CatchSwitch =
Builder.CreateCatchSwitch(ConstantTokenNone::get(Context), nullptr, 1);
CatchSwitch->addHandler(Body);
Builder.SetInsertPoint(Body);
CatchPadInst *CatchPad = Builder.CreateCatchPad(CatchSwitch, {});
Builder.CreateCatchRet(CatchPad, Exit);
Builder.SetInsertPoint(Exit);
Builder.CreateRetVoid();
Solver.MarkBlockExecutable(Entry);
Solver.Solve();
EXPECT_TRUE(Solver.isBlockExecutable(Pad));
EXPECT_TRUE(Solver.isBlockExecutable(Body));
EXPECT_TRUE(Solver.isBlockExecutable(Exit));
}