#include "X86.h"
#include "X86InstrBuilder.h"
#include "X86MachineFunctionInfo.h"
#include "X86RegisterInfo.h"
#include "X86Subtarget.h"
#include "llvm/CodeGen/LiveIntervals.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/Passes.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/CodeGen/TileShapeInfo.h"
#include "llvm/CodeGen/VirtRegMap.h"
#include "llvm/InitializePasses.h"
using namespace llvm;
#define DEBUG_TYPE "tileconfig"
namespace {
struct X86TileConfig : public MachineFunctionPass {
X86TileConfig() : MachineFunctionPass(ID) {}
StringRef getPassName() const override { return "Tile Register Configure"; }
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesAll();
AU.addRequired<VirtRegMap>();
AU.addRequired<LiveIntervals>();
MachineFunctionPass::getAnalysisUsage(AU);
}
bool runOnMachineFunction(MachineFunction &mf) override;
MachineFunctionProperties getRequiredProperties() const override {
return MachineFunctionProperties().set(
MachineFunctionProperties::Property::NoPHIs);
}
static char ID;
};
}
char X86TileConfig::ID = 0;
INITIALIZE_PASS_BEGIN(X86TileConfig, DEBUG_TYPE, "Tile Register Configure",
false, false)
INITIALIZE_PASS_DEPENDENCY(VirtRegMap)
INITIALIZE_PASS_END(X86TileConfig, DEBUG_TYPE, "Tile Register Configure", false,
false)
bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) {
const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
const TargetRegisterInfo *TRI = ST.getRegisterInfo();
const TargetInstrInfo *TII = ST.getInstrInfo();
MachineRegisterInfo &MRI = MF.getRegInfo();
LiveIntervals &LIS = getAnalysis<LiveIntervals>();
VirtRegMap &VRM = getAnalysis<VirtRegMap>();
if (VRM.isShapeMapEmpty())
return false;
int SS = INT_MAX;
for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : MBB) {
if (MI.getOpcode() == X86::PLDTILECFGV) {
SS = MI.getOperand(0).getIndex();
break;
}
}
if (SS != INT_MAX)
break;
}
if (SS == INT_MAX)
return false;
unsigned ConstPos = 0;
MachineInstr *ConstMI = nullptr;
for (MachineInstr &MI : MF.front()) {
if (MI.getOpcode() == X86::MOV8mi && SS == MI.getOperand(0).getIndex()) {
ConstMI = &MI;
break;
}
++ConstPos;
}
assert(ConstMI && "Cannot find an insertion point");
unsigned AMXRegNum = TRI->getRegClass(X86::TILERegClassID)->getNumRegs();
SmallVector<Register, 8> Phys2Virt(AMXRegNum, 0);
for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
Register VirtReg = Register::index2VirtReg(I);
if (MRI.reg_nodbg_empty(VirtReg))
continue;
if (MRI.getRegClass(VirtReg)->getID() != X86::TILERegClassID)
continue;
if (VRM.getPhys(VirtReg) == VirtRegMap::NO_PHYS_REG)
continue;
unsigned Index = VRM.getPhys(VirtReg) - X86::TMM0;
if (!Phys2Virt[Index])
Phys2Virt[Index] = VirtReg;
}
for (unsigned I = 0; I < AMXRegNum; ++I) {
if (!Phys2Virt[I])
continue;
DebugLoc DL;
bool IsRow = true;
MachineInstr *NewMI = nullptr;
ShapeT Shape = VRM.getShape(Phys2Virt[I]);
for (auto &R : {Shape.getRow()->getReg(), Shape.getCol()->getReg()}) {
int64_t Imm = INT64_MAX;
int Offset = IsRow ? 48 + I : 16 + I * 2;
for (auto &DefMI : MRI.def_instructions(R)) {
MachineBasicBlock &MBB = *DefMI.getParent();
if (DefMI.isMoveImmediate()) {
if (Imm != INT64_MAX) {
assert(Imm == DefMI.getOperand(1).getImm() &&
"Cannot initialize with different shapes");
continue;
}
Imm = DefMI.getOperand(1).getImm();
NewMI = addFrameReference(
BuildMI(MF.front(), ++ConstMI->getIterator(), DL,
TII->get(IsRow ? X86::MOV8mi : X86::MOV16mi)),
SS, Offset)
.addImm(Imm);
ConstMI = NewMI;
LIS.InsertMachineInstrInMaps(*NewMI);
} else {
unsigned SubIdx = IsRow ? X86::sub_8bit : X86::sub_16bit;
unsigned RegSize = TRI->getRegSizeInBits(*MRI.getRegClass(R));
if ((IsRow && RegSize == 8) || (!IsRow && RegSize == 16))
SubIdx = 0;
auto Iter = DefMI.getIterator();
if (&MBB == &MF.front() &&
(unsigned)std::distance(MBB.instr_begin(), Iter) < ConstPos)
Iter = ConstMI->getIterator();
NewMI = addFrameReference(
BuildMI(MBB, ++Iter, DL,
TII->get(IsRow ? X86::MOV8mr : X86::MOV16mr)),
SS, Offset)
.addReg(R, 0, SubIdx);
SlotIndex SIdx = LIS.InsertMachineInstrInMaps(*NewMI);
LIS.extendToIndices(LIS.getInterval(R), {SIdx.getRegSlot()});
}
}
IsRow = false;
}
}
return true;
}
FunctionPass *llvm::createX86TileConfigPass() { return new X86TileConfig(); }