#include "SPIRV.h"
#include "SPIRVGlobalRegistry.h"
#include "SPIRVSubtarget.h"
#include "SPIRVUtils.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/IntrinsicsSPIRV.h"
#include "llvm/Target/TargetIntrinsicInfo.h"
#define DEBUG_TYPE "spirv-prelegalizer"
using namespace llvm;
namespace {
class SPIRVPreLegalizer : public MachineFunctionPass {
public:
static char ID;
SPIRVPreLegalizer() : MachineFunctionPass(ID) {
initializeSPIRVPreLegalizerPass(*PassRegistry::getPassRegistry());
}
bool runOnMachineFunction(MachineFunction &MF) override;
};
}
static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR) {
MachineRegisterInfo &MRI = MF.getRegInfo();
DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
SmallVector<MachineInstr *, 10> ToErase, ToEraseComposites;
for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : MBB) {
if (!isSpvIntrinsic(MI, Intrinsic::spv_track_constant))
continue;
ToErase.push_back(&MI);
auto *Const =
cast<Constant>(cast<ConstantAsMetadata>(
MI.getOperand(3).getMetadata()->getOperand(0))
->getValue());
if (auto *GV = dyn_cast<GlobalValue>(Const)) {
Register Reg = GR->find(GV, &MF);
if (!Reg.isValid())
GR->add(GV, &MF, MI.getOperand(2).getReg());
else
RegsAlreadyAddedToDT[&MI] = Reg;
} else {
Register Reg = GR->find(Const, &MF);
if (!Reg.isValid()) {
if (auto *ConstVec = dyn_cast<ConstantDataVector>(Const)) {
auto *BuildVec = MRI.getVRegDef(MI.getOperand(2).getReg());
assert(BuildVec &&
BuildVec->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
for (unsigned i = 0; i < ConstVec->getNumElements(); ++i)
GR->add(ConstVec->getElementAsConstant(i), &MF,
BuildVec->getOperand(1 + i).getReg());
}
GR->add(Const, &MF, MI.getOperand(2).getReg());
} else {
RegsAlreadyAddedToDT[&MI] = Reg;
assert(MI.getOperand(2).isReg() && "Reg operand is expected");
MachineInstr *SrcMI = MRI.getVRegDef(MI.getOperand(2).getReg());
if (SrcMI && isSpvIntrinsic(*SrcMI, Intrinsic::spv_const_composite))
ToEraseComposites.push_back(SrcMI);
}
}
}
}
for (MachineInstr *MI : ToErase) {
Register Reg = MI->getOperand(2).getReg();
if (RegsAlreadyAddedToDT.find(MI) != RegsAlreadyAddedToDT.end())
Reg = RegsAlreadyAddedToDT[MI];
MRI.replaceRegWith(MI->getOperand(0).getReg(), Reg);
MI->eraseFromParent();
}
for (MachineInstr *MI : ToEraseComposites)
MI->eraseFromParent();
}
static void foldConstantsIntoIntrinsics(MachineFunction &MF) {
SmallVector<MachineInstr *, 10> ToErase;
MachineRegisterInfo &MRI = MF.getRegInfo();
const unsigned AssignNameOperandShift = 2;
for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : MBB) {
if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_name))
continue;
unsigned NumOp = MI.getNumExplicitDefs() + AssignNameOperandShift;
while (MI.getOperand(NumOp).isReg()) {
MachineOperand &MOp = MI.getOperand(NumOp);
MachineInstr *ConstMI = MRI.getVRegDef(MOp.getReg());
assert(ConstMI->getOpcode() == TargetOpcode::G_CONSTANT);
MI.removeOperand(NumOp);
MI.addOperand(MachineOperand::CreateImm(
ConstMI->getOperand(1).getCImm()->getZExtValue()));
if (MRI.use_empty(ConstMI->getOperand(0).getReg()))
ToErase.push_back(ConstMI);
}
}
}
for (MachineInstr *MI : ToErase)
MI->eraseFromParent();
}
static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
MachineIRBuilder MIB) {
SmallVector<MachineInstr *, 10> ToErase;
for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : MBB) {
if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast))
continue;
assert(MI.getOperand(2).isReg());
MIB.setInsertPt(*MI.getParent(), MI);
MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
ToErase.push_back(&MI);
}
}
for (MachineInstr *MI : ToErase)
MI->eraseFromParent();
}
static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
MachineRegisterInfo &MRI,
MachineIRBuilder &MIB) {
SPIRVType *SpirvTy = nullptr;
assert(MI && "Machine instr is expected");
if (MI->getOperand(0).isReg()) {
Register Reg = MI->getOperand(0).getReg();
SpirvTy = GR->getSPIRVTypeForVReg(Reg);
if (!SpirvTy) {
switch (MI->getOpcode()) {
case TargetOpcode::G_CONSTANT: {
MIB.setInsertPt(*MI->getParent(), MI);
Type *Ty = MI->getOperand(1).getCImm()->getType();
SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
break;
}
case TargetOpcode::G_GLOBAL_VALUE: {
MIB.setInsertPt(*MI->getParent(), MI);
Type *Ty = MI->getOperand(1).getGlobal()->getType();
SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
break;
}
case TargetOpcode::G_TRUNC:
case TargetOpcode::G_ADDRSPACE_CAST:
case TargetOpcode::G_PTR_ADD:
case TargetOpcode::COPY: {
MachineOperand &Op = MI->getOperand(1);
MachineInstr *Def = Op.isReg() ? MRI.getVRegDef(Op.getReg()) : nullptr;
if (Def)
SpirvTy = propagateSPIRVType(Def, GR, MRI, MIB);
break;
}
default:
break;
}
if (SpirvTy)
GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF());
if (!MRI.getRegClassOrNull(Reg))
MRI.setRegClass(Reg, &SPIRV::IDRegClass);
}
}
return SpirvTy;
}
static Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
SPIRVGlobalRegistry *GR,
MachineIRBuilder &MIB,
MachineRegisterInfo &MRI) {
MachineInstr *Def = MRI.getVRegDef(Reg);
assert((Ty || SpirvTy) && "Either LLVM or SPIRV type is expected.");
MIB.setInsertPt(*Def->getParent(),
(Def->getNextNode() ? Def->getNextNode()->getIterator()
: Def->getParent()->end()));
Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
if (auto *RC = MRI.getRegClassOrNull(Reg))
MRI.setRegClass(NewReg, RC);
SpirvTy = SpirvTy ? SpirvTy : GR->getOrCreateSPIRVType(Ty, MIB);
GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF());
GR->assignSPIRVTypeToVReg(SpirvTy, NewReg, MIB.getMF());
MIB.buildInstr(SPIRV::ASSIGN_TYPE)
.addDef(Reg)
.addUse(NewReg)
.addUse(GR->getSPIRVTypeID(SpirvTy));
Def->getOperand(0).setReg(NewReg);
MRI.setRegClass(Reg, &SPIRV::ANYIDRegClass);
return NewReg;
}
static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
MachineIRBuilder MIB) {
MachineRegisterInfo &MRI = MF.getRegInfo();
SmallVector<MachineInstr *, 10> ToErase;
for (MachineBasicBlock *MBB : post_order(&MF)) {
if (MBB->empty())
continue;
bool ReachedBegin = false;
for (auto MII = std::prev(MBB->end()), Begin = MBB->begin();
!ReachedBegin;) {
MachineInstr &MI = *MII;
if (isSpvIntrinsic(MI, Intrinsic::spv_assign_type)) {
Register Reg = MI.getOperand(1).getReg();
Type *Ty = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
MachineInstr *Def = MRI.getVRegDef(Reg);
assert(Def && "Expecting an instruction that defines the register");
if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE)
insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MF.getRegInfo());
ToErase.push_back(&MI);
} else if (MI.getOpcode() == TargetOpcode::G_CONSTANT ||
MI.getOpcode() == TargetOpcode::G_FCONSTANT ||
MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
Register Reg = MI.getOperand(0).getReg();
if (MRI.hasOneUse(Reg)) {
MachineInstr &UseMI = *MRI.use_instr_begin(Reg);
if (isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type) ||
isSpvIntrinsic(UseMI, Intrinsic::spv_assign_name))
continue;
}
Type *Ty = nullptr;
if (MI.getOpcode() == TargetOpcode::G_CONSTANT)
Ty = MI.getOperand(1).getCImm()->getType();
else if (MI.getOpcode() == TargetOpcode::G_FCONSTANT)
Ty = MI.getOperand(1).getFPImm()->getType();
else {
assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
Type *ElemTy = nullptr;
MachineInstr *ElemMI = MRI.getVRegDef(MI.getOperand(1).getReg());
assert(ElemMI);
if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT)
ElemTy = ElemMI->getOperand(1).getCImm()->getType();
else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT)
ElemTy = ElemMI->getOperand(1).getFPImm()->getType();
else
llvm_unreachable("Unexpected opcode");
unsigned NumElts =
MI.getNumExplicitOperands() - MI.getNumExplicitDefs();
Ty = VectorType::get(ElemTy, NumElts, false);
}
insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI);
} else if (MI.getOpcode() == TargetOpcode::G_TRUNC ||
MI.getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
MI.getOpcode() == TargetOpcode::COPY ||
MI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST) {
propagateSPIRVType(&MI, GR, MRI, MIB);
}
if (MII == Begin)
ReachedBegin = true;
else
--MII;
}
}
for (MachineInstr *MI : ToErase)
MI->eraseFromParent();
}
static std::pair<Register, unsigned>
createNewIdReg(Register ValReg, unsigned Opcode, MachineRegisterInfo &MRI,
const SPIRVGlobalRegistry &GR) {
LLT NewT = LLT::scalar(32);
SPIRVType *SpvType = GR.getSPIRVTypeForVReg(ValReg);
assert(SpvType && "VReg is expected to have SPIRV type");
bool IsFloat = SpvType->getOpcode() == SPIRV::OpTypeFloat;
bool IsVectorFloat =
SpvType->getOpcode() == SPIRV::OpTypeVector &&
GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg())->getOpcode() ==
SPIRV::OpTypeFloat;
IsFloat |= IsVectorFloat;
auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID;
auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass;
if (MRI.getType(ValReg).isPointer()) {
NewT = LLT::pointer(0, 32);
GetIdOp = SPIRV::GET_pID;
DstClass = &SPIRV::pIDRegClass;
} else if (MRI.getType(ValReg).isVector()) {
NewT = LLT::fixed_vector(2, NewT);
GetIdOp = IsFloat ? SPIRV::GET_vfID : SPIRV::GET_vID;
DstClass = IsFloat ? &SPIRV::vfIDRegClass : &SPIRV::vIDRegClass;
}
Register IdReg = MRI.createGenericVirtualRegister(NewT);
MRI.setRegClass(IdReg, DstClass);
return {IdReg, GetIdOp};
}
static void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) {
unsigned Opc = MI.getOpcode();
assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg()));
MachineInstr &AssignTypeInst =
*(MRI.use_instr_begin(MI.getOperand(0).getReg()));
auto NewReg = createNewIdReg(MI.getOperand(0).getReg(), Opc, MRI, *GR).first;
AssignTypeInst.getOperand(1).setReg(NewReg);
MI.getOperand(0).setReg(NewReg);
MIB.setInsertPt(*MI.getParent(),
(MI.getNextNode() ? MI.getNextNode()->getIterator()
: MI.getParent()->end()));
for (auto &Op : MI.operands()) {
if (!Op.isReg() || Op.isDef())
continue;
auto IdOpInfo = createNewIdReg(Op.getReg(), Opc, MRI, *GR);
MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg());
Op.setReg(IdOpInfo.first);
}
}
extern bool isTypeFoldingSupported(unsigned Opcode);
static void processInstrsWithTypeFolding(MachineFunction &MF,
SPIRVGlobalRegistry *GR,
MachineIRBuilder MIB) {
MachineRegisterInfo &MRI = MF.getRegInfo();
for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : MBB) {
if (isTypeFoldingSupported(MI.getOpcode()))
processInstr(MI, MIB, MRI, GR);
}
}
for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : MBB) {
if (MI.getOpcode() != SPIRV::ASSIGN_TYPE)
continue;
Register SrcReg = MI.getOperand(1).getReg();
if (!isTypeFoldingSupported(MRI.getVRegDef(SrcReg)->getOpcode()))
continue;
Register DstReg = MI.getOperand(0).getReg();
if (MRI.getType(DstReg).isVector())
MRI.setRegClass(DstReg, &SPIRV::IDRegClass);
MRI.setType(DstReg, LLT::scalar(32));
}
}
}
static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
MachineIRBuilder MIB) {
DenseMap<Register, SmallDenseMap<uint64_t, MachineBasicBlock *>>
SwitchRegToMBB;
DenseMap<Register, MachineBasicBlock *> DefaultMBBs;
DenseSet<Register> SwitchRegs;
MachineRegisterInfo &MRI = MF.getRegInfo();
for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : MBB) {
if (isSpvIntrinsic(MI, Intrinsic::spv_switch)) {
assert(MI.getOperand(1).isReg());
Register Reg = MI.getOperand(1).getReg();
SwitchRegs.insert(Reg);
DefaultMBBs[Reg] = *MBB.succ_begin();
}
if (MI.getOpcode() == TargetOpcode::G_ICMP && MI.getOperand(2).isReg() &&
SwitchRegs.contains(MI.getOperand(2).getReg())) {
assert(MI.getOperand(0).isReg() && MI.getOperand(1).isPredicate() &&
MI.getOperand(3).isReg());
Register Dst = MI.getOperand(0).getReg();
if (GR->getSPIRVTypeForVReg(Dst) == nullptr) {
MIB.setInsertPt(*MI.getParent(), MI);
Type *LLVMTy = IntegerType::get(MF.getFunction().getContext(), 1);
SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, MIB);
MRI.setRegClass(Dst, &SPIRV::IDRegClass);
GR->assignSPIRVTypeToVReg(SpirvTy, Dst, MIB.getMF());
}
Register CmpReg = MI.getOperand(2).getReg();
MachineOperand &PredOp = MI.getOperand(1);
const auto CC = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
assert(CC == CmpInst::ICMP_EQ && MRI.hasOneUse(Dst) &&
MRI.hasOneDef(CmpReg));
uint64_t Val = getIConstVal(MI.getOperand(3).getReg(), &MRI);
MachineInstr *CBr = MRI.use_begin(Dst)->getParent();
assert(CBr->getOpcode() == SPIRV::G_BRCOND &&
CBr->getOperand(1).isMBB());
SwitchRegToMBB[CmpReg][Val] = CBr->getOperand(1).getMBB();
MachineInstr *NextMI = CBr->getNextNode();
assert(NextMI->getOpcode() == SPIRV::G_BR &&
NextMI->getOperand(0).isMBB());
MachineBasicBlock *NextMBB = NextMI->getOperand(0).getMBB();
assert(NextMBB != nullptr);
if (NextMBB->front().getOpcode() != SPIRV::G_ICMP ||
(NextMBB->front().getOperand(2).isReg() &&
NextMBB->front().getOperand(2).getReg() != CmpReg))
DefaultMBBs[CmpReg] = NextMBB;
}
}
}
for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : MBB) {
if (!isSpvIntrinsic(MI, Intrinsic::spv_switch))
continue;
assert(MI.getOperand(1).isReg());
Register Reg = MI.getOperand(1).getReg();
unsigned NumOp = MI.getNumExplicitOperands();
SmallVector<const ConstantInt *, 3> Vals;
SmallVector<MachineBasicBlock *, 3> MBBs;
for (unsigned i = 2; i < NumOp; i++) {
Register CReg = MI.getOperand(i).getReg();
uint64_t Val = getIConstVal(CReg, &MRI);
MachineInstr *ConstInstr = getDefInstrMaybeConstant(CReg, &MRI);
Vals.push_back(ConstInstr->getOperand(1).getCImm());
MBBs.push_back(SwitchRegToMBB[Reg][Val]);
}
for (unsigned i = MI.getNumExplicitOperands() - 1; i > 1; i--)
MI.removeOperand(i);
MI.addOperand(MachineOperand::CreateMBB(DefaultMBBs[Reg]));
for (unsigned i = 0; i < Vals.size(); i++) {
MI.addOperand(MachineOperand::CreateCImm(Vals[i]));
MI.addOperand(MachineOperand::CreateMBB(MBBs[i]));
}
}
}
}
bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {
const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
GR->setCurrentFunc(MF);
MachineIRBuilder MIB(MF);
addConstantsToTrack(MF, GR);
foldConstantsIntoIntrinsics(MF);
insertBitcasts(MF, GR, MIB);
generateAssignInstrs(MF, GR, MIB);
processInstrsWithTypeFolding(MF, GR, MIB);
processSwitches(MF, GR, MIB);
return true;
}
INITIALIZE_PASS(SPIRVPreLegalizer, DEBUG_TYPE, "SPIRV pre legalizer", false,
false)
char SPIRVPreLegalizer::ID = 0;
FunctionPass *llvm::createSPIRVPreLegalizerPass() {
return new SPIRVPreLegalizer();
}