#include "X86.h"
#include "X86InstrBuilder.h"
#include "X86MachineFunctionInfo.h"
#include "X86RegisterInfo.h"
#include "X86Subtarget.h"
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/Passes.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/InitializePasses.h"
#include "llvm/Support/Debug.h"
using namespace llvm;
#define DEBUG_TYPE "fastpretileconfig"
STATISTIC(NumStores, "Number of stores added");
STATISTIC(NumLoads, "Number of loads added");
namespace {
class X86FastPreTileConfig : public MachineFunctionPass {
MachineFunction *MF = nullptr;
const X86Subtarget *ST = nullptr;
const TargetInstrInfo *TII = nullptr;
MachineRegisterInfo *MRI = nullptr;
X86MachineFunctionInfo *X86FI = nullptr;
MachineFrameInfo *MFI = nullptr;
const TargetRegisterInfo *TRI = nullptr;
MachineBasicBlock *MBB = nullptr;
int CfgSS = -1;
struct PHIInfo {
Register Row;
Register Col;
Register StackAddr;
};
DenseMap<MachineInstr *, struct PHIInfo> VisitedPHIs;
IndexedMap<int, VirtReg2IndexFunctor> StackSlotForVirtReg;
BitVector MayLiveAcrossBlocks;
int getStackSpaceFor(Register VirtReg);
void InitializeTileConfigStackSpace();
bool mayLiveOut(Register VirtReg, MachineInstr *CfgMI);
void spill(MachineBasicBlock::iterator Before, Register VirtReg, bool Kill);
void reload(MachineBasicBlock::iterator UseMI, Register VirtReg,
MachineOperand *RowMO, MachineOperand *ColMO);
void canonicalizePHIs(MachineBasicBlock &MBB);
void convertPHI(MachineBasicBlock *MBB, MachineInstr &PHI);
void convertPHIs(MachineBasicBlock &MBB);
bool configBasicBlock(MachineBasicBlock &MBB);
public:
X86FastPreTileConfig() : MachineFunctionPass(ID), StackSlotForVirtReg(-1) {}
StringRef getPassName() const override {
return "Fast Tile Register Preconfigure";
}
bool runOnMachineFunction(MachineFunction &MFunc) override;
static char ID;
};
}
char X86FastPreTileConfig::ID = 0;
INITIALIZE_PASS_BEGIN(X86FastPreTileConfig, DEBUG_TYPE,
"Fast Tile Register Preconfigure", false, false)
INITIALIZE_PASS_END(X86FastPreTileConfig, DEBUG_TYPE,
"Fast Tile Register Preconfigure", false, false)
static bool dominates(MachineBasicBlock &MBB,
MachineBasicBlock::const_iterator A,
MachineBasicBlock::const_iterator B) {
auto MBBEnd = MBB.end();
if (B == MBBEnd)
return true;
MachineBasicBlock::const_iterator I = MBB.begin();
for (; &*I != A && &*I != B; ++I)
;
return &*I == A;
}
int X86FastPreTileConfig::getStackSpaceFor(Register VirtReg) {
int SS = StackSlotForVirtReg[VirtReg];
if (SS != -1)
return SS;
const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg);
unsigned Size = TRI->getSpillSize(RC);
Align Alignment = TRI->getSpillAlign(RC);
int FrameIdx = MFI->CreateSpillStackObject(Size, Alignment);
StackSlotForVirtReg[VirtReg] = FrameIdx;
return FrameIdx;
}
bool X86FastPreTileConfig::mayLiveOut(Register VirtReg, MachineInstr *CfgMI) {
if (MayLiveAcrossBlocks.test(Register::virtReg2Index(VirtReg)))
return true;
for (const MachineInstr &UseInst : MRI->use_nodbg_instructions(VirtReg)) {
if (UseInst.getParent() != MBB) {
MayLiveAcrossBlocks.set(Register::virtReg2Index(VirtReg));
return true;
}
if (CfgMI) {
if (dominates(*MBB, *CfgMI, UseInst)) {
MayLiveAcrossBlocks.set(Register::virtReg2Index(VirtReg));
return true;
}
}
}
return false;
}
void X86FastPreTileConfig::InitializeTileConfigStackSpace() {
MachineBasicBlock &MBB = MF->front();
MachineInstr *MI = &*MBB.getFirstNonPHI();
DebugLoc DL;
if (ST->hasAVX512()) {
Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass);
BuildMI(MBB, MI, DL, TII->get(X86::AVX512_512_SET0), Zmm);
addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSZmr)), CfgSS)
.addReg(Zmm);
} else if (ST->hasAVX2()) {
Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass);
BuildMI(MBB, MI, DL, TII->get(X86::AVX_SET0), Ymm);
addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), CfgSS)
.addReg(Ymm);
addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), CfgSS,
32)
.addReg(Ymm);
} else {
assert(ST->hasSSE2() && "AMX should assume SSE2 enabled");
unsigned StoreOpc = ST->hasAVX() ? X86::VMOVUPSmr : X86::MOVUPSmr;
Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass);
BuildMI(MBB, MI, DL, TII->get(X86::V_SET0), Xmm);
addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), CfgSS)
.addReg(Xmm);
addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), CfgSS, 16)
.addReg(Xmm);
addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), CfgSS, 32)
.addReg(Xmm);
addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), CfgSS, 48)
.addReg(Xmm);
}
addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV8mi)), CfgSS)
.addImm(1);
}
void X86FastPreTileConfig::spill(MachineBasicBlock::iterator Before,
Register VirtReg, bool Kill) {
LLVM_DEBUG(dbgs() << "Spilling " << printReg(VirtReg, TRI) << " \n");
int FI = getStackSpaceFor(VirtReg);
LLVM_DEBUG(dbgs() << " to stack slot #" << FI << '\n');
const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg);
TII->storeRegToStackSlot(*MBB, Before, VirtReg, Kill, FI, &RC, TRI);
++NumStores;
}
void X86FastPreTileConfig::reload(MachineBasicBlock::iterator UseMI,
Register OrigReg, MachineOperand *RowMO,
MachineOperand *ColMO) {
int FI = getStackSpaceFor(OrigReg);
const TargetRegisterClass &RC = *MRI->getRegClass(OrigReg);
Register TileReg;
if (UseMI->isCopy())
TileReg = UseMI->getOperand(0).getReg();
else
TileReg = MRI->createVirtualRegister(&RC);
unsigned Opc = X86::PTILELOADDV;
Register StrideReg = MRI->createVirtualRegister(&X86::GR64_NOSPRegClass);
MachineInstr *NewMI = BuildMI(*UseMI->getParent(), UseMI, DebugLoc(),
TII->get(X86::MOV64ri), StrideReg)
.addImm(64);
NewMI = addFrameReference(
BuildMI(*UseMI->getParent(), UseMI, DebugLoc(), TII->get(Opc), TileReg)
.addReg(RowMO->getReg())
.addReg(ColMO->getReg()),
FI);
MachineOperand &MO = NewMI->getOperand(5);
MO.setReg(StrideReg);
MO.setIsKill(true);
RowMO->setIsKill(false);
ColMO->setIsKill(false);
if (UseMI->isCopy()) {
UseMI->eraseFromParent();
} else {
for (auto &MO : UseMI->operands()) {
if (MO.isReg() && MO.getReg() == OrigReg)
MO.setReg(TileReg);
}
}
++NumLoads;
LLVM_DEBUG(dbgs() << "Reloading " << printReg(OrigReg, TRI) << " into "
<< printReg(TileReg, TRI) << '\n');
}
static bool isTileDef(MachineRegisterInfo *MRI, MachineInstr &MI) {
if (MI.isDebugInstr() || MI.getNumOperands() < 3 || !MI.isPseudo())
return false;
MachineOperand &MO = MI.getOperand(0);
if (MO.isReg()) {
Register Reg = MO.getReg();
if (Reg.isVirtual() &&
MRI->getRegClass(Reg)->getID() == X86::TILERegClassID)
return true;
if (Reg >= X86::TMM0 && Reg <= X86::TMM7)
return true;
}
return false;
}
static ShapeT getShape(MachineRegisterInfo *MRI, Register TileReg) {
MachineInstr *MI = MRI->getVRegDef(TileReg);
if (isTileDef(MRI, *MI)) {
MachineOperand *RowMO = &MI->getOperand(1);
MachineOperand *ColMO = &MI->getOperand(2);
return ShapeT(RowMO, ColMO, MRI);
} else if (MI->isCopy()) {
TileReg = MI->getOperand(1).getReg();
return getShape(MRI, TileReg);
}
assert(MI->isPHI() && "Unexpected PHI when get shape.");
llvm_unreachable("Unexpected MI when get shape.");
}
void X86FastPreTileConfig::convertPHI(MachineBasicBlock *MBB,
MachineInstr &PHI) {
Register StackAddrReg = MRI->createVirtualRegister(&X86::GR64_NOSPRegClass);
MachineInstrBuilder AddrPHI = BuildMI(*MBB, ++PHI.getIterator(), DebugLoc(),
TII->get(X86::PHI), StackAddrReg);
Register RowReg = MRI->createVirtualRegister(&X86::GR16RegClass);
MachineInstrBuilder RowPHI = BuildMI(*MBB, ++PHI.getIterator(), DebugLoc(),
TII->get(X86::PHI), RowReg);
Register ColReg = MRI->createVirtualRegister(&X86::GR16RegClass);
MachineInstrBuilder ColPHI = BuildMI(*MBB, ++PHI.getIterator(), DebugLoc(),
TII->get(X86::PHI), ColReg);
VisitedPHIs[&PHI] = {RowReg, ColReg, StackAddrReg};
for (unsigned I = 1, E = PHI.getNumOperands(); I != E; I += 2) {
Register InTileReg = PHI.getOperand(I).getReg();
MayLiveAcrossBlocks.set(Register::virtReg2Index(InTileReg));
MachineBasicBlock *InMBB = PHI.getOperand(I + 1).getMBB();
MachineInstr *TileDefMI = MRI->getVRegDef(InTileReg);
MachineBasicBlock::iterator InsertPos;
if (TileDefMI->isPHI()) {
InsertPos = TileDefMI->getParent()->getFirstNonPHI();
if (VisitedPHIs.count(TileDefMI)) { Register InRowReg = VisitedPHIs[TileDefMI].Row;
Register InColReg = VisitedPHIs[TileDefMI].Col;
Register InStackAddrReg = VisitedPHIs[TileDefMI].StackAddr;
RowPHI.addReg(InRowReg).addMBB(InMBB);
ColPHI.addReg(InColReg).addMBB(InMBB);
AddrPHI.addReg(InStackAddrReg).addMBB(InMBB);
continue;
} else {
convertPHI(TileDefMI->getParent(), *TileDefMI);
MachineInstr *TileLoad = MRI->getVRegDef(InTileReg);
assert(TileLoad && TileLoad->getOpcode() == X86::PTILELOADDV);
Register InRowReg = TileLoad->getOperand(1).getReg();
Register InColReg = TileLoad->getOperand(2).getReg();
Register InStackAddrReg = TileLoad->getOperand(3).getReg();
RowPHI.addReg(InRowReg).addMBB(InMBB);
ColPHI.addReg(InColReg).addMBB(InMBB);
AddrPHI.addReg(InStackAddrReg).addMBB(InMBB);
}
} else {
InsertPos = TileDefMI->getIterator();
ShapeT Shape = getShape(MRI, InTileReg);
Shape.getRow()->setIsKill(false);
Shape.getCol()->setIsKill(false);
RowPHI.addReg(Shape.getRow()->getReg()).addMBB(InMBB);
ColPHI.addReg(Shape.getCol()->getReg()).addMBB(InMBB);
int FI = getStackSpaceFor(InTileReg);
Register InStackAddrReg =
MRI->createVirtualRegister(&X86::GR64_NOSPRegClass);
addOffset(BuildMI(*TileDefMI->getParent(), InsertPos, DebugLoc(),
TII->get(X86::LEA64r), InStackAddrReg)
.addFrameIndex(FI),
0);
AddrPHI.addReg(InStackAddrReg).addMBB(InMBB);
}
}
MachineBasicBlock::iterator InsertPos = MBB->getFirstNonPHI();
Register StrideReg = MRI->createVirtualRegister(&X86::GR64_NOSPRegClass);
BuildMI(*MBB, InsertPos, DebugLoc(), TII->get(X86::MOV64ri), StrideReg)
.addImm(64);
Register TileReg = PHI.getOperand(0).getReg();
MachineInstr *NewMI = addDirectMem(
BuildMI(*MBB, InsertPos, DebugLoc(), TII->get(X86::PTILELOADDV), TileReg)
.addReg(RowReg)
.addReg(ColReg),
StackAddrReg);
MachineOperand &MO = NewMI->getOperand(5);
MO.setReg(StrideReg);
MO.setIsKill(true);
PHI.eraseFromParent();
VisitedPHIs.erase(&PHI);
}
static bool isTileRegDef(MachineRegisterInfo *MRI, MachineInstr &MI) {
MachineOperand &MO = MI.getOperand(0);
if (MO.isReg() && MO.getReg().isVirtual() &&
MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID)
return true;
return false;
}
void X86FastPreTileConfig::canonicalizePHIs(MachineBasicBlock &MBB) {
SmallVector<MachineInstr *, 8> PHIs;
for (MachineInstr &MI : MBB) {
if (!MI.isPHI())
break;
if (!isTileRegDef(MRI, MI))
continue;
PHIs.push_back(&MI);
}
while (!PHIs.empty()) {
MachineInstr *PHI = PHIs.pop_back_val();
MachineOperand *InMO = nullptr;
MachineInstr *DefMI = nullptr;
for (unsigned I = 1, E = PHI->getNumOperands(); I != E; I += 2) {
Register InTileReg = PHI->getOperand(I).getReg();
MachineBasicBlock *InMBB = PHI->getOperand(I + 1).getMBB();
DefMI = MRI->getVRegDef(InTileReg);
if (InMBB != &MBB || !DefMI->isPHI())
continue;
InMO = &PHI->getOperand(I);
break;
}
if (!InMO)
continue;
Register DefTileReg;
for (unsigned I = 1, E = DefMI->getNumOperands(); I != E; I += 2) {
MachineBasicBlock *InMBB = PHI->getOperand(I + 1).getMBB();
if (InMBB != &MBB)
continue;
DefTileReg = DefMI->getOperand(I).getReg();
InMO->setReg(DefTileReg);
break;
}
}
}
void X86FastPreTileConfig::convertPHIs(MachineBasicBlock &MBB) {
SmallVector<MachineInstr *, 8> PHIs;
for (MachineInstr &MI : MBB) {
if (!MI.isPHI())
break;
if (!isTileRegDef(MRI, MI))
continue;
PHIs.push_back(&MI);
}
while (!PHIs.empty()) {
MachineInstr *MI = PHIs.pop_back_val();
VisitedPHIs.clear();
convertPHI(&MBB, *MI);
}
}
bool X86FastPreTileConfig::configBasicBlock(MachineBasicBlock &MBB) {
this->MBB = &MBB;
bool Change = false;
MachineInstr *LastShapeMI = nullptr;
MachineInstr *LastTileCfg = nullptr;
bool HasUnconfigTile = false;
auto Config = [&](MachineInstr &Before) {
if (CfgSS == -1)
CfgSS = MFI->CreateStackObject(ST->getTileConfigSize(),
ST->getTileConfigAlignment(), false);
LastTileCfg = addFrameReference(
BuildMI(MBB, Before, DebugLoc(), TII->get(X86::PLDTILECFGV)), CfgSS);
LastShapeMI = nullptr;
Change = true;
};
auto HasTileOperand = [](MachineRegisterInfo *MRI, MachineInstr &MI) {
for (const MachineOperand &MO : MI.operands()) {
if (!MO.isReg())
continue;
Register Reg = MO.getReg();
if (Reg.isVirtual() &&
MRI->getRegClass(Reg)->getID() == X86::TILERegClassID)
return true;
}
return false;
};
for (MachineInstr &MI : reverse(MBB)) {
if (MI.isPHI())
break;
if (HasTileOperand(MRI, MI))
HasUnconfigTile = true;
if (MI.isCall() && HasUnconfigTile) {
MachineBasicBlock::iterator I;
if (LastShapeMI && dominates(MBB, MI, LastShapeMI))
I = ++LastShapeMI->getIterator();
else
I = ++MI.getIterator();
Config(*I);
HasUnconfigTile = false;
continue;
}
if (!isTileDef(MRI, MI))
continue;
if (LastShapeMI && dominates(MBB, MI, LastShapeMI))
Config(*(++LastShapeMI->getIterator()));
MachineOperand *RowMO = &MI.getOperand(1);
MachineOperand *ColMO = &MI.getOperand(2);
MachineInstr *RowMI = MRI->getVRegDef(RowMO->getReg());
MachineInstr *ColMI = MRI->getVRegDef(ColMO->getReg());
if (RowMI->getParent() == &MBB) {
if (!LastShapeMI)
LastShapeMI = RowMI;
else if (dominates(MBB, LastShapeMI, RowMI))
LastShapeMI = RowMI;
}
if (ColMI->getParent() == &MBB) {
if (!LastShapeMI)
LastShapeMI = ColMI;
else if (dominates(MBB, LastShapeMI, ColMI))
LastShapeMI = ColMI;
}
Register TileReg = MI.getOperand(0).getReg();
if (mayLiveOut(TileReg, LastTileCfg))
spill(++MI.getIterator(), TileReg, false);
for (MachineInstr &UseMI : MRI->use_instructions(TileReg)) {
if (UseMI.getParent() == &MBB) {
if (!LastTileCfg || !dominates(MBB, LastTileCfg, UseMI))
continue;
reload(UseMI.getIterator(), TileReg, RowMO, ColMO);
} else {
if (!UseMI.isPHI())
reload(UseMI.getIterator(), TileReg, RowMO, ColMO);
}
}
}
if (HasUnconfigTile) {
MachineInstr *Before;
if (LastShapeMI == nullptr || LastShapeMI->isPHI())
Before = &*MBB.getFirstNonPHI();
else
Before = &*(++LastShapeMI->getIterator());
Config(*Before);
}
return Change;
}
bool X86FastPreTileConfig::runOnMachineFunction(MachineFunction &MFunc) {
MF = &MFunc;
MRI = &MFunc.getRegInfo();
ST = &MFunc.getSubtarget<X86Subtarget>();
TII = ST->getInstrInfo();
X86FI = MFunc.getInfo<X86MachineFunctionInfo>();
MFI = &MFunc.getFrameInfo();
TRI = ST->getRegisterInfo();
CfgSS = -1;
unsigned NumVirtRegs = MRI->getNumVirtRegs();
bool HasVirtTileReg = false;
for (unsigned I = 0, E = NumVirtRegs; I != E; ++I) {
Register VirtReg = Register::index2VirtReg(I);
if (MRI->getRegClass(VirtReg)->getID() == X86::TILERegClassID) {
HasVirtTileReg = true;
break;
}
}
if (!HasVirtTileReg)
return false;
StackSlotForVirtReg.resize(NumVirtRegs);
MayLiveAcrossBlocks.clear();
MayLiveAcrossBlocks.resize(NumVirtRegs * 3);
bool Change = false;
assert(MRI->isSSA());
for (MachineBasicBlock &MBB : MFunc)
canonicalizePHIs(MBB);
ReversePostOrderTraversal<MachineFunction *> RPOT(MF);
for (MachineBasicBlock *MBB : RPOT) {
convertPHIs(*MBB);
Change |= configBasicBlock(*MBB);
}
if (Change)
InitializeTileConfigStackSpace();
StackSlotForVirtReg.clear();
return Change;
}
FunctionPass *llvm::createX86FastPreTileConfigPass() {
return new X86FastPreTileConfig();
}