#include "llvm/Transforms/Utils/CallPromotionUtils.h"
#include "llvm/Analysis/Loads.h"
#include "llvm/Analysis/TypeMetadataUtils.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
using namespace llvm;
#define DEBUG_TYPE "call-promotion-utils"
static void fixupPHINodeForNormalDest(InvokeInst *Invoke, BasicBlock *OrigBlock,
BasicBlock *MergeBlock) {
for (PHINode &Phi : Invoke->getNormalDest()->phis()) {
int Idx = Phi.getBasicBlockIndex(OrigBlock);
if (Idx == -1)
continue;
Phi.setIncomingBlock(Idx, MergeBlock);
}
}
static void fixupPHINodeForUnwindDest(InvokeInst *Invoke, BasicBlock *OrigBlock,
BasicBlock *ThenBlock,
BasicBlock *ElseBlock) {
for (PHINode &Phi : Invoke->getUnwindDest()->phis()) {
int Idx = Phi.getBasicBlockIndex(OrigBlock);
if (Idx == -1)
continue;
auto *V = Phi.getIncomingValue(Idx);
Phi.setIncomingBlock(Idx, ThenBlock);
Phi.addIncoming(V, ElseBlock);
}
}
static void createRetPHINode(Instruction *OrigInst, Instruction *NewInst,
BasicBlock *MergeBlock, IRBuilder<> &Builder) {
if (OrigInst->getType()->isVoidTy() || OrigInst->use_empty())
return;
Builder.SetInsertPoint(&MergeBlock->front());
PHINode *Phi = Builder.CreatePHI(OrigInst->getType(), 0);
SmallVector<User *, 16> UsersToUpdate(OrigInst->users());
for (User *U : UsersToUpdate)
U->replaceUsesOfWith(OrigInst, Phi);
Phi->addIncoming(OrigInst, OrigInst->getParent());
Phi->addIncoming(NewInst, NewInst->getParent());
}
static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) {
SmallVector<User *, 16> UsersToUpdate(CB.users());
Instruction *InsertBefore = nullptr;
if (auto *Invoke = dyn_cast<InvokeInst>(&CB))
InsertBefore =
&SplitEdge(Invoke->getParent(), Invoke->getNormalDest())->front();
else
InsertBefore = &*std::next(CB.getIterator());
auto *Cast = CastInst::CreateBitOrPointerCast(&CB, RetTy, "", InsertBefore);
if (RetBitCast)
*RetBitCast = Cast;
for (User *U : UsersToUpdate)
U->replaceUsesOfWith(&CB, Cast);
}
CallBase &llvm::versionCallSite(CallBase &CB, Value *Callee,
MDNode *BranchWeights) {
IRBuilder<> Builder(&CB);
CallBase *OrigInst = &CB;
BasicBlock *OrigBlock = OrigInst->getParent();
if (CB.getCalledOperand()->getType() != Callee->getType())
Callee = Builder.CreateBitCast(Callee, CB.getCalledOperand()->getType());
auto *Cond = Builder.CreateICmpEQ(CB.getCalledOperand(), Callee);
if (OrigInst->isMustTailCall()) {
Instruction *ThenTerm =
SplitBlockAndInsertIfThen(Cond, &CB, false, BranchWeights);
BasicBlock *ThenBlock = ThenTerm->getParent();
ThenBlock->setName("if.true.direct_targ");
CallBase *NewInst = cast<CallBase>(OrigInst->clone());
NewInst->insertBefore(ThenTerm);
Value *NewRetVal = NewInst;
auto Next = OrigInst->getNextNode();
if (auto *BitCast = dyn_cast_or_null<BitCastInst>(Next)) {
assert(BitCast->getOperand(0) == OrigInst &&
"bitcast following musttail call must use the call");
auto NewBitCast = BitCast->clone();
NewBitCast->replaceUsesOfWith(OrigInst, NewInst);
NewBitCast->insertBefore(ThenTerm);
NewRetVal = NewBitCast;
Next = BitCast->getNextNode();
}
ReturnInst *Ret = dyn_cast_or_null<ReturnInst>(Next);
assert(Ret && "musttail call must precede a ret with an optional bitcast");
auto NewRet = Ret->clone();
if (Ret->getReturnValue())
NewRet->replaceUsesOfWith(Ret->getReturnValue(), NewRetVal);
NewRet->insertBefore(ThenTerm);
ThenTerm->eraseFromParent();
return *NewInst;
}
Instruction *ThenTerm = nullptr;
Instruction *ElseTerm = nullptr;
SplitBlockAndInsertIfThenElse(Cond, &CB, &ThenTerm, &ElseTerm, BranchWeights);
BasicBlock *ThenBlock = ThenTerm->getParent();
BasicBlock *ElseBlock = ElseTerm->getParent();
BasicBlock *MergeBlock = OrigInst->getParent();
ThenBlock->setName("if.true.direct_targ");
ElseBlock->setName("if.false.orig_indirect");
MergeBlock->setName("if.end.icp");
CallBase *NewInst = cast<CallBase>(OrigInst->clone());
OrigInst->moveBefore(ElseTerm);
NewInst->insertBefore(ThenTerm);
if (auto *OrigInvoke = dyn_cast<InvokeInst>(OrigInst)) {
auto *NewInvoke = cast<InvokeInst>(NewInst);
ThenTerm->eraseFromParent();
ElseTerm->eraseFromParent();
Builder.SetInsertPoint(MergeBlock);
Builder.CreateBr(OrigInvoke->getNormalDest());
fixupPHINodeForNormalDest(OrigInvoke, OrigBlock, MergeBlock);
fixupPHINodeForUnwindDest(OrigInvoke, MergeBlock, ThenBlock, ElseBlock);
OrigInvoke->setNormalDest(MergeBlock);
NewInvoke->setNormalDest(MergeBlock);
}
createRetPHINode(OrigInst, NewInst, MergeBlock, Builder);
return *NewInst;
}
bool llvm::isLegalToPromote(const CallBase &CB, Function *Callee,
const char **FailureReason) {
assert(!CB.getCalledFunction() && "Only indirect call sites can be promoted");
auto &DL = Callee->getParent()->getDataLayout();
Type *CallRetTy = CB.getType();
Type *FuncRetTy = Callee->getReturnType();
if (CallRetTy != FuncRetTy)
if (!CastInst::isBitOrNoopPointerCastable(FuncRetTy, CallRetTy, DL)) {
if (FailureReason)
*FailureReason = "Return type mismatch";
return false;
}
unsigned NumParams = Callee->getFunctionType()->getNumParams();
unsigned NumArgs = CB.arg_size();
if (NumArgs != NumParams && !Callee->isVarArg()) {
if (FailureReason)
*FailureReason = "The number of arguments mismatch";
return false;
}
unsigned I = 0;
for (; I < NumParams; ++I) {
Type *FormalTy = Callee->getFunctionType()->getFunctionParamType(I);
Type *ActualTy = CB.getArgOperand(I)->getType();
if (FormalTy == ActualTy)
continue;
if (!CastInst::isBitOrNoopPointerCastable(ActualTy, FormalTy, DL)) {
if (FailureReason)
*FailureReason = "Argument type mismatch";
return false;
}
if (Callee->hasParamAttribute(I, Attribute::ByVal) !=
CB.getAttributes().hasParamAttr(I, Attribute::ByVal)) {
if (FailureReason)
*FailureReason = "byval mismatch";
return false;
}
if (Callee->hasParamAttribute(I, Attribute::InAlloca) !=
CB.getAttributes().hasParamAttr(I, Attribute::InAlloca)) {
if (FailureReason)
*FailureReason = "inalloca mismatch";
return false;
}
}
for (; I < NumArgs; I++) {
assert(Callee->isVarArg());
if (CB.paramHasAttr(I, Attribute::StructRet)) {
if (FailureReason)
*FailureReason = "SRet arg to vararg function";
return false;
}
}
return true;
}
CallBase &llvm::promoteCall(CallBase &CB, Function *Callee,
CastInst **RetBitCast) {
assert(!CB.getCalledFunction() && "Only indirect call sites can be promoted");
CB.setCalledOperand(Callee);
CB.setMetadata(LLVMContext::MD_prof, nullptr);
CB.setMetadata(LLVMContext::MD_callees, nullptr);
if (CB.getFunctionType() == Callee->getFunctionType())
return CB;
Type *CallSiteRetTy = CB.getType();
Type *CalleeRetTy = Callee->getReturnType();
CB.mutateFunctionType(Callee->getFunctionType());
auto CalleeType = Callee->getFunctionType();
auto CalleeParamNum = CalleeType->getNumParams();
LLVMContext &Ctx = Callee->getContext();
const AttributeList &CallerPAL = CB.getAttributes();
SmallVector<AttributeSet, 4> NewArgAttrs;
bool AttributeChanged = false;
for (unsigned ArgNo = 0; ArgNo < CalleeParamNum; ++ArgNo) {
auto *Arg = CB.getArgOperand(ArgNo);
Type *FormalTy = CalleeType->getParamType(ArgNo);
Type *ActualTy = Arg->getType();
if (FormalTy != ActualTy) {
auto *Cast = CastInst::CreateBitOrPointerCast(Arg, FormalTy, "", &CB);
CB.setArgOperand(ArgNo, Cast);
AttrBuilder ArgAttrs(Ctx, CallerPAL.getParamAttrs(ArgNo));
ArgAttrs.remove(AttributeFuncs::typeIncompatible(FormalTy));
if (ArgAttrs.getByValType())
ArgAttrs.addByValAttr(Callee->getParamByValType(ArgNo));
if (ArgAttrs.getInAllocaType())
ArgAttrs.addInAllocaAttr(Callee->getParamInAllocaType(ArgNo));
NewArgAttrs.push_back(AttributeSet::get(Ctx, ArgAttrs));
AttributeChanged = true;
} else
NewArgAttrs.push_back(CallerPAL.getParamAttrs(ArgNo));
}
AttrBuilder RAttrs(Ctx, CallerPAL.getRetAttrs());
if (!CallSiteRetTy->isVoidTy() && CallSiteRetTy != CalleeRetTy) {
createRetBitCast(CB, CallSiteRetTy, RetBitCast);
RAttrs.remove(AttributeFuncs::typeIncompatible(CalleeRetTy));
AttributeChanged = true;
}
if (AttributeChanged)
CB.setAttributes(AttributeList::get(Ctx, CallerPAL.getFnAttrs(),
AttributeSet::get(Ctx, RAttrs),
NewArgAttrs));
return CB;
}
CallBase &llvm::promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
MDNode *BranchWeights) {
CallBase &NewInst = versionCallSite(CB, Callee, BranchWeights);
return promoteCall(NewInst, Callee);
}
bool llvm::tryPromoteCall(CallBase &CB) {
assert(!CB.getCalledFunction());
Module *M = CB.getCaller()->getParent();
const DataLayout &DL = M->getDataLayout();
Value *Callee = CB.getCalledOperand();
LoadInst *VTableEntryLoad = dyn_cast<LoadInst>(Callee);
if (!VTableEntryLoad)
return false; Value *VTableEntryPtr = VTableEntryLoad->getPointerOperand();
APInt VTableOffset(DL.getTypeSizeInBits(VTableEntryPtr->getType()), 0);
Value *VTableBasePtr = VTableEntryPtr->stripAndAccumulateConstantOffsets(
DL, VTableOffset, true);
LoadInst *VTablePtrLoad = dyn_cast<LoadInst>(VTableBasePtr);
if (!VTablePtrLoad)
return false; Value *Object = VTablePtrLoad->getPointerOperand();
APInt ObjectOffset(DL.getTypeSizeInBits(Object->getType()), 0);
Value *ObjectBase = Object->stripAndAccumulateConstantOffsets(
DL, ObjectOffset, true);
if (!(isa<AllocaInst>(ObjectBase) && ObjectOffset == 0))
return false;
BasicBlock::iterator BBI(VTablePtrLoad);
Value *VTablePtr = FindAvailableLoadedValue(
VTablePtrLoad, VTablePtrLoad->getParent(), BBI, 0, nullptr, nullptr);
if (!VTablePtr)
return false; APInt VTableOffsetGVBase(DL.getTypeSizeInBits(VTablePtr->getType()), 0);
Value *VTableGVBase = VTablePtr->stripAndAccumulateConstantOffsets(
DL, VTableOffsetGVBase, true);
GlobalVariable *GV = dyn_cast<GlobalVariable>(VTableGVBase);
if (!(GV && GV->isConstant() && GV->hasDefinitiveInitializer()))
return false;
Constant *VTableGVInitializer = GV->getInitializer();
APInt VTableGVOffset = VTableOffsetGVBase + VTableOffset;
if (!(VTableGVOffset.getActiveBits() <= 64))
return false; Constant *Ptr = getPointerAtOffset(VTableGVInitializer,
VTableGVOffset.getZExtValue(),
*M);
if (!Ptr)
return false; Function *DirectCallee = dyn_cast<Function>(Ptr->stripPointerCasts());
if (!DirectCallee)
return false;
if (!isLegalToPromote(CB, DirectCallee))
return false;
promoteCall(CB, DirectCallee);
return true;
}
#undef DEBUG_TYPE