#include "llvm/CodeGen/SwiftErrorValueTracking.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/IR/Value.h"
using namespace llvm;
Register SwiftErrorValueTracking::getOrCreateVReg(const MachineBasicBlock *MBB,
const Value *Val) {
auto Key = std::make_pair(MBB, Val);
auto It = VRegDefMap.find(Key);
if (It == VRegDefMap.end()) {
auto &DL = MF->getDataLayout();
const TargetRegisterClass *RC = TLI->getRegClassFor(TLI->getPointerTy(DL));
auto VReg = MF->getRegInfo().createVirtualRegister(RC);
VRegDefMap[Key] = VReg;
VRegUpwardsUse[Key] = VReg;
return VReg;
} else
return It->second;
}
void SwiftErrorValueTracking::setCurrentVReg(const MachineBasicBlock *MBB,
const Value *Val, Register VReg) {
VRegDefMap[std::make_pair(MBB, Val)] = VReg;
}
Register SwiftErrorValueTracking::getOrCreateVRegDefAt(
const Instruction *I, const MachineBasicBlock *MBB, const Value *Val) {
auto Key = PointerIntPair<const Instruction *, 1, bool>(I, true);
auto It = VRegDefUses.find(Key);
if (It != VRegDefUses.end())
return It->second;
auto &DL = MF->getDataLayout();
const TargetRegisterClass *RC = TLI->getRegClassFor(TLI->getPointerTy(DL));
Register VReg = MF->getRegInfo().createVirtualRegister(RC);
VRegDefUses[Key] = VReg;
setCurrentVReg(MBB, Val, VReg);
return VReg;
}
Register SwiftErrorValueTracking::getOrCreateVRegUseAt(
const Instruction *I, const MachineBasicBlock *MBB, const Value *Val) {
auto Key = PointerIntPair<const Instruction *, 1, bool>(I, false);
auto It = VRegDefUses.find(Key);
if (It != VRegDefUses.end())
return It->second;
Register VReg = getOrCreateVReg(MBB, Val);
VRegDefUses[Key] = VReg;
return VReg;
}
void SwiftErrorValueTracking::setFunction(MachineFunction &mf) {
MF = &mf;
Fn = &MF->getFunction();
TLI = MF->getSubtarget().getTargetLowering();
TII = MF->getSubtarget().getInstrInfo();
if (!TLI->supportSwiftError())
return;
SwiftErrorVals.clear();
VRegDefMap.clear();
VRegUpwardsUse.clear();
VRegDefUses.clear();
SwiftErrorArg = nullptr;
bool HaveSeenSwiftErrorArg = false;
for (Function::const_arg_iterator AI = Fn->arg_begin(), AE = Fn->arg_end();
AI != AE; ++AI)
if (AI->hasSwiftErrorAttr()) {
assert(!HaveSeenSwiftErrorArg &&
"Must have only one swifterror parameter");
(void)HaveSeenSwiftErrorArg; HaveSeenSwiftErrorArg = true;
SwiftErrorArg = &*AI;
SwiftErrorVals.push_back(&*AI);
}
for (const auto &LLVMBB : *Fn)
for (const auto &Inst : LLVMBB) {
if (const AllocaInst *Alloca = dyn_cast<AllocaInst>(&Inst))
if (Alloca->isSwiftError())
SwiftErrorVals.push_back(Alloca);
}
}
bool SwiftErrorValueTracking::createEntriesInEntryBlock(DebugLoc DbgLoc) {
if (!TLI->supportSwiftError())
return false;
if (SwiftErrorVals.empty())
return false;
MachineBasicBlock *MBB = &*MF->begin();
auto &DL = MF->getDataLayout();
auto const *RC = TLI->getRegClassFor(TLI->getPointerTy(DL));
bool Inserted = false;
for (const auto *SwiftErrorVal : SwiftErrorVals) {
if (SwiftErrorArg && SwiftErrorArg == SwiftErrorVal)
continue;
Register VReg = MF->getRegInfo().createVirtualRegister(RC);
BuildMI(*MBB, MBB->getFirstNonPHI(), DbgLoc,
TII->get(TargetOpcode::IMPLICIT_DEF), VReg);
setCurrentVReg(MBB, SwiftErrorVal, VReg);
Inserted = true;
}
return Inserted;
}
void SwiftErrorValueTracking::propagateVRegs() {
if (!TLI->supportSwiftError())
return;
if (SwiftErrorVals.empty())
return;
ReversePostOrderTraversal<MachineFunction *> RPOT(MF);
for (MachineBasicBlock *MBB : RPOT) {
for (const auto *SwiftErrorVal : SwiftErrorVals) {
auto Key = std::make_pair(MBB, SwiftErrorVal);
auto UUseIt = VRegUpwardsUse.find(Key);
auto VRegDefIt = VRegDefMap.find(Key);
bool UpwardsUse = UUseIt != VRegUpwardsUse.end();
Register UUseVReg = UpwardsUse ? UUseIt->second : Register();
bool DownwardDef = VRegDefIt != VRegDefMap.end();
assert(!(UpwardsUse && !DownwardDef) &&
"We can't have an upwards use but no downwards def");
if (!UpwardsUse && DownwardDef)
continue;
SmallVector<std::pair<MachineBasicBlock *, Register>, 4> VRegs;
SmallSet<const MachineBasicBlock *, 8> Visited;
for (auto *Pred : MBB->predecessors()) {
if (!Visited.insert(Pred).second)
continue;
VRegs.push_back(std::make_pair(
Pred, getOrCreateVReg(Pred, SwiftErrorVal)));
if (Pred != MBB)
continue;
if (!UpwardsUse) {
UpwardsUse = true;
UUseIt = VRegUpwardsUse.find(Key);
assert(UUseIt != VRegUpwardsUse.end());
UUseVReg = UUseIt->second;
}
}
bool needPHI =
VRegs.size() >= 1 &&
llvm::find_if(
VRegs,
[&](const std::pair<const MachineBasicBlock *, Register> &V)
-> bool { return V.second != VRegs[0].second; }) !=
VRegs.end();
if (!UpwardsUse && !needPHI) {
assert(!VRegs.empty() &&
"No predecessors? The entry block should bail out earlier");
setCurrentVReg(MBB, SwiftErrorVal, VRegs[0].second);
continue;
}
auto DLoc = isa<Instruction>(SwiftErrorVal)
? cast<Instruction>(SwiftErrorVal)->getDebugLoc()
: DebugLoc();
const auto *TII = MF->getSubtarget().getInstrInfo();
if (!needPHI) {
assert(UpwardsUse);
assert(!VRegs.empty() &&
"No predecessors? Is the Calling Convention correct?");
Register DestReg = UUseVReg;
BuildMI(*MBB, MBB->getFirstNonPHI(), DLoc, TII->get(TargetOpcode::COPY),
DestReg)
.addReg(VRegs[0].second);
continue;
}
auto &DL = MF->getDataLayout();
auto const *RC = TLI->getRegClassFor(TLI->getPointerTy(DL));
Register PHIVReg =
UpwardsUse ? UUseVReg : MF->getRegInfo().createVirtualRegister(RC);
MachineInstrBuilder PHI =
BuildMI(*MBB, MBB->getFirstNonPHI(), DLoc,
TII->get(TargetOpcode::PHI), PHIVReg);
for (auto BBRegPair : VRegs) {
PHI.addReg(BBRegPair.second).addMBB(BBRegPair.first);
}
if (!UpwardsUse)
setCurrentVReg(MBB, SwiftErrorVal, PHIVReg);
}
}
}
void SwiftErrorValueTracking::preassignVRegs(
MachineBasicBlock *MBB, BasicBlock::const_iterator Begin,
BasicBlock::const_iterator End) {
if (!TLI->supportSwiftError() || SwiftErrorVals.empty())
return;
for (auto It = Begin; It != End; ++It) {
if (auto *CB = dyn_cast<CallBase>(&*It)) {
const Value *SwiftErrorAddr = nullptr;
for (const auto &Arg : CB->args()) {
if (!Arg->isSwiftError())
continue;
assert(!SwiftErrorAddr && "Cannot have multiple swifterror arguments");
SwiftErrorAddr = &*Arg;
assert(SwiftErrorAddr->isSwiftError() &&
"Must have a swifterror value argument");
getOrCreateVRegUseAt(&*It, MBB, SwiftErrorAddr);
}
if (!SwiftErrorAddr)
continue;
getOrCreateVRegDefAt(&*It, MBB, SwiftErrorAddr);
} else if (const LoadInst *LI = dyn_cast<const LoadInst>(&*It)) {
const Value *V = LI->getOperand(0);
if (!V->isSwiftError())
continue;
getOrCreateVRegUseAt(LI, MBB, V);
} else if (const StoreInst *SI = dyn_cast<const StoreInst>(&*It)) {
const Value *SwiftErrorAddr = SI->getOperand(1);
if (!SwiftErrorAddr->isSwiftError())
continue;
getOrCreateVRegDefAt(&*It, MBB, SwiftErrorAddr);
} else if (const ReturnInst *R = dyn_cast<const ReturnInst>(&*It)) {
const Function *F = R->getParent()->getParent();
if (!F->getAttributes().hasAttrSomewhere(Attribute::SwiftError))
continue;
getOrCreateVRegUseAt(R, MBB, SwiftErrorArg);
}
}
}