#include "WebAssemblyExceptionInfo.h"
#include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
#include "Utils/WebAssemblyUtilities.h"
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/CodeGen/MachineDominanceFrontier.h"
#include "llvm/CodeGen/MachineDominators.h"
#include "llvm/CodeGen/WasmEHFuncInfo.h"
#include "llvm/InitializePasses.h"
#include "llvm/MC/MCAsmInfo.h"
#include "llvm/Target/TargetMachine.h"
using namespace llvm;
#define DEBUG_TYPE "wasm-exception-info"
char WebAssemblyExceptionInfo::ID = 0;
INITIALIZE_PASS_BEGIN(WebAssemblyExceptionInfo, DEBUG_TYPE,
"WebAssembly Exception Information", true, true)
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
INITIALIZE_PASS_DEPENDENCY(MachineDominanceFrontier)
INITIALIZE_PASS_END(WebAssemblyExceptionInfo, DEBUG_TYPE,
"WebAssembly Exception Information", true, true)
bool WebAssemblyExceptionInfo::runOnMachineFunction(MachineFunction &MF) {
LLVM_DEBUG(dbgs() << "********** Exception Info Calculation **********\n"
"********** Function: "
<< MF.getName() << '\n');
releaseMemory();
if (MF.getTarget().getMCAsmInfo()->getExceptionHandlingType() !=
ExceptionHandling::Wasm ||
!MF.getFunction().hasPersonalityFn())
return false;
auto &MDT = getAnalysis<MachineDominatorTree>();
auto &MDF = getAnalysis<MachineDominanceFrontier>();
recalculate(MF, MDT, MDF);
LLVM_DEBUG(dump());
return false;
}
static bool isReachableAmongDominated(const MachineBasicBlock *Src,
const MachineBasicBlock *Dst,
const MachineBasicBlock *Header,
const MachineDominatorTree &MDT) {
assert(MDT.dominates(Header, Dst));
SmallVector<const MachineBasicBlock *, 8> WL;
SmallPtrSet<const MachineBasicBlock *, 8> Visited;
WL.push_back(Src);
while (!WL.empty()) {
const auto *MBB = WL.pop_back_val();
if (MBB == Dst)
return true;
Visited.insert(MBB);
for (auto *Succ : MBB->successors())
if (!Visited.count(Succ) && MDT.dominates(Header, Succ))
WL.push_back(Succ);
}
return false;
}
void WebAssemblyExceptionInfo::recalculate(
MachineFunction &MF, MachineDominatorTree &MDT,
const MachineDominanceFrontier &MDF) {
SmallVector<std::unique_ptr<WebAssemblyException>, 8> Exceptions;
for (auto DomNode : post_order(&MDT)) {
MachineBasicBlock *EHPad = DomNode->getBlock();
if (!EHPad->isEHPad())
continue;
auto WE = std::make_unique<WebAssemblyException>(EHPad);
discoverAndMapException(WE.get(), MDT, MDF);
Exceptions.push_back(std::move(WE));
}
const auto *EHInfo = MF.getWasmEHFuncInfo();
SmallVector<std::pair<WebAssemblyException *, WebAssemblyException *>>
UnwindWEVec;
for (auto *DomNode : depth_first(&MDT)) {
MachineBasicBlock *EHPad = DomNode->getBlock();
if (!EHPad->isEHPad())
continue;
if (!EHInfo->hasUnwindDest(EHPad))
continue;
auto *UnwindDest = EHInfo->getUnwindDest(EHPad);
auto *SrcWE = getExceptionFor(EHPad);
auto *DstWE = getExceptionFor(UnwindDest);
if (SrcWE->contains(DstWE)) {
UnwindWEVec.push_back(std::make_pair(SrcWE, DstWE));
LLVM_DEBUG(dbgs() << "Unwind destination ExceptionInfo fix:\n "
<< DstWE->getEHPad()->getNumber() << "."
<< DstWE->getEHPad()->getName()
<< "'s exception is taken out of "
<< SrcWE->getEHPad()->getNumber() << "."
<< SrcWE->getEHPad()->getName() << "'s exception\n");
DstWE->setParentException(SrcWE->getParentException());
}
}
for (auto *DomNode : depth_first(&MDT)) {
MachineBasicBlock *EHPad = DomNode->getBlock();
if (!EHPad->isEHPad())
continue;
auto *WE = getExceptionFor(EHPad);
for (auto &P : UnwindWEVec) {
auto *SrcWE = P.first;
auto *DstWE = P.second;
if (WE != SrcWE && SrcWE->contains(WE) && !DstWE->contains(WE) &&
isReachableAmongDominated(DstWE->getEHPad(), EHPad, SrcWE->getEHPad(),
MDT)) {
LLVM_DEBUG(dbgs() << "Remaining reachable ExceptionInfo fix:\n "
<< WE->getEHPad()->getNumber() << "."
<< WE->getEHPad()->getName()
<< "'s exception is taken out of "
<< SrcWE->getEHPad()->getNumber() << "."
<< SrcWE->getEHPad()->getName() << "'s exception\n");
WE->setParentException(SrcWE->getParentException());
}
}
}
for (auto *DomNode : post_order(&MDT)) {
MachineBasicBlock *MBB = DomNode->getBlock();
WebAssemblyException *WE = getExceptionFor(MBB);
for (; WE; WE = WE->getParentException())
WE->addToBlocksSet(MBB);
}
for (auto &P : UnwindWEVec) {
auto *SrcWE = P.first;
auto *DstWE = P.second;
for (auto *MBB : SrcWE->getBlocksSet()) {
if (MBB->isEHPad()) {
assert(!isReachableAmongDominated(DstWE->getEHPad(), MBB,
SrcWE->getEHPad(), MDT) &&
"We already handled EH pads above");
continue;
}
if (isReachableAmongDominated(DstWE->getEHPad(), MBB, SrcWE->getEHPad(),
MDT)) {
LLVM_DEBUG(dbgs() << "Remainder BB: " << MBB->getNumber() << "."
<< MBB->getName() << " is\n");
WebAssemblyException *InnerWE = getExceptionFor(MBB);
while (InnerWE != SrcWE) {
LLVM_DEBUG(dbgs()
<< " removed from " << InnerWE->getEHPad()->getNumber()
<< "." << InnerWE->getEHPad()->getName()
<< "'s exception\n");
InnerWE->removeFromBlocksSet(MBB);
InnerWE = InnerWE->getParentException();
}
SrcWE->removeFromBlocksSet(MBB);
LLVM_DEBUG(dbgs() << " removed from " << SrcWE->getEHPad()->getNumber()
<< "." << SrcWE->getEHPad()->getName()
<< "'s exception\n");
changeExceptionFor(MBB, SrcWE->getParentException());
if (SrcWE->getParentException())
SrcWE->getParentException()->addToBlocksSet(MBB);
}
}
}
for (auto DomNode : post_order(&MDT)) {
MachineBasicBlock *MBB = DomNode->getBlock();
WebAssemblyException *WE = getExceptionFor(MBB);
for (; WE; WE = WE->getParentException())
WE->addToBlocksVector(MBB);
}
SmallVector<WebAssemblyException*, 8> ExceptionPointers;
ExceptionPointers.reserve(Exceptions.size());
for (auto &WE : Exceptions) {
ExceptionPointers.push_back(WE.get());
if (WE->getParentException())
WE->getParentException()->getSubExceptions().push_back(std::move(WE));
else
addTopLevelException(std::move(WE));
}
for (auto *WE : ExceptionPointers) {
WE->reverseBlock();
std::reverse(WE->getSubExceptions().begin(), WE->getSubExceptions().end());
}
}
void WebAssemblyExceptionInfo::releaseMemory() {
BBMap.clear();
TopLevelExceptions.clear();
}
void WebAssemblyExceptionInfo::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesAll();
AU.addRequired<MachineDominatorTree>();
AU.addRequired<MachineDominanceFrontier>();
MachineFunctionPass::getAnalysisUsage(AU);
}
void WebAssemblyExceptionInfo::discoverAndMapException(
WebAssemblyException *WE, const MachineDominatorTree &MDT,
const MachineDominanceFrontier &MDF) {
unsigned NumBlocks = 0;
unsigned NumSubExceptions = 0;
MachineBasicBlock *EHPad = WE->getEHPad();
SmallVector<MachineBasicBlock *, 8> WL;
WL.push_back(EHPad);
while (!WL.empty()) {
MachineBasicBlock *MBB = WL.pop_back_val();
WebAssemblyException *SubE = getOutermostException(MBB);
if (SubE) {
if (SubE != WE) {
SubE->setParentException(WE);
++NumSubExceptions;
NumBlocks += SubE->getBlocksVector().capacity();
for (auto &Frontier : MDF.find(SubE->getEHPad())->second)
if (MDT.dominates(EHPad, Frontier))
WL.push_back(Frontier);
}
continue;
}
changeExceptionFor(MBB, WE);
++NumBlocks;
for (auto *Succ : MBB->successors())
if (MDT.dominates(EHPad, Succ))
WL.push_back(Succ);
}
WE->getSubExceptions().reserve(NumSubExceptions);
WE->reserveBlocks(NumBlocks);
}
WebAssemblyException *
WebAssemblyExceptionInfo::getOutermostException(MachineBasicBlock *MBB) const {
WebAssemblyException *WE = getExceptionFor(MBB);
if (WE) {
while (WebAssemblyException *Parent = WE->getParentException())
WE = Parent;
}
return WE;
}
void WebAssemblyException::print(raw_ostream &OS, unsigned Depth) const {
OS.indent(Depth * 2) << "Exception at depth " << getExceptionDepth()
<< " containing: ";
for (unsigned I = 0; I < getBlocks().size(); ++I) {
MachineBasicBlock *MBB = getBlocks()[I];
if (I)
OS << ", ";
OS << "%bb." << MBB->getNumber();
if (const auto *BB = MBB->getBasicBlock())
if (BB->hasName())
OS << "." << BB->getName();
if (getEHPad() == MBB)
OS << " (landing-pad)";
}
OS << "\n";
for (auto &SubE : SubExceptions)
SubE->print(OS, Depth + 2);
}
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
LLVM_DUMP_METHOD void WebAssemblyException::dump() const { print(dbgs()); }
#endif
raw_ostream &operator<<(raw_ostream &OS, const WebAssemblyException &WE) {
WE.print(OS);
return OS;
}
void WebAssemblyExceptionInfo::print(raw_ostream &OS, const Module *) const {
for (auto &WE : TopLevelExceptions)
WE->print(OS);
}