#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/StackMaps.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
#include "llvm/IR/Statepoint.h"
#include "llvm/InitializePasses.h"
#include "llvm/Support/Debug.h"
using namespace llvm;
#define DEBUG_TYPE "fixup-statepoint-caller-saved"
STATISTIC(NumSpilledRegisters, "Number of spilled register");
STATISTIC(NumSpillSlotsAllocated, "Number of spill slots allocated");
STATISTIC(NumSpillSlotsExtended, "Number of spill slots extended");
static cl::opt<bool> FixupSCSExtendSlotSize(
"fixup-scs-extend-slot-size", cl::Hidden, cl::init(false),
cl::desc("Allow spill in spill slot of greater size than register size"),
cl::Hidden);
static cl::opt<bool> PassGCPtrInCSR(
"fixup-allow-gcptr-in-csr", cl::Hidden, cl::init(false),
cl::desc("Allow passing GC Pointer arguments in callee saved registers"));
static cl::opt<bool> EnableCopyProp(
"fixup-scs-enable-copy-propagation", cl::Hidden, cl::init(true),
cl::desc("Enable simple copy propagation during register reloading"));
static cl::opt<unsigned> MaxStatepointsWithRegs(
"fixup-max-csr-statepoints", cl::Hidden,
cl::desc("Max number of statepoints allowed to pass GC Ptrs in registers"));
namespace {
class FixupStatepointCallerSaved : public MachineFunctionPass {
public:
static char ID;
FixupStatepointCallerSaved() : MachineFunctionPass(ID) {
initializeFixupStatepointCallerSavedPass(*PassRegistry::getPassRegistry());
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
MachineFunctionPass::getAnalysisUsage(AU);
}
StringRef getPassName() const override {
return "Fixup Statepoint Caller Saved";
}
bool runOnMachineFunction(MachineFunction &MF) override;
};
}
char FixupStatepointCallerSaved::ID = 0;
char &llvm::FixupStatepointCallerSavedID = FixupStatepointCallerSaved::ID;
INITIALIZE_PASS_BEGIN(FixupStatepointCallerSaved, DEBUG_TYPE,
"Fixup Statepoint Caller Saved", false, false)
INITIALIZE_PASS_END(FixupStatepointCallerSaved, DEBUG_TYPE,
"Fixup Statepoint Caller Saved", false, false)
static unsigned getRegisterSize(const TargetRegisterInfo &TRI, Register Reg) {
const TargetRegisterClass *RC = TRI.getMinimalPhysRegClass(Reg);
return TRI.getSpillSize(*RC);
}
static Register performCopyPropagation(Register Reg,
MachineBasicBlock::iterator &RI,
bool &IsKill, const TargetInstrInfo &TII,
const TargetRegisterInfo &TRI) {
int Idx = RI->findRegisterUseOperandIdx(Reg, false, &TRI);
if (Idx >= 0 && (unsigned)Idx < StatepointOpers(&*RI).getNumDeoptArgsIdx()) {
IsKill = false;
return Reg;
}
if (!EnableCopyProp)
return Reg;
MachineBasicBlock *MBB = RI->getParent();
MachineBasicBlock::reverse_iterator E = MBB->rend();
MachineInstr *Def = nullptr, *Use = nullptr;
for (auto It = ++(RI.getReverse()); It != E; ++It) {
if (It->readsRegister(Reg, &TRI) && !Use)
Use = &*It;
if (It->modifiesRegister(Reg, &TRI)) {
Def = &*It;
break;
}
}
if (!Def)
return Reg;
auto DestSrc = TII.isCopyInstr(*Def);
if (!DestSrc || DestSrc->Destination->getReg() != Reg)
return Reg;
Register SrcReg = DestSrc->Source->getReg();
if (getRegisterSize(TRI, Reg) != getRegisterSize(TRI, SrcReg))
return Reg;
LLVM_DEBUG(dbgs() << "spillRegisters: perform copy propagation "
<< printReg(Reg, &TRI) << " -> " << printReg(SrcReg, &TRI)
<< "\n");
RI = ++MachineBasicBlock::iterator(Def);
IsKill = DestSrc->Source->isKill();
if (!Use) {
LLVM_DEBUG(dbgs() << "spillRegisters: removing dead copy " << *Def);
Def->eraseFromParent();
} else if (IsKill) {
const_cast<MachineOperand *>(DestSrc->Source)->setIsKill(false);
}
return SrcReg;
}
namespace {
using RegSlotPair = std::pair<Register, int>;
class RegReloadCache {
using ReloadSet = SmallSet<RegSlotPair, 8>;
DenseMap<const MachineBasicBlock *, ReloadSet> Reloads;
public:
RegReloadCache() = default;
void recordReload(Register Reg, int FI, const MachineBasicBlock *MBB) {
RegSlotPair RSP(Reg, FI);
auto Res = Reloads[MBB].insert(RSP);
(void)Res;
assert(Res.second && "reload already exists");
}
bool hasReload(Register Reg, int FI, const MachineBasicBlock *MBB) {
RegSlotPair RSP(Reg, FI);
return Reloads.count(MBB) && Reloads[MBB].count(RSP);
}
};
class FrameIndexesCache {
private:
struct FrameIndexesPerSize {
SmallVector<int, 8> Slots;
unsigned Index = 0;
};
MachineFrameInfo &MFI;
const TargetRegisterInfo &TRI;
DenseMap<unsigned, FrameIndexesPerSize> Cache;
SmallSet<int, 8> ReservedSlots;
DenseMap<const MachineBasicBlock *, SmallVector<RegSlotPair, 8>>
GlobalIndices;
FrameIndexesPerSize &getCacheBucket(unsigned Size) {
return Cache[FixupSCSExtendSlotSize ? 0 : Size];
}
public:
FrameIndexesCache(MachineFrameInfo &MFI, const TargetRegisterInfo &TRI)
: MFI(MFI), TRI(TRI) {}
void reset(const MachineBasicBlock *EHPad) {
for (auto &It : Cache)
It.second.Index = 0;
ReservedSlots.clear();
if (EHPad && GlobalIndices.count(EHPad))
for (auto &RSP : GlobalIndices[EHPad])
ReservedSlots.insert(RSP.second);
}
int getFrameIndex(Register Reg, MachineBasicBlock *EHPad) {
auto It = GlobalIndices.find(EHPad);
if (It != GlobalIndices.end()) {
auto &Vec = It->second;
auto Idx = llvm::find_if(
Vec, [Reg](RegSlotPair &RSP) { return Reg == RSP.first; });
if (Idx != Vec.end()) {
int FI = Idx->second;
LLVM_DEBUG(dbgs() << "Found global FI " << FI << " for register "
<< printReg(Reg, &TRI) << " at "
<< printMBBReference(*EHPad) << "\n");
assert(ReservedSlots.count(FI) && "using unreserved slot");
return FI;
}
}
unsigned Size = getRegisterSize(TRI, Reg);
FrameIndexesPerSize &Line = getCacheBucket(Size);
while (Line.Index < Line.Slots.size()) {
int FI = Line.Slots[Line.Index++];
if (ReservedSlots.count(FI))
continue;
if (MFI.getObjectSize(FI) < Size) {
MFI.setObjectSize(FI, Size);
MFI.setObjectAlignment(FI, Align(Size));
NumSpillSlotsExtended++;
}
return FI;
}
int FI = MFI.CreateSpillStackObject(Size, Align(Size));
NumSpillSlotsAllocated++;
Line.Slots.push_back(FI);
++Line.Index;
if (EHPad) {
GlobalIndices[EHPad].push_back(std::make_pair(Reg, FI));
LLVM_DEBUG(dbgs() << "Reserved FI " << FI << " for spilling reg "
<< printReg(Reg, &TRI) << " at landing pad "
<< printMBBReference(*EHPad) << "\n");
}
return FI;
}
void sortRegisters(SmallVectorImpl<Register> &Regs) {
if (!FixupSCSExtendSlotSize)
return;
llvm::sort(Regs, [&](Register &A, Register &B) {
return getRegisterSize(TRI, A) > getRegisterSize(TRI, B);
});
}
};
class StatepointState {
private:
MachineInstr &MI;
MachineFunction &MF;
MachineBasicBlock *EHPad;
const TargetRegisterInfo &TRI;
const TargetInstrInfo &TII;
MachineFrameInfo &MFI;
const uint32_t *Mask;
FrameIndexesCache &CacheFI;
bool AllowGCPtrInCSR;
SmallVector<unsigned, 8> OpsToSpill;
SmallVector<Register, 8> RegsToSpill;
SmallVector<Register, 8> RegsToReload;
DenseMap<Register, int> RegToSlotIdx;
public:
StatepointState(MachineInstr &MI, const uint32_t *Mask,
FrameIndexesCache &CacheFI, bool AllowGCPtrInCSR)
: MI(MI), MF(*MI.getMF()), TRI(*MF.getSubtarget().getRegisterInfo()),
TII(*MF.getSubtarget().getInstrInfo()), MFI(MF.getFrameInfo()),
Mask(Mask), CacheFI(CacheFI), AllowGCPtrInCSR(AllowGCPtrInCSR) {
EHPad = nullptr;
MachineBasicBlock *MBB = MI.getParent();
bool Last = std::none_of(++MI.getIterator(), MBB->end().getInstrIterator(),
[](MachineInstr &I) {
return I.getOpcode() == TargetOpcode::STATEPOINT;
});
if (!Last)
return;
auto IsEHPad = [](MachineBasicBlock *B) { return B->isEHPad(); };
assert(llvm::count_if(MBB->successors(), IsEHPad) < 2 && "multiple EHPads");
auto It = llvm::find_if(MBB->successors(), IsEHPad);
if (It != MBB->succ_end())
EHPad = *It;
}
MachineBasicBlock *getEHPad() const { return EHPad; }
bool isCalleeSaved(Register Reg) { return (Mask[Reg / 32] >> Reg % 32) & 1; }
bool findRegistersToSpill() {
SmallSet<Register, 8> GCRegs;
for (const auto &Def : MI.defs())
GCRegs.insert(Def.getReg());
SmallSet<Register, 8> VisitedRegs;
for (unsigned Idx = StatepointOpers(&MI).getVarIdx(),
EndIdx = MI.getNumOperands();
Idx < EndIdx; ++Idx) {
MachineOperand &MO = MI.getOperand(Idx);
if (!MO.isReg() || MO.isImplicit() || MO.isUndef())
continue;
Register Reg = MO.getReg();
assert(Reg.isPhysical() && "Only physical regs are expected");
if (isCalleeSaved(Reg) && (AllowGCPtrInCSR || !is_contained(GCRegs, Reg)))
continue;
LLVM_DEBUG(dbgs() << "Will spill " << printReg(Reg, &TRI) << " at index "
<< Idx << "\n");
if (VisitedRegs.insert(Reg).second)
RegsToSpill.push_back(Reg);
OpsToSpill.push_back(Idx);
}
CacheFI.sortRegisters(RegsToSpill);
return !RegsToSpill.empty();
}
void spillRegisters() {
for (Register Reg : RegsToSpill) {
int FI = CacheFI.getFrameIndex(Reg, EHPad);
const TargetRegisterClass *RC = TRI.getMinimalPhysRegClass(Reg);
NumSpilledRegisters++;
RegToSlotIdx[Reg] = FI;
LLVM_DEBUG(dbgs() << "Spilling " << printReg(Reg, &TRI) << " to FI " << FI
<< "\n");
bool IsKill = true;
MachineBasicBlock::iterator InsertBefore(MI);
Reg = performCopyPropagation(Reg, InsertBefore, IsKill, TII, TRI);
LLVM_DEBUG(dbgs() << "Insert spill before " << *InsertBefore);
TII.storeRegToStackSlot(*MI.getParent(), InsertBefore, Reg, IsKill, FI,
RC, &TRI);
}
}
void insertReloadBefore(unsigned Reg, MachineBasicBlock::iterator It,
MachineBasicBlock *MBB) {
const TargetRegisterClass *RC = TRI.getMinimalPhysRegClass(Reg);
int FI = RegToSlotIdx[Reg];
if (It != MBB->end()) {
TII.loadRegFromStackSlot(*MBB, It, Reg, FI, RC, &TRI);
return;
}
assert(!MBB->empty() && "Empty block");
--It;
TII.loadRegFromStackSlot(*MBB, It, Reg, FI, RC, &TRI);
MachineInstr *Reload = It->getPrevNode();
int Dummy = 0;
(void)Dummy;
assert(TII.isLoadFromStackSlot(*Reload, Dummy) == Reg);
assert(Dummy == FI);
MBB->remove(Reload);
MBB->insertAfter(It, Reload);
}
void insertReloads(MachineInstr *NewStatepoint, RegReloadCache &RC) {
MachineBasicBlock *MBB = NewStatepoint->getParent();
auto InsertPoint = std::next(NewStatepoint->getIterator());
for (auto Reg : RegsToReload) {
insertReloadBefore(Reg, InsertPoint, MBB);
LLVM_DEBUG(dbgs() << "Reloading " << printReg(Reg, &TRI) << " from FI "
<< RegToSlotIdx[Reg] << " after statepoint\n");
if (EHPad && !RC.hasReload(Reg, RegToSlotIdx[Reg], EHPad)) {
RC.recordReload(Reg, RegToSlotIdx[Reg], EHPad);
auto EHPadInsertPoint = EHPad->SkipPHIsLabelsAndDebug(EHPad->begin());
insertReloadBefore(Reg, EHPadInsertPoint, EHPad);
LLVM_DEBUG(dbgs() << "...also reload at EHPad "
<< printMBBReference(*EHPad) << "\n");
}
}
}
MachineInstr *rewriteStatepoint() {
MachineInstr *NewMI =
MF.CreateMachineInstr(TII.get(MI.getOpcode()), MI.getDebugLoc(), true);
MachineInstrBuilder MIB(MF, NewMI);
unsigned NumOps = MI.getNumOperands();
SmallVector<unsigned, 8> NewIndices;
unsigned NumDefs = MI.getNumDefs();
for (unsigned I = 0; I < NumDefs; ++I) {
MachineOperand &DefMO = MI.getOperand(I);
assert(DefMO.isReg() && DefMO.isDef() && "Expected Reg Def operand");
Register Reg = DefMO.getReg();
assert(DefMO.isTied() && "Def is expected to be tied");
if (MI.getOperand(MI.findTiedOperandIdx(I)).isUndef()) {
if (AllowGCPtrInCSR) {
NewIndices.push_back(NewMI->getNumOperands());
MIB.addReg(Reg, RegState::Define);
}
continue;
}
if (!AllowGCPtrInCSR) {
assert(is_contained(RegsToSpill, Reg));
RegsToReload.push_back(Reg);
} else {
if (isCalleeSaved(Reg)) {
NewIndices.push_back(NewMI->getNumOperands());
MIB.addReg(Reg, RegState::Define);
} else {
NewIndices.push_back(NumOps);
RegsToReload.push_back(Reg);
}
}
}
OpsToSpill.push_back(MI.getNumOperands());
unsigned CurOpIdx = 0;
for (unsigned I = NumDefs; I < MI.getNumOperands(); ++I) {
MachineOperand &MO = MI.getOperand(I);
if (I == OpsToSpill[CurOpIdx]) {
int FI = RegToSlotIdx[MO.getReg()];
MIB.addImm(StackMaps::IndirectMemRefOp);
MIB.addImm(getRegisterSize(TRI, MO.getReg()));
assert(MO.isReg() && "Should be register");
assert(MO.getReg().isPhysical() && "Should be physical register");
MIB.addFrameIndex(FI);
MIB.addImm(0);
++CurOpIdx;
} else {
MIB.add(MO);
unsigned OldDef;
if (AllowGCPtrInCSR && MI.isRegTiedToDefOperand(I, &OldDef)) {
assert(OldDef < NumDefs);
assert(NewIndices[OldDef] < NumOps);
MIB->tieOperands(NewIndices[OldDef], MIB->getNumOperands() - 1);
}
}
}
assert(CurOpIdx == (OpsToSpill.size() - 1) && "Not all operands processed");
NewMI->setMemRefs(MF, MI.memoperands());
for (auto It : RegToSlotIdx) {
Register R = It.first;
int FrameIndex = It.second;
auto PtrInfo = MachinePointerInfo::getFixedStack(MF, FrameIndex);
MachineMemOperand::Flags Flags = MachineMemOperand::MOLoad;
if (is_contained(RegsToReload, R))
Flags |= MachineMemOperand::MOStore;
auto *MMO =
MF.getMachineMemOperand(PtrInfo, Flags, getRegisterSize(TRI, R),
MFI.getObjectAlign(FrameIndex));
NewMI->addMemOperand(MF, MMO);
}
MI.getParent()->insert(MI, NewMI);
LLVM_DEBUG(dbgs() << "rewritten statepoint to : " << *NewMI << "\n");
MI.eraseFromParent();
return NewMI;
}
};
class StatepointProcessor {
private:
MachineFunction &MF;
const TargetRegisterInfo &TRI;
FrameIndexesCache CacheFI;
RegReloadCache ReloadCache;
public:
StatepointProcessor(MachineFunction &MF)
: MF(MF), TRI(*MF.getSubtarget().getRegisterInfo()),
CacheFI(MF.getFrameInfo(), TRI) {}
bool process(MachineInstr &MI, bool AllowGCPtrInCSR) {
StatepointOpers SO(&MI);
uint64_t Flags = SO.getFlags();
if (Flags & (uint64_t)StatepointFlags::DeoptLiveIn)
return false;
LLVM_DEBUG(dbgs() << "\nMBB " << MI.getParent()->getNumber() << " "
<< MI.getParent()->getName() << " : process statepoint "
<< MI);
CallingConv::ID CC = SO.getCallingConv();
const uint32_t *Mask = TRI.getCallPreservedMask(MF, CC);
StatepointState SS(MI, Mask, CacheFI, AllowGCPtrInCSR);
CacheFI.reset(SS.getEHPad());
if (!SS.findRegistersToSpill())
return false;
SS.spillRegisters();
auto *NewStatepoint = SS.rewriteStatepoint();
SS.insertReloads(NewStatepoint, ReloadCache);
return true;
}
};
}
bool FixupStatepointCallerSaved::runOnMachineFunction(MachineFunction &MF) {
if (skipFunction(MF.getFunction()))
return false;
const Function &F = MF.getFunction();
if (!F.hasGC())
return false;
SmallVector<MachineInstr *, 16> Statepoints;
for (MachineBasicBlock &BB : MF)
for (MachineInstr &I : BB)
if (I.getOpcode() == TargetOpcode::STATEPOINT)
Statepoints.push_back(&I);
if (Statepoints.empty())
return false;
bool Changed = false;
StatepointProcessor SPP(MF);
unsigned NumStatepoints = 0;
bool AllowGCPtrInCSR = PassGCPtrInCSR;
for (MachineInstr *I : Statepoints) {
++NumStatepoints;
if (MaxStatepointsWithRegs.getNumOccurrences() &&
NumStatepoints >= MaxStatepointsWithRegs)
AllowGCPtrInCSR = false;
Changed |= SPP.process(*I, AllowGCPtrInCSR);
}
return Changed;
}