#include "AMDGPU.h"
#include "Utils/AMDGPUBaseInfo.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/MemoryDependenceAnalysis.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "amdgpu-rewrite-out-arguments"
using namespace llvm;
static cl::opt<bool> AnyAddressSpace(
"amdgpu-any-address-space-out-arguments",
cl::desc("Replace pointer out arguments with "
"struct returns for non-private address space"),
cl::Hidden,
cl::init(false));
static cl::opt<unsigned> MaxNumRetRegs(
"amdgpu-max-return-arg-num-regs",
cl::desc("Approximately limit number of return registers for replacing out arguments"),
cl::Hidden,
cl::init(16));
STATISTIC(NumOutArgumentsReplaced,
"Number out arguments moved to struct return values");
STATISTIC(NumOutArgumentFunctionsReplaced,
"Number of functions with out arguments moved to struct return values");
namespace {
class AMDGPURewriteOutArguments : public FunctionPass {
private:
const DataLayout *DL = nullptr;
MemoryDependenceResults *MDA = nullptr;
Type *getStoredType(Value &Arg) const;
Type *getOutArgumentType(Argument &Arg) const;
public:
static char ID;
AMDGPURewriteOutArguments() : FunctionPass(ID) {}
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<MemoryDependenceWrapperPass>();
FunctionPass::getAnalysisUsage(AU);
}
bool doInitialization(Module &M) override;
bool runOnFunction(Function &F) override;
};
}
INITIALIZE_PASS_BEGIN(AMDGPURewriteOutArguments, DEBUG_TYPE,
"AMDGPU Rewrite Out Arguments", false, false)
INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass)
INITIALIZE_PASS_END(AMDGPURewriteOutArguments, DEBUG_TYPE,
"AMDGPU Rewrite Out Arguments", false, false)
char AMDGPURewriteOutArguments::ID = 0;
Type *AMDGPURewriteOutArguments::getStoredType(Value &Arg) const {
const int MaxUses = 10;
int UseCount = 0;
SmallVector<Use *> Worklist;
for (Use &U : Arg.uses())
Worklist.push_back(&U);
Type *StoredType = nullptr;
while (!Worklist.empty()) {
Use *U = Worklist.pop_back_val();
if (auto *BCI = dyn_cast<BitCastInst>(U->getUser())) {
for (Use &U : BCI->uses())
Worklist.push_back(&U);
continue;
}
if (auto *SI = dyn_cast<StoreInst>(U->getUser())) {
if (UseCount++ > MaxUses)
return nullptr;
if (!SI->isSimple() ||
U->getOperandNo() != StoreInst::getPointerOperandIndex())
return nullptr;
if (StoredType && StoredType != SI->getValueOperand()->getType())
return nullptr; StoredType = SI->getValueOperand()->getType();
continue;
}
return nullptr;
}
return StoredType;
}
Type *AMDGPURewriteOutArguments::getOutArgumentType(Argument &Arg) const {
const unsigned MaxOutArgSizeBytes = 4 * MaxNumRetRegs;
PointerType *ArgTy = dyn_cast<PointerType>(Arg.getType());
if (!ArgTy || (ArgTy->getAddressSpace() != DL->getAllocaAddrSpace() &&
!AnyAddressSpace) ||
Arg.hasByValAttr() || Arg.hasStructRetAttr()) {
return nullptr;
}
Type *StoredType = getStoredType(Arg);
if (!StoredType || DL->getTypeStoreSize(StoredType) > MaxOutArgSizeBytes)
return nullptr;
return StoredType;
}
bool AMDGPURewriteOutArguments::doInitialization(Module &M) {
DL = &M.getDataLayout();
return false;
}
bool AMDGPURewriteOutArguments::runOnFunction(Function &F) {
if (skipFunction(F))
return false;
if (F.isVarArg() || F.hasStructRetAttr() ||
AMDGPU::isEntryFunctionCC(F.getCallingConv()))
return false;
MDA = &getAnalysis<MemoryDependenceWrapperPass>().getMemDep();
unsigned ReturnNumRegs = 0;
SmallDenseMap<int, Type *, 4> OutArgIndexes;
SmallVector<Type *, 4> ReturnTypes;
Type *RetTy = F.getReturnType();
if (!RetTy->isVoidTy()) {
ReturnNumRegs = DL->getTypeStoreSize(RetTy) / 4;
if (ReturnNumRegs >= MaxNumRetRegs)
return false;
ReturnTypes.push_back(RetTy);
}
SmallVector<std::pair<Argument *, Type *>, 4> OutArgs;
for (Argument &Arg : F.args()) {
if (Type *Ty = getOutArgumentType(Arg)) {
LLVM_DEBUG(dbgs() << "Found possible out argument " << Arg
<< " in function " << F.getName() << '\n');
OutArgs.push_back({&Arg, Ty});
}
}
if (OutArgs.empty())
return false;
using ReplacementVec = SmallVector<std::pair<Argument *, Value *>, 4>;
DenseMap<ReturnInst *, ReplacementVec> Replacements;
SmallVector<ReturnInst *, 4> Returns;
for (BasicBlock &BB : F) {
if (ReturnInst *RI = dyn_cast<ReturnInst>(&BB.back()))
Returns.push_back(RI);
}
if (Returns.empty())
return false;
bool Changing;
do {
Changing = false;
for (const auto &Pair : OutArgs) {
bool ThisReplaceable = true;
SmallVector<std::pair<ReturnInst *, StoreInst *>, 4> ReplaceableStores;
Argument *OutArg = Pair.first;
Type *ArgTy = Pair.second;
unsigned ArgNumRegs = DL->getTypeStoreSize(ArgTy) / 4;
if (ArgNumRegs + ReturnNumRegs > MaxNumRetRegs)
continue;
for (ReturnInst *RI : Returns) {
BasicBlock *BB = RI->getParent();
MemDepResult Q = MDA->getPointerDependencyFrom(
MemoryLocation::getBeforeOrAfter(OutArg), true, BB->end(), BB, RI);
StoreInst *SI = nullptr;
if (Q.isDef())
SI = dyn_cast<StoreInst>(Q.getInst());
if (SI) {
LLVM_DEBUG(dbgs() << "Found out argument store: " << *SI << '\n');
ReplaceableStores.emplace_back(RI, SI);
} else {
ThisReplaceable = false;
break;
}
}
if (!ThisReplaceable)
continue;
for (std::pair<ReturnInst *, StoreInst *> Store : ReplaceableStores) {
Value *ReplVal = Store.second->getValueOperand();
auto &ValVec = Replacements[Store.first];
if (llvm::any_of(ValVec,
[OutArg](const std::pair<Argument *, Value *> &Entry) {
return Entry.first == OutArg;
})) {
LLVM_DEBUG(dbgs()
<< "Saw multiple out arg stores" << *OutArg << '\n');
ThisReplaceable = false;
break;
}
ValVec.emplace_back(OutArg, ReplVal);
Store.second->eraseFromParent();
}
if (ThisReplaceable) {
ReturnTypes.push_back(ArgTy);
OutArgIndexes.insert({OutArg->getArgNo(), ArgTy});
++NumOutArgumentsReplaced;
Changing = true;
}
}
} while (Changing);
if (Replacements.empty())
return false;
LLVMContext &Ctx = F.getParent()->getContext();
StructType *NewRetTy = StructType::create(Ctx, ReturnTypes, F.getName());
FunctionType *NewFuncTy = FunctionType::get(NewRetTy,
F.getFunctionType()->params(),
F.isVarArg());
LLVM_DEBUG(dbgs() << "Computed new return type: " << *NewRetTy << '\n');
Function *NewFunc = Function::Create(NewFuncTy, Function::PrivateLinkage,
F.getName() + ".body");
F.getParent()->getFunctionList().insert(F.getIterator(), NewFunc);
NewFunc->copyAttributesFrom(&F);
NewFunc->setComdat(F.getComdat());
NewFunc->stealArgumentListFrom(F);
AttributeMask RetAttrs;
RetAttrs.addAttribute(Attribute::SExt);
RetAttrs.addAttribute(Attribute::ZExt);
RetAttrs.addAttribute(Attribute::NoAlias);
NewFunc->removeRetAttrs(RetAttrs);
NewFunc->getBasicBlockList().splice(NewFunc->begin(), F.getBasicBlockList());
for (std::pair<ReturnInst *, ReplacementVec> &Replacement : Replacements) {
ReturnInst *RI = Replacement.first;
IRBuilder<> B(RI);
B.SetCurrentDebugLocation(RI->getDebugLoc());
int RetIdx = 0;
Value *NewRetVal = UndefValue::get(NewRetTy);
Value *RetVal = RI->getReturnValue();
if (RetVal)
NewRetVal = B.CreateInsertValue(NewRetVal, RetVal, RetIdx++);
for (std::pair<Argument *, Value *> ReturnPoint : Replacement.second)
NewRetVal = B.CreateInsertValue(NewRetVal, ReturnPoint.second, RetIdx++);
if (RetVal)
RI->setOperand(0, NewRetVal);
else {
B.CreateRet(NewRetVal);
RI->eraseFromParent();
}
}
SmallVector<Value *, 16> StubCallArgs;
for (Argument &Arg : F.args()) {
if (OutArgIndexes.count(Arg.getArgNo())) {
StubCallArgs.push_back(UndefValue::get(Arg.getType()));
} else {
StubCallArgs.push_back(&Arg);
}
}
BasicBlock *StubBB = BasicBlock::Create(Ctx, "", &F);
IRBuilder<> B(StubBB);
CallInst *StubCall = B.CreateCall(NewFunc, StubCallArgs);
int RetIdx = RetTy->isVoidTy() ? 0 : 1;
for (Argument &Arg : F.args()) {
if (!OutArgIndexes.count(Arg.getArgNo()))
continue;
PointerType *ArgType = cast<PointerType>(Arg.getType());
Type *EltTy = OutArgIndexes[Arg.getArgNo()];
const auto Align =
DL->getValueOrABITypeAlignment(Arg.getParamAlign(), EltTy);
Value *Val = B.CreateExtractValue(StubCall, RetIdx++);
Type *PtrTy = Val->getType()->getPointerTo(ArgType->getAddressSpace());
Value *PtrVal = B.CreateBitCast(&Arg, PtrTy);
B.CreateAlignedStore(Val, PtrVal, Align);
}
if (!RetTy->isVoidTy()) {
B.CreateRet(B.CreateExtractValue(StubCall, 0));
} else {
B.CreateRetVoid();
}
F.addFnAttr(Attribute::AlwaysInline);
++NumOutArgumentFunctionsReplaced;
return true;
}
FunctionPass *llvm::createAMDGPURewriteOutArgumentsPass() {
return new AMDGPURewriteOutArguments();
}