#include "llvm/Transforms/Coroutines/CoroElide.h"
#include "CoroInternal.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FileSystem.h"
using namespace llvm;
#define DEBUG_TYPE "coro-elide"
STATISTIC(NumOfCoroElided, "The # of coroutine get elided.");
#ifndef NDEBUG
static cl::opt<std::string> CoroElideInfoOutputFilename(
"coro-elide-info-output-file", cl::value_desc("filename"),
cl::desc("File to record the coroutines got elided"), cl::Hidden);
#endif
namespace {
struct Lowerer : coro::LowererBase {
SmallVector<CoroIdInst *, 4> CoroIds;
SmallVector<CoroBeginInst *, 1> CoroBegins;
SmallVector<CoroAllocInst *, 1> CoroAllocs;
SmallVector<CoroSubFnInst *, 4> ResumeAddr;
DenseMap<CoroBeginInst *, SmallVector<CoroSubFnInst *, 4>> DestroyAddr;
SmallPtrSet<const SwitchInst *, 4> CoroSuspendSwitches;
Lowerer(Module &M) : LowererBase(M) {}
void elideHeapAllocations(Function *F, uint64_t FrameSize, Align FrameAlign,
AAResults &AA);
bool shouldElide(Function *F, DominatorTree &DT) const;
void collectPostSplitCoroIds(Function *F);
bool processCoroId(CoroIdInst *, AAResults &AA, DominatorTree &DT);
bool hasEscapePath(const CoroBeginInst *,
const SmallPtrSetImpl<BasicBlock *> &) const;
};
}
static void replaceWithConstant(Constant *Value,
SmallVectorImpl<CoroSubFnInst *> &Users) {
if (Users.empty())
return;
Type *IntrTy = Users.front()->getType();
Type *ValueTy = Value->getType();
if (ValueTy != IntrTy) {
assert(ValueTy->isPointerTy() && IntrTy->isPointerTy());
Value = ConstantExpr::getBitCast(Value, IntrTy);
}
for (CoroSubFnInst *I : Users)
replaceAndRecursivelySimplify(I, Value);
}
static bool operandReferences(CallInst *CI, AllocaInst *Frame, AAResults &AA) {
for (Value *Op : CI->operand_values())
if (!AA.isNoAlias(Op, Frame))
return true;
return false;
}
static void removeTailCallAttribute(AllocaInst *Frame, AAResults &AA) {
Function &F = *Frame->getFunction();
for (Instruction &I : instructions(F))
if (auto *Call = dyn_cast<CallInst>(&I))
if (Call->isTailCall() && operandReferences(Call, Frame, AA) &&
!Call->isMustTailCall())
Call->setTailCall(false);
}
static Optional<std::pair<uint64_t, Align>> getFrameLayout(Function *Resume) {
auto Size = Resume->getParamDereferenceableBytes(0);
if (!Size)
return None;
return std::make_pair(Size, Resume->getParamAlign(0).valueOrOne());
}
static Instruction *getFirstNonAllocaInTheEntryBlock(Function *F) {
for (Instruction &I : F->getEntryBlock())
if (!isa<AllocaInst>(&I))
return &I;
llvm_unreachable("no terminator in the entry block");
}
#ifndef NDEBUG
static std::unique_ptr<raw_fd_ostream> getOrCreateLogFile() {
assert(!CoroElideInfoOutputFilename.empty() &&
"coro-elide-info-output-file shouldn't be empty");
std::error_code EC;
auto Result = std::make_unique<raw_fd_ostream>(CoroElideInfoOutputFilename,
EC, sys::fs::OF_Append);
if (!EC)
return Result;
llvm::errs() << "Error opening coro-elide-info-output-file '"
<< CoroElideInfoOutputFilename << " for appending!\n";
return std::make_unique<raw_fd_ostream>(2, false); }
#endif
void Lowerer::elideHeapAllocations(Function *F, uint64_t FrameSize,
Align FrameAlign, AAResults &AA) {
LLVMContext &C = F->getContext();
auto *InsertPt =
getFirstNonAllocaInTheEntryBlock(CoroIds.front()->getFunction());
auto *False = ConstantInt::getFalse(C);
for (auto *CA : CoroAllocs) {
CA->replaceAllUsesWith(False);
CA->eraseFromParent();
}
const DataLayout &DL = F->getParent()->getDataLayout();
auto FrameTy = ArrayType::get(Type::getInt8Ty(C), FrameSize);
auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt);
Frame->setAlignment(FrameAlign);
auto *FrameVoidPtr =
new BitCastInst(Frame, Type::getInt8PtrTy(C), "vFrame", InsertPt);
for (auto *CB : CoroBegins) {
CB->replaceAllUsesWith(FrameVoidPtr);
CB->eraseFromParent();
}
removeTailCallAttribute(Frame, AA);
}
bool Lowerer::hasEscapePath(const CoroBeginInst *CB,
const SmallPtrSetImpl<BasicBlock *> &TIs) const {
const auto &It = DestroyAddr.find(CB);
assert(It != DestroyAddr.end());
unsigned Limit = 32 * (1 + It->second.size());
SmallVector<const BasicBlock *, 32> Worklist;
Worklist.push_back(CB->getParent());
SmallPtrSet<const BasicBlock *, 32> Visited;
for (auto *DA : It->second)
Visited.insert(DA->getParent());
do {
const auto *BB = Worklist.pop_back_val();
if (!Visited.insert(BB).second)
continue;
if (TIs.count(BB))
return true;
if (!--Limit)
return true;
auto TI = BB->getTerminator();
if (isa<SwitchInst>(TI) &&
CoroSuspendSwitches.count(cast<SwitchInst>(TI))) {
Worklist.push_back(cast<SwitchInst>(TI)->getSuccessor(1));
Worklist.push_back(cast<SwitchInst>(TI)->getSuccessor(2));
} else
Worklist.append(succ_begin(BB), succ_end(BB));
} while (!Worklist.empty());
return false;
}
bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const {
if (CoroAllocs.empty())
return false;
SmallPtrSet<BasicBlock *, 8> Terminators;
for (BasicBlock &B : *F) {
auto *TI = B.getTerminator();
if (TI->getNumSuccessors() == 0 && !TI->isExceptionalTerminator() &&
!isa<UnreachableInst>(TI))
Terminators.insert(&B);
}
SmallPtrSet<CoroBeginInst *, 8> ReferencedCoroBegins;
for (auto &It : DestroyAddr) {
for (Instruction *DA : It.second) {
if (llvm::all_of(Terminators, [&](auto *TI) {
return DT.dominates(DA, TI->getTerminator());
})) {
ReferencedCoroBegins.insert(It.first);
break;
}
}
if (!ReferencedCoroBegins.count(It.first) &&
!hasEscapePath(It.first, Terminators))
ReferencedCoroBegins.insert(It.first);
}
return ReferencedCoroBegins.size() == CoroBegins.size();
}
void Lowerer::collectPostSplitCoroIds(Function *F) {
CoroIds.clear();
CoroSuspendSwitches.clear();
for (auto &I : instructions(F)) {
if (auto *CII = dyn_cast<CoroIdInst>(&I))
if (CII->getInfo().isPostSplit())
if (CII->getCoroutine() != CII->getFunction())
CoroIds.push_back(CII);
if (auto *CSI = dyn_cast<CoroSuspendInst>(&I))
if (CSI->hasOneUse() && isa<SwitchInst>(CSI->use_begin()->getUser())) {
SwitchInst *SWI = cast<SwitchInst>(CSI->use_begin()->getUser());
if (SWI->getNumCases() == 2)
CoroSuspendSwitches.insert(SWI);
}
}
}
bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA,
DominatorTree &DT) {
CoroBegins.clear();
CoroAllocs.clear();
ResumeAddr.clear();
DestroyAddr.clear();
for (User *U : CoroId->users()) {
if (auto *CB = dyn_cast<CoroBeginInst>(U))
CoroBegins.push_back(CB);
else if (auto *CA = dyn_cast<CoroAllocInst>(U))
CoroAllocs.push_back(CA);
}
for (CoroBeginInst *CB : CoroBegins) {
for (User *U : CB->users())
if (auto *II = dyn_cast<CoroSubFnInst>(U))
switch (II->getIndex()) {
case CoroSubFnInst::ResumeIndex:
ResumeAddr.push_back(II);
break;
case CoroSubFnInst::DestroyIndex:
DestroyAddr[CB].push_back(II);
break;
default:
llvm_unreachable("unexpected coro.subfn.addr constant");
}
}
ConstantArray *Resumers = CoroId->getInfo().Resumers;
assert(Resumers && "PostSplit coro.id Info argument must refer to an array"
"of coroutine subfunctions");
auto *ResumeAddrConstant =
Resumers->getAggregateElement(CoroSubFnInst::ResumeIndex);
replaceWithConstant(ResumeAddrConstant, ResumeAddr);
bool ShouldElide = shouldElide(CoroId->getFunction(), DT);
auto *DestroyAddrConstant = Resumers->getAggregateElement(
ShouldElide ? CoroSubFnInst::CleanupIndex : CoroSubFnInst::DestroyIndex);
for (auto &It : DestroyAddr)
replaceWithConstant(DestroyAddrConstant, It.second);
if (ShouldElide) {
if (auto FrameSizeAndAlign =
getFrameLayout(cast<Function>(ResumeAddrConstant))) {
elideHeapAllocations(CoroId->getFunction(), FrameSizeAndAlign->first,
FrameSizeAndAlign->second, AA);
coro::replaceCoroFree(CoroId, true);
NumOfCoroElided++;
#ifndef NDEBUG
if (!CoroElideInfoOutputFilename.empty())
*getOrCreateLogFile()
<< "Elide " << CoroId->getCoroutine()->getName() << " in "
<< CoroId->getFunction()->getName() << "\n";
#endif
}
}
return true;
}
static bool declaresCoroElideIntrinsics(Module &M) {
return coro::declaresIntrinsics(M, {"llvm.coro.id", "llvm.coro.id.async"});
}
PreservedAnalyses CoroElidePass::run(Function &F, FunctionAnalysisManager &AM) {
auto &M = *F.getParent();
if (!declaresCoroElideIntrinsics(M))
return PreservedAnalyses::all();
Lowerer L(M);
L.CoroIds.clear();
L.collectPostSplitCoroIds(&F);
if (L.CoroIds.empty())
return PreservedAnalyses::all();
AAResults &AA = AM.getResult<AAManager>(F);
DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
bool Changed = false;
for (auto *CII : L.CoroIds)
Changed |= L.processCoroId(CII, AA, DT);
return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
}