#include "llvm/Transforms/Instrumentation/PoisonChecking.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/Support/CommandLine.h"
using namespace llvm;
#define DEBUG_TYPE "poison-checking"
static cl::opt<bool>
LocalCheck("poison-checking-function-local",
cl::init(false),
cl::desc("Check that returns are non-poison (for testing)"));
static bool isConstantFalse(Value* V) {
assert(V->getType()->isIntegerTy(1));
if (auto *CI = dyn_cast<ConstantInt>(V))
return CI->isZero();
return false;
}
static Value *buildOrChain(IRBuilder<> &B, ArrayRef<Value*> Ops) {
if (Ops.size() == 0)
return B.getFalse();
unsigned i = 0;
for (; i < Ops.size() && isConstantFalse(Ops[i]); i++) {}
if (i == Ops.size())
return B.getFalse();
Value *Accum = Ops[i++];
for (; i < Ops.size(); i++)
if (!isConstantFalse(Ops[i]))
Accum = B.CreateOr(Accum, Ops[i]);
return Accum;
}
static void generateCreationChecksForBinOp(Instruction &I,
SmallVectorImpl<Value*> &Checks) {
assert(isa<BinaryOperator>(I));
IRBuilder<> B(&I);
Value *LHS = I.getOperand(0);
Value *RHS = I.getOperand(1);
switch (I.getOpcode()) {
default:
return;
case Instruction::Add: {
if (I.hasNoSignedWrap()) {
auto *OverflowOp =
B.CreateBinaryIntrinsic(Intrinsic::sadd_with_overflow, LHS, RHS);
Checks.push_back(B.CreateExtractValue(OverflowOp, 1));
}
if (I.hasNoUnsignedWrap()) {
auto *OverflowOp =
B.CreateBinaryIntrinsic(Intrinsic::uadd_with_overflow, LHS, RHS);
Checks.push_back(B.CreateExtractValue(OverflowOp, 1));
}
break;
}
case Instruction::Sub: {
if (I.hasNoSignedWrap()) {
auto *OverflowOp =
B.CreateBinaryIntrinsic(Intrinsic::ssub_with_overflow, LHS, RHS);
Checks.push_back(B.CreateExtractValue(OverflowOp, 1));
}
if (I.hasNoUnsignedWrap()) {
auto *OverflowOp =
B.CreateBinaryIntrinsic(Intrinsic::usub_with_overflow, LHS, RHS);
Checks.push_back(B.CreateExtractValue(OverflowOp, 1));
}
break;
}
case Instruction::Mul: {
if (I.hasNoSignedWrap()) {
auto *OverflowOp =
B.CreateBinaryIntrinsic(Intrinsic::smul_with_overflow, LHS, RHS);
Checks.push_back(B.CreateExtractValue(OverflowOp, 1));
}
if (I.hasNoUnsignedWrap()) {
auto *OverflowOp =
B.CreateBinaryIntrinsic(Intrinsic::umul_with_overflow, LHS, RHS);
Checks.push_back(B.CreateExtractValue(OverflowOp, 1));
}
break;
}
case Instruction::UDiv: {
if (I.isExact()) {
auto *Check =
B.CreateICmp(ICmpInst::ICMP_NE, B.CreateURem(LHS, RHS),
ConstantInt::get(LHS->getType(), 0));
Checks.push_back(Check);
}
break;
}
case Instruction::SDiv: {
if (I.isExact()) {
auto *Check =
B.CreateICmp(ICmpInst::ICMP_NE, B.CreateSRem(LHS, RHS),
ConstantInt::get(LHS->getType(), 0));
Checks.push_back(Check);
}
break;
}
case Instruction::AShr:
case Instruction::LShr:
case Instruction::Shl: {
Value *ShiftCheck =
B.CreateICmp(ICmpInst::ICMP_UGE, RHS,
ConstantInt::get(RHS->getType(),
LHS->getType()->getScalarSizeInBits()));
Checks.push_back(ShiftCheck);
break;
}
};
}
static void generateCreationChecks(Instruction &I,
SmallVectorImpl<Value*> &Checks) {
IRBuilder<> B(&I);
if (isa<BinaryOperator>(I) && !I.getType()->isVectorTy())
generateCreationChecksForBinOp(I, Checks);
switch (I.getOpcode()) {
default:
break;
case Instruction::ExtractElement: {
Value *Vec = I.getOperand(0);
auto *VecVTy = dyn_cast<FixedVectorType>(Vec->getType());
if (!VecVTy)
break;
Value *Idx = I.getOperand(1);
unsigned NumElts = VecVTy->getNumElements();
Value *Check =
B.CreateICmp(ICmpInst::ICMP_UGE, Idx,
ConstantInt::get(Idx->getType(), NumElts));
Checks.push_back(Check);
break;
}
case Instruction::InsertElement: {
Value *Vec = I.getOperand(0);
auto *VecVTy = dyn_cast<FixedVectorType>(Vec->getType());
if (!VecVTy)
break;
Value *Idx = I.getOperand(2);
unsigned NumElts = VecVTy->getNumElements();
Value *Check =
B.CreateICmp(ICmpInst::ICMP_UGE, Idx,
ConstantInt::get(Idx->getType(), NumElts));
Checks.push_back(Check);
break;
}
};
}
static Value *getPoisonFor(DenseMap<Value *, Value *> &ValToPoison, Value *V) {
auto Itr = ValToPoison.find(V);
if (Itr != ValToPoison.end())
return Itr->second;
if (isa<Constant>(V)) {
return ConstantInt::getFalse(V->getContext());
}
return ConstantInt::getFalse(V->getContext());
}
static void CreateAssert(IRBuilder<> &B, Value *Cond) {
assert(Cond->getType()->isIntegerTy(1));
if (auto *CI = dyn_cast<ConstantInt>(Cond))
if (CI->isAllOnesValue())
return;
Module *M = B.GetInsertBlock()->getModule();
M->getOrInsertFunction("__poison_checker_assert",
Type::getVoidTy(M->getContext()),
Type::getInt1Ty(M->getContext()));
Function *TrapFunc = M->getFunction("__poison_checker_assert");
B.CreateCall(TrapFunc, Cond);
}
static void CreateAssertNot(IRBuilder<> &B, Value *Cond) {
assert(Cond->getType()->isIntegerTy(1));
CreateAssert(B, B.CreateNot(Cond));
}
static bool rewrite(Function &F) {
auto * const Int1Ty = Type::getInt1Ty(F.getContext());
DenseMap<Value *, Value *> ValToPoison;
for (BasicBlock &BB : F)
for (auto I = BB.begin(); isa<PHINode>(&*I); I++) {
auto *OldPHI = cast<PHINode>(&*I);
auto *NewPHI = PHINode::Create(Int1Ty, OldPHI->getNumIncomingValues());
for (unsigned i = 0; i < OldPHI->getNumIncomingValues(); i++)
NewPHI->addIncoming(UndefValue::get(Int1Ty),
OldPHI->getIncomingBlock(i));
NewPHI->insertBefore(OldPHI);
ValToPoison[OldPHI] = NewPHI;
}
for (BasicBlock &BB : F)
for (Instruction &I : BB) {
if (isa<PHINode>(I)) continue;
IRBuilder<> B(cast<Instruction>(&I));
SmallPtrSet<const Value *, 4> NonPoisonOps;
getGuaranteedNonPoisonOps(&I, NonPoisonOps);
for (const Value *Op : NonPoisonOps)
CreateAssertNot(B, getPoisonFor(ValToPoison, const_cast<Value *>(Op)));
if (LocalCheck)
if (auto *RI = dyn_cast<ReturnInst>(&I))
if (RI->getNumOperands() != 0) {
Value *Op = RI->getOperand(0);
CreateAssertNot(B, getPoisonFor(ValToPoison, Op));
}
SmallVector<Value*, 4> Checks;
if (propagatesPoison(cast<Operator>(&I)))
for (Value *V : I.operands())
Checks.push_back(getPoisonFor(ValToPoison, V));
if (canCreatePoison(cast<Operator>(&I)))
generateCreationChecks(I, Checks);
ValToPoison[&I] = buildOrChain(B, Checks);
}
for (BasicBlock &BB : F)
for (auto I = BB.begin(); isa<PHINode>(&*I); I++) {
auto *OldPHI = cast<PHINode>(&*I);
if (!ValToPoison.count(OldPHI))
continue; auto *NewPHI = cast<PHINode>(ValToPoison[OldPHI]);
for (unsigned i = 0; i < OldPHI->getNumIncomingValues(); i++) {
auto *OldVal = OldPHI->getIncomingValue(i);
NewPHI->setIncomingValue(i, getPoisonFor(ValToPoison, OldVal));
}
}
return true;
}
PreservedAnalyses PoisonCheckingPass::run(Module &M,
ModuleAnalysisManager &AM) {
bool Changed = false;
for (auto &F : M)
Changed |= rewrite(F);
return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
}
PreservedAnalyses PoisonCheckingPass::run(Function &F,
FunctionAnalysisManager &AM) {
return rewrite(F) ? PreservedAnalyses::none() : PreservedAnalyses::all();
}