#include "ARM.h"
#include "ARMBaseInstrInfo.h"
#include "ARMBaseRegisterInfo.h"
#include "ARMBasicBlockInfo.h"
#include "ARMSubtarget.h"
#include "MVETailPredUtils.h"
#include "Thumb2InstrInfo.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/CodeGen/LivePhysRegs.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineLoopInfo.h"
#include "llvm/CodeGen/MachineLoopUtils.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/Passes.h"
#include "llvm/CodeGen/ReachingDefAnalysis.h"
#include "llvm/MC/MCInstrDesc.h"
using namespace llvm;
#define DEBUG_TYPE "arm-low-overhead-loops"
#define ARM_LOW_OVERHEAD_LOOPS_NAME "ARM Low Overhead Loops pass"
static cl::opt<bool>
DisableTailPredication("arm-loloops-disable-tailpred", cl::Hidden,
cl::desc("Disable tail-predication in the ARM LowOverheadLoop pass"),
cl::init(false));
static bool isVectorPredicated(MachineInstr *MI) {
int PIdx = llvm::findFirstVPTPredOperandIdx(*MI);
return PIdx != -1 && MI->getOperand(PIdx + 1).getReg() == ARM::VPR;
}
static bool isVectorPredicate(MachineInstr *MI) {
return MI->findRegisterDefOperandIdx(ARM::VPR) != -1;
}
static bool hasVPRUse(MachineInstr &MI) {
return MI.findRegisterUseOperandIdx(ARM::VPR) != -1;
}
static bool isDomainMVE(MachineInstr *MI) {
uint64_t Domain = MI->getDesc().TSFlags & ARMII::DomainMask;
return Domain == ARMII::DomainMVE;
}
static int getVecSize(const MachineInstr &MI) {
const MCInstrDesc &MCID = MI.getDesc();
uint64_t Flags = MCID.TSFlags;
return (Flags & ARMII::VecSize) >> ARMII::VecSizeShift;
}
static bool shouldInspect(MachineInstr &MI) {
if (MI.isDebugInstr())
return false;
return isDomainMVE(&MI) || isVectorPredicate(&MI) || hasVPRUse(MI);
}
namespace {
using InstSet = SmallPtrSetImpl<MachineInstr *>;
class PostOrderLoopTraversal {
MachineLoop &ML;
MachineLoopInfo &MLI;
SmallPtrSet<MachineBasicBlock*, 4> Visited;
SmallVector<MachineBasicBlock*, 4> Order;
public:
PostOrderLoopTraversal(MachineLoop &ML, MachineLoopInfo &MLI)
: ML(ML), MLI(MLI) { }
const SmallVectorImpl<MachineBasicBlock*> &getOrder() const {
return Order;
}
void ProcessLoop() {
std::function<void(MachineBasicBlock*)> Search = [this, &Search]
(MachineBasicBlock *MBB) -> void {
if (Visited.count(MBB))
return;
Visited.insert(MBB);
for (auto *Succ : MBB->successors()) {
if (!ML.contains(Succ))
continue;
Search(Succ);
}
Order.push_back(MBB);
};
SmallVector<MachineBasicBlock*, 2> ExitBlocks;
ML.getExitBlocks(ExitBlocks);
append_range(Order, ExitBlocks);
Search(ML.getHeader());
std::function<void(MachineBasicBlock*)> GetPredecessor =
[this, &GetPredecessor] (MachineBasicBlock *MBB) -> void {
Order.push_back(MBB);
if (MBB->pred_size() == 1)
GetPredecessor(*MBB->pred_begin());
};
if (auto *Preheader = ML.getLoopPreheader())
GetPredecessor(Preheader);
else if (auto *Preheader = MLI.findLoopPreheader(&ML, true, true))
GetPredecessor(Preheader);
}
};
struct PredicatedMI {
MachineInstr *MI = nullptr;
SetVector<MachineInstr*> Predicates;
public:
PredicatedMI(MachineInstr *I, SetVector<MachineInstr *> &Preds) : MI(I) {
assert(I && "Instruction must not be null!");
Predicates.insert(Preds.begin(), Preds.end());
}
};
class VPTState {
friend struct LowOverheadLoop;
SmallVector<MachineInstr *, 4> Insts;
static SmallVector<VPTState, 4> Blocks;
static SetVector<MachineInstr *> CurrentPredicates;
static std::map<MachineInstr *,
std::unique_ptr<PredicatedMI>> PredicatedInsts;
static void CreateVPTBlock(MachineInstr *MI) {
assert((CurrentPredicates.size() || MI->getParent()->isLiveIn(ARM::VPR))
&& "Can't begin VPT without predicate");
Blocks.emplace_back(MI);
PredicatedInsts.emplace(
MI, std::make_unique<PredicatedMI>(MI, CurrentPredicates));
}
static void reset() {
Blocks.clear();
PredicatedInsts.clear();
CurrentPredicates.clear();
}
static void addInst(MachineInstr *MI) {
Blocks.back().insert(MI);
PredicatedInsts.emplace(
MI, std::make_unique<PredicatedMI>(MI, CurrentPredicates));
}
static void addPredicate(MachineInstr *MI) {
LLVM_DEBUG(dbgs() << "ARM Loops: Adding VPT Predicate: " << *MI);
CurrentPredicates.insert(MI);
}
static void resetPredicate(MachineInstr *MI) {
LLVM_DEBUG(dbgs() << "ARM Loops: Resetting VPT Predicate: " << *MI);
CurrentPredicates.clear();
CurrentPredicates.insert(MI);
}
public:
static bool hasUniformPredicate(VPTState &Block) {
return getDivergent(Block) == nullptr;
}
static MachineInstr *getDivergent(VPTState &Block) {
SmallVectorImpl<MachineInstr *> &Insts = Block.getInsts();
for (unsigned i = 1; i < Insts.size(); ++i) {
MachineInstr *Next = Insts[i];
if (isVectorPredicate(Next))
return Next; }
return nullptr;
}
static bool isPredicatedOnVCTP(MachineInstr *MI, bool Exclusive = false) {
SetVector<MachineInstr *> &Predicates = PredicatedInsts[MI]->Predicates;
if (Exclusive && Predicates.size() != 1)
return false;
return llvm::any_of(Predicates, isVCTP);
}
static bool isEntryPredicatedOnVCTP(VPTState &Block,
bool Exclusive = false) {
SmallVectorImpl<MachineInstr *> &Insts = Block.getInsts();
return isPredicatedOnVCTP(Insts.front(), Exclusive);
}
static bool hasImplicitlyValidVPT(VPTState &Block,
ReachingDefAnalysis &RDA) {
SmallVectorImpl<MachineInstr *> &Insts = Block.getInsts();
MachineInstr *VPT = Insts.front();
assert(isVPTOpcode(VPT->getOpcode()) &&
"Expected VPT block to begin with VPT/VPST");
if (VPT->getOpcode() == ARM::MVE_VPST)
return false;
auto IsOperandPredicated = [&](MachineInstr *MI, unsigned Idx) {
MachineInstr *Op = RDA.getMIOperand(MI, MI->getOperand(Idx));
return Op && PredicatedInsts.count(Op) && isPredicatedOnVCTP(Op);
};
auto IsOperandInvariant = [&](MachineInstr *MI, unsigned Idx) {
MachineOperand &MO = MI->getOperand(Idx);
if (!MO.isReg() || !MO.getReg())
return true;
SmallPtrSet<MachineInstr *, 2> Defs;
RDA.getGlobalReachingDefs(MI, MO.getReg(), Defs);
if (Defs.empty())
return true;
for (auto *Def : Defs)
if (Def->getParent() == VPT->getParent())
return false;
return true;
};
return (IsOperandPredicated(VPT, 1) || IsOperandPredicated(VPT, 2)) &&
(IsOperandPredicated(VPT, 1) || IsOperandInvariant(VPT, 1)) &&
(IsOperandPredicated(VPT, 2) || IsOperandInvariant(VPT, 2));
}
static bool isValid(ReachingDefAnalysis &RDA) {
for (auto &Block : Blocks) {
if (isEntryPredicatedOnVCTP(Block, false) ||
hasImplicitlyValidVPT(Block, RDA))
continue;
SmallVectorImpl<MachineInstr *> &Insts = Block.getInsts();
assert(isVPTOpcode(Insts.front()->getOpcode()) &&
"Expected VPT block to start with a VPST or VPT!");
if (Insts.size() == 2 && Insts.front()->getOpcode() != ARM::MVE_VPST &&
isVCTP(Insts.back()))
return false;
for (auto *MI : Insts) {
if (isVCTP(MI) && getVPTInstrPredicate(*MI) != ARMVCC::Then)
return false;
if (MI->getOpcode() == ARM::MVE_VPST || isVectorPredicate(MI))
continue;
if (!isPredicatedOnVCTP(MI)) {
LLVM_DEBUG(dbgs() << "ARM Loops: Can't convert: " << *MI);
return false;
}
}
}
return true;
}
VPTState(MachineInstr *MI) { Insts.push_back(MI); }
void insert(MachineInstr *MI) {
Insts.push_back(MI);
assert(Insts.size() <= 5 && "Too many instructions in VPT block!");
}
bool containsVCTP() const {
return llvm::any_of(Insts, isVCTP);
}
unsigned size() const { return Insts.size(); }
SmallVectorImpl<MachineInstr *> &getInsts() { return Insts; }
};
struct LowOverheadLoop {
MachineLoop &ML;
MachineBasicBlock *Preheader = nullptr;
MachineLoopInfo &MLI;
ReachingDefAnalysis &RDA;
const TargetRegisterInfo &TRI;
const ARMBaseInstrInfo &TII;
MachineFunction *MF = nullptr;
MachineBasicBlock::iterator StartInsertPt;
MachineBasicBlock *StartInsertBB = nullptr;
MachineInstr *Start = nullptr;
MachineInstr *Dec = nullptr;
MachineInstr *End = nullptr;
MachineOperand TPNumElements;
SmallVector<MachineInstr *, 4> VCTPs;
SmallPtrSet<MachineInstr *, 4> ToRemove;
SmallPtrSet<MachineInstr *, 4> BlockMasksToRecompute;
SmallPtrSet<MachineInstr *, 4> DoubleWidthResultInstrs;
SmallPtrSet<MachineInstr *, 4> VMOVCopies;
bool Revert = false;
bool CannotTailPredicate = false;
LowOverheadLoop(MachineLoop &ML, MachineLoopInfo &MLI,
ReachingDefAnalysis &RDA, const TargetRegisterInfo &TRI,
const ARMBaseInstrInfo &TII)
: ML(ML), MLI(MLI), RDA(RDA), TRI(TRI), TII(TII),
TPNumElements(MachineOperand::CreateImm(0)) {
MF = ML.getHeader()->getParent();
if (auto *MBB = ML.getLoopPreheader())
Preheader = MBB;
else if (auto *MBB = MLI.findLoopPreheader(&ML, true, true))
Preheader = MBB;
VPTState::reset();
}
bool ValidateMVEInst(MachineInstr *MI);
void AnalyseMVEInst(MachineInstr *MI) {
CannotTailPredicate = !ValidateMVEInst(MI);
}
bool IsTailPredicationLegal() const {
return !Revert && FoundAllComponents() && !VCTPs.empty() &&
!CannotTailPredicate && ML.getNumBlocks() == 1;
}
bool AddVCTP(MachineInstr *MI);
bool ValidateTailPredicate();
bool ValidateLiveOuts();
MachineInstr *isSafeToDefineLR();
void Validate(ARMBasicBlockUtils *BBUtils);
bool FoundAllComponents() const {
return Start && Dec && End;
}
SmallVectorImpl<VPTState> &getVPTBlocks() {
return VPTState::Blocks;
}
MachineOperand &getLoopStartOperand() {
if (IsTailPredicationLegal())
return TPNumElements;
return Start->getOperand(1);
}
unsigned getStartOpcode() const {
bool IsDo = isDoLoopStart(*Start);
if (!IsTailPredicationLegal())
return IsDo ? ARM::t2DLS : ARM::t2WLS;
return VCTPOpcodeToLSTP(VCTPs.back()->getOpcode(), IsDo);
}
void dump() const {
if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start;
if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec;
if (End) dbgs() << "ARM Loops: Found Loop End: " << *End;
if (!VCTPs.empty()) {
dbgs() << "ARM Loops: Found VCTP(s):\n";
for (auto *MI : VCTPs)
dbgs() << " - " << *MI;
}
if (!FoundAllComponents())
dbgs() << "ARM Loops: Not a low-overhead loop.\n";
else if (!(Start && Dec && End))
dbgs() << "ARM Loops: Failed to find all loop components.\n";
}
};
class ARMLowOverheadLoops : public MachineFunctionPass {
MachineFunction *MF = nullptr;
MachineLoopInfo *MLI = nullptr;
ReachingDefAnalysis *RDA = nullptr;
const ARMBaseInstrInfo *TII = nullptr;
MachineRegisterInfo *MRI = nullptr;
const TargetRegisterInfo *TRI = nullptr;
std::unique_ptr<ARMBasicBlockUtils> BBUtils = nullptr;
public:
static char ID;
ARMLowOverheadLoops() : MachineFunctionPass(ID) { }
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
AU.addRequired<MachineLoopInfo>();
AU.addRequired<ReachingDefAnalysis>();
MachineFunctionPass::getAnalysisUsage(AU);
}
bool runOnMachineFunction(MachineFunction &MF) override;
MachineFunctionProperties getRequiredProperties() const override {
return MachineFunctionProperties().set(
MachineFunctionProperties::Property::NoVRegs).set(
MachineFunctionProperties::Property::TracksLiveness);
}
StringRef getPassName() const override {
return ARM_LOW_OVERHEAD_LOOPS_NAME;
}
private:
bool ProcessLoop(MachineLoop *ML);
bool RevertNonLoops();
void RevertWhile(MachineInstr *MI) const;
void RevertDo(MachineInstr *MI) const;
bool RevertLoopDec(MachineInstr *MI) const;
void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const;
void RevertLoopEndDec(MachineInstr *MI) const;
void ConvertVPTBlocks(LowOverheadLoop &LoLoop);
MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop);
void Expand(LowOverheadLoop &LoLoop);
void IterationCountDCE(LowOverheadLoop &LoLoop);
};
}
char ARMLowOverheadLoops::ID = 0;
SmallVector<VPTState, 4> VPTState::Blocks;
SetVector<MachineInstr *> VPTState::CurrentPredicates;
std::map<MachineInstr *,
std::unique_ptr<PredicatedMI>> VPTState::PredicatedInsts;
INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME,
false, false)
static bool TryRemove(MachineInstr *MI, ReachingDefAnalysis &RDA,
InstSet &ToRemove, InstSet &Ignore) {
auto WontCorruptITs = [](InstSet &Killed, ReachingDefAnalysis &RDA) {
SmallPtrSet<MachineBasicBlock*, 2> BasicBlocks;
for (auto *Dead : Killed)
BasicBlocks.insert(Dead->getParent());
std::map<MachineInstr *, SmallPtrSet<MachineInstr *, 2>> ITBlocks;
for (auto *MBB : BasicBlocks) {
for (auto &IT : *MBB) {
if (IT.getOpcode() != ARM::t2IT)
continue;
RDA.getReachingLocalUses(&IT, MCRegister::from(ARM::ITSTATE),
ITBlocks[&IT]);
}
}
SmallPtrSet<MachineInstr *, 2> ModifiedITs;
SmallPtrSet<MachineInstr *, 2> RemoveITs;
for (auto *Dead : Killed) {
if (MachineOperand *MO = Dead->findRegisterUseOperand(ARM::ITSTATE)) {
MachineInstr *IT = RDA.getMIOperand(Dead, *MO);
RemoveITs.insert(IT);
auto &CurrentBlock = ITBlocks[IT];
CurrentBlock.erase(Dead);
if (CurrentBlock.empty())
ModifiedITs.erase(IT);
else
ModifiedITs.insert(IT);
}
}
if (!ModifiedITs.empty())
return false;
Killed.insert(RemoveITs.begin(), RemoveITs.end());
return true;
};
SmallPtrSet<MachineInstr *, 2> Uses;
if (!RDA.isSafeToRemove(MI, Uses, Ignore))
return false;
if (WontCorruptITs(Uses, RDA)) {
ToRemove.insert(Uses.begin(), Uses.end());
LLVM_DEBUG(dbgs() << "ARM Loops: Able to remove: " << *MI
<< " - can also remove:\n";
for (auto *Use : Uses)
dbgs() << " - " << *Use);
SmallPtrSet<MachineInstr*, 4> Killed;
RDA.collectKilledOperands(MI, Killed);
if (WontCorruptITs(Killed, RDA)) {
ToRemove.insert(Killed.begin(), Killed.end());
LLVM_DEBUG(for (auto *Dead : Killed)
dbgs() << " - " << *Dead);
}
return true;
}
return false;
}
bool LowOverheadLoop::ValidateTailPredicate() {
if (!IsTailPredicationLegal()) {
LLVM_DEBUG(if (VCTPs.empty())
dbgs() << "ARM Loops: Didn't find a VCTP instruction.\n";
dbgs() << "ARM Loops: Tail-predication is not valid.\n");
return false;
}
assert(!VCTPs.empty() && "VCTP instruction expected but is not set");
assert(ML.getBlocks().size() == 1 &&
"Shouldn't be processing a loop with more than one block");
if (DisableTailPredication) {
LLVM_DEBUG(dbgs() << "ARM Loops: tail-predication is disabled\n");
return false;
}
if (!VPTState::isValid(RDA)) {
LLVM_DEBUG(dbgs() << "ARM Loops: Invalid VPT state.\n");
return false;
}
if (!ValidateLiveOuts()) {
LLVM_DEBUG(dbgs() << "ARM Loops: Invalid live outs.\n");
return false;
}
MachineInstr *VCTP = VCTPs.back();
if (Start->getOpcode() == ARM::t2DoLoopStartTP ||
Start->getOpcode() == ARM::t2WhileLoopStartTP) {
TPNumElements = Start->getOperand(2);
StartInsertPt = Start;
StartInsertBB = Start->getParent();
} else {
TPNumElements = VCTP->getOperand(1);
MCRegister NumElements = TPNumElements.getReg().asMCReg();
if (RDA.hasLocalDefBefore(VCTP, NumElements)) {
LLVM_DEBUG(dbgs() << "ARM Loops: VCTP operand is defined in the loop.\n");
return false;
}
if (StartInsertPt != StartInsertBB->end() &&
!RDA.isReachingDefLiveOut(&*StartInsertPt, NumElements)) {
if (auto *ElemDef =
RDA.getLocalLiveOutMIDef(StartInsertBB, NumElements)) {
if (RDA.isSafeToMoveForwards(ElemDef, &*StartInsertPt)) {
ElemDef->removeFromParent();
StartInsertBB->insert(StartInsertPt, ElemDef);
LLVM_DEBUG(dbgs()
<< "ARM Loops: Moved element count def: " << *ElemDef);
} else if (RDA.isSafeToMoveBackwards(&*StartInsertPt, ElemDef)) {
StartInsertPt->removeFromParent();
StartInsertBB->insertAfter(MachineBasicBlock::iterator(ElemDef),
&*StartInsertPt);
LLVM_DEBUG(dbgs() << "ARM Loops: Moved start past: " << *ElemDef);
} else {
MachineOperand Operand = ElemDef->getOperand(1);
if (isMovRegOpcode(ElemDef->getOpcode()) &&
RDA.getUniqueReachingMIDef(ElemDef, Operand.getReg().asMCReg()) ==
RDA.getUniqueReachingMIDef(&*StartInsertPt,
Operand.getReg().asMCReg())) {
TPNumElements = Operand;
NumElements = TPNumElements.getReg();
} else {
LLVM_DEBUG(dbgs()
<< "ARM Loops: Unable to move element count to loop "
<< "start instruction.\n");
return false;
}
}
}
}
auto CannotProvideElements = [this](MachineBasicBlock *MBB,
MCRegister NumElements) {
if (MBB->empty())
return false;
if (RDA.hasLocalDefBefore(&MBB->back(), NumElements))
return true;
if (MBB->pred_size() > 1)
return true;
return false;
};
MachineBasicBlock *MBB = Preheader;
while (MBB && MBB != StartInsertBB) {
if (CannotProvideElements(MBB, NumElements)) {
LLVM_DEBUG(dbgs() << "ARM Loops: Unable to provide element count.\n");
return false;
}
MBB = *MBB->pred_begin();
}
}
if (std::any_of(StartInsertPt, StartInsertBB->end(), shouldInspect)) {
LLVM_DEBUG(dbgs() << "ARM Loops: Instruction blocks [W|D]LSTP\n");
return false;
}
unsigned VCTPVecSize = getVecSize(*VCTP);
for (MachineInstr *MI : DoubleWidthResultInstrs) {
unsigned InstrVecSize = getVecSize(*MI);
if (InstrVecSize > VCTPVecSize) {
LLVM_DEBUG(dbgs() << "ARM Loops: Double width result larger than VCTP "
<< "VecSize:\n" << *MI);
return false;
}
}
auto IsValidSub = [](MachineInstr *MI, int ExpectedVecWidth) {
return -getAddSubImmediate(*MI) == ExpectedVecWidth;
};
MachineBasicBlock *MBB = VCTP->getParent();
if (auto *Def = RDA.getUniqueReachingMIDef(
&MBB->back(), VCTP->getOperand(1).getReg().asMCReg())) {
SmallPtrSet<MachineInstr*, 2> ElementChain;
SmallPtrSet<MachineInstr*, 2> Ignore;
unsigned ExpectedVectorWidth = getTailPredVectorWidth(VCTP->getOpcode());
Ignore.insert(VCTPs.begin(), VCTPs.end());
if (TryRemove(Def, RDA, ElementChain, Ignore)) {
bool FoundSub = false;
for (auto *MI : ElementChain) {
if (isMovRegOpcode(MI->getOpcode()))
continue;
if (isSubImmOpcode(MI->getOpcode())) {
if (FoundSub || !IsValidSub(MI, ExpectedVectorWidth)) {
LLVM_DEBUG(dbgs() << "ARM Loops: Unexpected instruction in element"
" count: " << *MI);
return false;
}
FoundSub = true;
} else {
LLVM_DEBUG(dbgs() << "ARM Loops: Unexpected instruction in element"
" count: " << *MI);
return false;
}
}
ToRemove.insert(ElementChain.begin(), ElementChain.end());
}
}
if ((Start->getOpcode() == ARM::t2DoLoopStartTP ||
Start->getOpcode() == ARM::t2WhileLoopStartTP) &&
Preheader && !Preheader->empty() &&
!RDA.hasLocalDefBefore(VCTP, VCTP->getOperand(1).getReg())) {
if (auto *Def = RDA.getUniqueReachingMIDef(
&Preheader->back(), VCTP->getOperand(1).getReg().asMCReg())) {
SmallPtrSet<MachineInstr*, 2> Ignore;
Ignore.insert(VCTPs.begin(), VCTPs.end());
TryRemove(Def, RDA, ToRemove, Ignore);
}
}
return true;
}
static bool isRegInClass(const MachineOperand &MO,
const TargetRegisterClass *Class) {
return MO.isReg() && MO.getReg() && Class->contains(MO.getReg());
}
static bool retainsPreviousHalfElement(const MachineInstr &MI) {
const MCInstrDesc &MCID = MI.getDesc();
uint64_t Flags = MCID.TSFlags;
return (Flags & ARMII::RetainsPreviousHalfElement) != 0;
}
static bool producesDoubleWidthResult(const MachineInstr &MI) {
const MCInstrDesc &MCID = MI.getDesc();
uint64_t Flags = MCID.TSFlags;
return (Flags & ARMII::DoubleWidthResult) != 0;
}
static bool isHorizontalReduction(const MachineInstr &MI) {
const MCInstrDesc &MCID = MI.getDesc();
uint64_t Flags = MCID.TSFlags;
return (Flags & ARMII::HorizontalReduction) != 0;
}
static bool canGenerateNonZeros(const MachineInstr &MI) {
if (producesDoubleWidthResult(MI))
return true;
switch (MI.getOpcode()) {
default:
break;
case ARM::MVE_VMVN:
case ARM::MVE_VORN:
case ARM::MVE_VCLZs8:
case ARM::MVE_VCLZs16:
case ARM::MVE_VCLZs32:
return true;
}
return false;
}
static bool producesFalseLanesZero(MachineInstr &MI,
const TargetRegisterClass *QPRs,
const ReachingDefAnalysis &RDA,
InstSet &FalseLanesZero) {
if (canGenerateNonZeros(MI))
return false;
bool isPredicated = isVectorPredicated(&MI);
if (MI.mayLoad())
return isPredicated;
auto IsZeroInit = [](MachineInstr *Def) {
return !isVectorPredicated(Def) &&
Def->getOpcode() == ARM::MVE_VMOVimmi32 &&
Def->getOperand(1).getImm() == 0;
};
bool AllowScalars = isHorizontalReduction(MI);
for (auto &MO : MI.operands()) {
if (!MO.isReg() || !MO.getReg())
continue;
if (!isRegInClass(MO, QPRs) && AllowScalars)
continue;
int PIdx = llvm::findFirstVPTPredOperandIdx(MI);
if (PIdx != -1 && (int)MI.getOperandNo(&MO) == PIdx + 2)
continue;
SmallPtrSet<MachineInstr *, 2> Defs;
RDA.getGlobalReachingDefs(&MI, MO.getReg(), Defs);
for (auto *Def : Defs) {
if (Def == &MI || FalseLanesZero.count(Def) || IsZeroInit(Def))
continue;
if (MO.isUse() && isPredicated)
continue;
return false;
}
}
LLVM_DEBUG(dbgs() << "ARM Loops: Always False Zeros: " << MI);
return true;
}
bool LowOverheadLoop::ValidateLiveOuts() {
const TargetRegisterClass *QPRs = TRI.getRegClass(ARM::MQPRRegClassID);
SetVector<MachineInstr *> FalseLanesUnknown;
SmallPtrSet<MachineInstr *, 4> FalseLanesZero;
SmallPtrSet<MachineInstr *, 4> Predicated;
MachineBasicBlock *Header = ML.getHeader();
LLVM_DEBUG(dbgs() << "ARM Loops: Validating Live outs\n");
for (auto &MI : *Header) {
if (!shouldInspect(MI))
continue;
if (isVCTP(&MI) || isVPTOpcode(MI.getOpcode()))
continue;
bool isPredicated = isVectorPredicated(&MI);
bool retainsOrReduces =
retainsPreviousHalfElement(MI) || isHorizontalReduction(MI);
if (isPredicated)
Predicated.insert(&MI);
if (producesFalseLanesZero(MI, QPRs, RDA, FalseLanesZero))
FalseLanesZero.insert(&MI);
else if (MI.getNumDefs() == 0)
continue;
else if (!isPredicated && retainsOrReduces) {
LLVM_DEBUG(dbgs() << " Unpredicated instruction that retainsOrReduces: " << MI);
return false;
} else if (!isPredicated && MI.getOpcode() != ARM::MQPRCopy)
FalseLanesUnknown.insert(&MI);
}
LLVM_DEBUG({
dbgs() << " Predicated:\n";
for (auto *I : Predicated)
dbgs() << " " << *I;
dbgs() << " FalseLanesZero:\n";
for (auto *I : FalseLanesZero)
dbgs() << " " << *I;
dbgs() << " FalseLanesUnknown:\n";
for (auto *I : FalseLanesUnknown)
dbgs() << " " << *I;
});
auto HasPredicatedUsers = [this](MachineInstr *MI, const MachineOperand &MO,
SmallPtrSetImpl<MachineInstr *> &Predicated) {
SmallPtrSet<MachineInstr *, 2> Uses;
RDA.getGlobalUses(MI, MO.getReg().asMCReg(), Uses);
for (auto *Use : Uses) {
if (Use != MI && !Predicated.count(Use))
return false;
}
return true;
};
SmallPtrSet<MachineInstr*, 2> NonPredicated;
for (auto *MI : reverse(FalseLanesUnknown)) {
for (auto &MO : MI->operands()) {
if (!isRegInClass(MO, QPRs) || !MO.isDef())
continue;
if (!HasPredicatedUsers(MI, MO, Predicated)) {
LLVM_DEBUG(dbgs() << " Found an unknown def of : "
<< TRI.getRegAsmName(MO.getReg()) << " at " << *MI);
NonPredicated.insert(MI);
break;
}
}
if (!NonPredicated.contains(MI))
Predicated.insert(MI);
}
SmallPtrSet<MachineInstr *, 2> LiveOutMIs;
SmallVector<MachineBasicBlock *, 2> ExitBlocks;
ML.getExitBlocks(ExitBlocks);
assert(ML.getNumBlocks() == 1 && "Expected single block loop!");
assert(ExitBlocks.size() == 1 && "Expected a single exit block");
MachineBasicBlock *ExitBB = ExitBlocks.front();
for (const MachineBasicBlock::RegisterMaskPair &RegMask : ExitBB->liveins()) {
if (RegMask.PhysReg == ARM::VPR) {
LLVM_DEBUG(dbgs() << " VPR is live in to the exit block.");
return false;
}
if (QPRs->contains(RegMask.PhysReg))
if (auto *MI = RDA.getLocalLiveOutMIDef(Header, RegMask.PhysReg))
LiveOutMIs.insert(MI);
}
SmallVector<MachineInstr *> Worklist(LiveOutMIs.begin(), LiveOutMIs.end());
while (!Worklist.empty()) {
MachineInstr *MI = Worklist.pop_back_val();
if (MI->getOpcode() == ARM::MQPRCopy) {
VMOVCopies.insert(MI);
MachineInstr *CopySrc =
RDA.getUniqueReachingMIDef(MI, MI->getOperand(1).getReg());
if (CopySrc)
Worklist.push_back(CopySrc);
} else if (NonPredicated.count(MI) && FalseLanesUnknown.contains(MI)) {
LLVM_DEBUG(dbgs() << " Unable to handle live out: " << *MI);
VMOVCopies.clear();
return false;
}
}
return true;
}
void LowOverheadLoop::Validate(ARMBasicBlockUtils *BBUtils) {
if (Revert)
return;
auto ValidateRanges = [](MachineInstr *Start, MachineInstr *End,
ARMBasicBlockUtils *BBUtils, MachineLoop &ML) {
MachineBasicBlock *TgtBB = End->getOpcode() == ARM::t2LoopEnd
? End->getOperand(1).getMBB()
: End->getOperand(2).getMBB();
if (TgtBB != ML.getHeader()) {
LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targeting header.\n");
return false;
}
if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML.getHeader()) ||
!BBUtils->isBBInRange(End, ML.getHeader(), 4094)) {
LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n");
return false;
}
if (isWhileLoopStart(*Start)) {
MachineBasicBlock *TargetBB = getWhileLoopStartTargetBB(*Start);
if (BBUtils->getOffsetOf(Start) > BBUtils->getOffsetOf(TargetBB) ||
!BBUtils->isBBInRange(Start, TargetBB, 4094)) {
LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n");
return false;
}
}
return true;
};
StartInsertPt = MachineBasicBlock::iterator(Start);
StartInsertBB = Start->getParent();
LLVM_DEBUG(dbgs() << "ARM Loops: Will insert LoopStart at "
<< *StartInsertPt);
Revert = !ValidateRanges(Start, End, BBUtils, ML);
CannotTailPredicate = !ValidateTailPredicate();
}
bool LowOverheadLoop::AddVCTP(MachineInstr *MI) {
LLVM_DEBUG(dbgs() << "ARM Loops: Adding VCTP: " << *MI);
if (VCTPs.empty()) {
VCTPs.push_back(MI);
return true;
}
MachineInstr *Prev = VCTPs.back();
if (!Prev->getOperand(1).isIdenticalTo(MI->getOperand(1)) ||
!RDA.hasSameReachingDef(Prev, MI, MI->getOperand(1).getReg().asMCReg())) {
LLVM_DEBUG(dbgs() << "ARM Loops: Found VCTP with a different reaching "
"definition from the main VCTP");
return false;
}
VCTPs.push_back(MI);
return true;
}
static bool ValidateMVEStore(MachineInstr *MI, MachineLoop *ML) {
auto GetFrameIndex = [](MachineMemOperand *Operand) {
const PseudoSourceValue *PseudoValue = Operand->getPseudoValue();
if (PseudoValue && PseudoValue->kind() == PseudoSourceValue::FixedStack) {
if (const auto *FS = dyn_cast<FixedStackPseudoSourceValue>(PseudoValue)) {
return FS->getFrameIndex();
}
}
return -1;
};
auto IsStackOp = [GetFrameIndex](MachineInstr *I) {
switch (I->getOpcode()) {
case ARM::MVE_VSTRWU32:
case ARM::MVE_VLDRWU32: {
return I->getOperand(1).getReg() == ARM::SP &&
I->memoperands().size() == 1 &&
GetFrameIndex(I->memoperands().front()) >= 0;
}
default:
return false;
}
};
if (MI->getOpcode() != ARM::MVE_VSTRWU32 || !IsStackOp(MI))
return false;
if (MI->memoperands().size() == 0)
return false;
int FI = GetFrameIndex(MI->memoperands().front());
auto &FrameInfo = MI->getParent()->getParent()->getFrameInfo();
if (FI == -1 || !FrameInfo.isSpillSlotObjectIndex(FI))
return false;
SmallVector<MachineBasicBlock *> Frontier;
ML->getExitBlocks(Frontier);
SmallPtrSet<MachineBasicBlock *, 4> Visited{MI->getParent()};
unsigned Idx = 0;
while (Idx < Frontier.size()) {
MachineBasicBlock *BB = Frontier[Idx];
bool LookAtSuccessors = true;
for (auto &I : *BB) {
if (!IsStackOp(&I) || I.memoperands().size() == 0)
continue;
if (GetFrameIndex(I.memoperands().front()) != FI)
continue;
if (I.getOpcode() == ARM::MVE_VSTRWU32) {
LookAtSuccessors = false;
break;
}
if (I.getOpcode() == ARM::MVE_VLDRWU32)
return false;
}
if (LookAtSuccessors) {
for (auto Succ : BB->successors()) {
if (!Visited.contains(Succ) && !is_contained(Frontier, Succ))
Frontier.push_back(Succ);
}
}
Visited.insert(BB);
Idx++;
}
return true;
}
bool LowOverheadLoop::ValidateMVEInst(MachineInstr *MI) {
if (CannotTailPredicate)
return false;
if (!shouldInspect(*MI))
return true;
if (MI->getOpcode() == ARM::MVE_VPSEL ||
MI->getOpcode() == ARM::MVE_VPNOT) {
return false;
}
if (isVCTP(MI) && !AddVCTP(MI))
return false;
const MCInstrDesc &MCID = MI->getDesc();
bool IsUse = false;
unsigned LastOpIdx = MI->getNumOperands() - 1;
for (auto &Op : enumerate(reverse(MCID.operands()))) {
const MachineOperand &MO = MI->getOperand(LastOpIdx - Op.index());
if (!MO.isReg() || !MO.isUse() || MO.getReg() != ARM::VPR)
continue;
if (ARM::isVpred(Op.value().OperandType)) {
VPTState::addInst(MI);
IsUse = true;
} else if (MI->getOpcode() != ARM::MVE_VPST) {
LLVM_DEBUG(dbgs() << "ARM Loops: Found instruction using vpr: " << *MI);
return false;
}
}
bool RequiresExplicitPredication =
(MCID.TSFlags & ARMII::ValidForTailPredication) == 0;
if (isDomainMVE(MI) && RequiresExplicitPredication) {
if (MI->getOpcode() == ARM::MQPRCopy)
return true;
if (!IsUse && producesDoubleWidthResult(*MI)) {
DoubleWidthResultInstrs.insert(MI);
return true;
}
LLVM_DEBUG(if (!IsUse) dbgs()
<< "ARM Loops: Can't tail predicate: " << *MI);
return IsUse;
}
if (MI->mayStore() && !ValidateMVEStore(MI, &ML))
return IsUse;
if (isVectorPredicate(MI)) {
if (!isVectorPredicated(MI))
VPTState::resetPredicate(MI);
else
VPTState::addPredicate(MI);
}
if (isVPTOpcode(MI->getOpcode()))
VPTState::CreateVPTBlock(MI);
return true;
}
bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) {
const ARMSubtarget &ST = mf.getSubtarget<ARMSubtarget>();
if (!ST.hasLOB())
return false;
MF = &mf;
LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n");
MLI = &getAnalysis<MachineLoopInfo>();
RDA = &getAnalysis<ReachingDefAnalysis>();
MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness);
MRI = &MF->getRegInfo();
TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo());
TRI = ST.getRegisterInfo();
BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF));
BBUtils->computeAllBlockSizes();
BBUtils->adjustBBOffsetsAfter(&MF->front());
bool Changed = false;
for (auto ML : *MLI) {
if (ML->isOutermost())
Changed |= ProcessLoop(ML);
}
Changed |= RevertNonLoops();
return Changed;
}
bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
bool Changed = false;
for (MachineLoop *L : *ML)
Changed |= ProcessLoop(L);
LLVM_DEBUG({
dbgs() << "ARM Loops: Processing loop containing:\n";
if (auto *Preheader = ML->getLoopPreheader())
dbgs() << " - Preheader: " << printMBBReference(*Preheader) << "\n";
else if (auto *Preheader = MLI->findLoopPreheader(ML, true, true))
dbgs() << " - Preheader: " << printMBBReference(*Preheader) << "\n";
for (auto *MBB : ML->getBlocks())
dbgs() << " - Block: " << printMBBReference(*MBB) << "\n";
});
std::function<MachineInstr*(MachineBasicBlock*)> SearchForStart =
[&SearchForStart](MachineBasicBlock *MBB) -> MachineInstr* {
for (auto &MI : *MBB) {
if (isLoopStart(MI))
return &MI;
}
if (MBB->pred_size() == 1)
return SearchForStart(*MBB->pred_begin());
return nullptr;
};
LowOverheadLoop LoLoop(*ML, *MLI, *RDA, *TRI, *TII);
if (LoLoop.Preheader)
LoLoop.Start = SearchForStart(LoLoop.Preheader);
else
return Changed;
for (auto *MBB : reverse(ML->getBlocks())) {
for (auto &MI : *MBB) {
if (MI.isDebugValue())
continue;
else if (MI.getOpcode() == ARM::t2LoopDec)
LoLoop.Dec = &MI;
else if (MI.getOpcode() == ARM::t2LoopEnd)
LoLoop.End = &MI;
else if (MI.getOpcode() == ARM::t2LoopEndDec)
LoLoop.End = LoLoop.Dec = &MI;
else if (isLoopStart(MI))
LoLoop.Start = &MI;
else if (MI.getDesc().isCall()) {
LoLoop.Revert = true;
LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n");
} else {
LoLoop.AnalyseMVEInst(&MI);
}
}
}
LLVM_DEBUG(LoLoop.dump());
if (!LoLoop.FoundAllComponents()) {
LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find loop start, update, end\n");
return Changed;
}
assert(LoLoop.Start->getOpcode() != ARM::t2WhileLoopStart &&
"Expected t2WhileLoopStart to be removed before regalloc!");
if (LoLoop.Dec != LoLoop.End) {
SmallPtrSet<MachineInstr *, 2> Uses;
RDA->getReachingLocalUses(LoLoop.Dec, MCRegister::from(ARM::LR), Uses);
if (Uses.size() > 1 || !Uses.count(LoLoop.End)) {
LLVM_DEBUG(dbgs() << "ARM Loops: Unable to remove LoopDec.\n");
LoLoop.Revert = true;
}
}
LoLoop.Validate(BBUtils.get());
Expand(LoLoop);
return true;
}
void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const {
LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp: " << *MI);
MachineBasicBlock *DestBB = getWhileLoopStartTargetBB(*MI);
unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
ARM::tBcc : ARM::t2Bcc;
RevertWhileLoopStartLR(MI, TII, BrOpc);
}
void ARMLowOverheadLoops::RevertDo(MachineInstr *MI) const {
LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to mov: " << *MI);
RevertDoLoopStart(MI, TII);
}
bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI) const {
LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI);
MachineBasicBlock *MBB = MI->getParent();
SmallPtrSet<MachineInstr*, 1> Ignore;
for (auto I = MachineBasicBlock::iterator(MI), E = MBB->end(); I != E; ++I) {
if (I->getOpcode() == ARM::t2LoopEnd) {
Ignore.insert(&*I);
break;
}
}
bool SetFlags =
RDA->isSafeToDefRegAt(MI, MCRegister::from(ARM::CPSR), Ignore);
llvm::RevertLoopDec(MI, TII, SetFlags);
return SetFlags;
}
void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const {
LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp, br: " << *MI);
MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
ARM::tBcc : ARM::t2Bcc;
llvm::RevertLoopEnd(MI, TII, BrOpc, SkipCmp);
}
void ARMLowOverheadLoops::RevertLoopEndDec(MachineInstr *MI) const {
LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to subs, br: " << *MI);
assert(MI->getOpcode() == ARM::t2LoopEndDec && "Expected a t2LoopEndDec!");
MachineBasicBlock *MBB = MI->getParent();
MachineInstrBuilder MIB =
BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(ARM::t2SUBri));
MIB.addDef(ARM::LR);
MIB.add(MI->getOperand(1));
MIB.addImm(1);
MIB.addImm(ARMCC::AL);
MIB.addReg(ARM::NoRegister);
MIB.addReg(ARM::CPSR);
MIB->getOperand(5).setIsDef(true);
MachineBasicBlock *DestBB = MI->getOperand(2).getMBB();
unsigned BrOpc =
BBUtils->isBBInRange(MI, DestBB, 254) ? ARM::tBcc : ARM::t2Bcc;
MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
MIB.add(MI->getOperand(2)); MIB.addImm(ARMCC::NE); MIB.addReg(ARM::CPSR);
MI->eraseFromParent();
}
void ARMLowOverheadLoops::IterationCountDCE(LowOverheadLoop &LoLoop) {
if (!LoLoop.IsTailPredicationLegal())
return;
LLVM_DEBUG(dbgs() << "ARM Loops: Trying DCE on loop iteration count.\n");
MachineInstr *Def = RDA->getMIOperand(LoLoop.Start, 1);
if (!Def) {
LLVM_DEBUG(dbgs() << "ARM Loops: Couldn't find iteration count.\n");
return;
}
SmallPtrSet<MachineInstr*, 4> Killed = { LoLoop.Start, LoLoop.Dec,
LoLoop.End };
if (!TryRemove(Def, *RDA, LoLoop.ToRemove, Killed))
LLVM_DEBUG(dbgs() << "ARM Loops: Unsafe to remove loop iteration count.\n");
}
MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) {
LLVM_DEBUG(dbgs() << "ARM Loops: Expanding LoopStart.\n");
IterationCountDCE(LoLoop);
MachineBasicBlock::iterator InsertPt = LoLoop.StartInsertPt;
MachineInstr *Start = LoLoop.Start;
MachineBasicBlock *MBB = LoLoop.StartInsertBB;
unsigned Opc = LoLoop.getStartOpcode();
MachineOperand &Count = LoLoop.getLoopStartOperand();
MachineInstr* NewStart;
if (Opc == ARM::t2DLS && Count.isReg() && Count.getReg() == ARM::LR) {
LLVM_DEBUG(dbgs() << "ARM Loops: Didn't insert start: DLS lr, lr");
NewStart = nullptr;
} else {
MachineInstrBuilder MIB =
BuildMI(*MBB, InsertPt, Start->getDebugLoc(), TII->get(Opc));
MIB.addDef(ARM::LR);
MIB.add(Count);
if (isWhileLoopStart(*Start))
MIB.addMBB(getWhileLoopStartTargetBB(*Start));
LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB);
NewStart = &*MIB;
}
LoLoop.ToRemove.insert(Start);
return NewStart;
}
void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) {
auto RemovePredicate = [](MachineInstr *MI) {
if (MI->isDebugInstr())
return;
LLVM_DEBUG(dbgs() << "ARM Loops: Removing predicate from: " << *MI);
int PIdx = llvm::findFirstVPTPredOperandIdx(*MI);
assert(PIdx >= 1 && "Trying to unpredicate a non-predicated instruction");
assert(MI->getOperand(PIdx).getImm() == ARMVCC::Then &&
"Expected Then predicate!");
MI->getOperand(PIdx).setImm(ARMVCC::None);
MI->getOperand(PIdx + 1).setReg(0);
};
for (auto &Block : LoLoop.getVPTBlocks()) {
SmallVectorImpl<MachineInstr *> &Insts = Block.getInsts();
auto ReplaceVCMPWithVPT = [&](MachineInstr *&TheVCMP, MachineInstr *At) {
assert(TheVCMP && "Replacing a removed or non-existent VCMP");
MachineInstrBuilder MIB =
BuildMI(*At->getParent(), At, At->getDebugLoc(),
TII->get(VCMPOpcodeToVPT(TheVCMP->getOpcode())));
MIB.addImm(ARMVCC::Then);
MIB.add(TheVCMP->getOperand(1));
MIB.add(TheVCMP->getOperand(2));
MIB.add(TheVCMP->getOperand(3));
LLVM_DEBUG(dbgs() << "ARM Loops: Combining with VCMP to VPT: " << *MIB);
LoLoop.BlockMasksToRecompute.insert(MIB.getInstr());
LoLoop.ToRemove.insert(TheVCMP);
TheVCMP = nullptr;
};
if (VPTState::isEntryPredicatedOnVCTP(Block, true)) {
MachineInstr *VPST = Insts.front();
if (VPTState::hasUniformPredicate(Block)) {
LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *VPST);
for (unsigned i = 1; i < Insts.size(); ++i)
RemovePredicate(Insts[i]);
} else {
MachineInstr *Divergent = VPTState::getDivergent(Block);
MachineBasicBlock *MBB = Divergent->getParent();
auto DivergentNext = ++MachineBasicBlock::iterator(Divergent);
while (DivergentNext != MBB->end() && DivergentNext->isDebugInstr())
++DivergentNext;
bool DivergentNextIsPredicated =
DivergentNext != MBB->end() &&
getVPTInstrPredicate(*DivergentNext) != ARMVCC::None;
for (auto I = ++MachineBasicBlock::iterator(VPST), E = DivergentNext;
I != E; ++I)
RemovePredicate(&*I);
MachineInstr *VCMP =
VCMPOpcodeToVPT(Divergent->getOpcode()) != 0 ? Divergent : nullptr;
if (DivergentNextIsPredicated) {
if (!VCMP) {
MachineInstrBuilder MIB =
BuildMI(*Divergent->getParent(), Divergent,
Divergent->getDebugLoc(), TII->get(ARM::MVE_VPST));
MIB.addImm(0);
LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB);
LoLoop.BlockMasksToRecompute.insert(MIB.getInstr());
} else {
ReplaceVCMPWithVPT(VCMP, VCMP);
}
}
}
LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *VPST);
LoLoop.ToRemove.insert(VPST);
} else if (Block.containsVCTP()) {
MachineInstr *VPST = Insts.front();
if (Block.size() == 2) {
assert(VPST->getOpcode() == ARM::MVE_VPST &&
"Found a VPST in an otherwise empty vpt block");
LoLoop.ToRemove.insert(VPST);
} else
LoLoop.BlockMasksToRecompute.insert(VPST);
} else if (Insts.front()->getOpcode() == ARM::MVE_VPST) {
MachineInstr *VPST = Insts.front();
auto Next = ++MachineBasicBlock::iterator(VPST);
assert(getVPTInstrPredicate(*Next) != ARMVCC::None &&
"The instruction after a VPST must be predicated");
(void)Next;
MachineInstr *VprDef = RDA->getUniqueReachingMIDef(VPST, ARM::VPR);
if (VprDef && VCMPOpcodeToVPT(VprDef->getOpcode()) &&
!LoLoop.ToRemove.contains(VprDef)) {
MachineInstr *VCMP = VprDef;
if (std::none_of(++MachineBasicBlock::iterator(VCMP),
MachineBasicBlock::iterator(VPST), hasVPRUse) &&
RDA->hasSameReachingDef(VCMP, VPST, VCMP->getOperand(1).getReg()) &&
RDA->hasSameReachingDef(VCMP, VPST, VCMP->getOperand(2).getReg())) {
ReplaceVCMPWithVPT(VCMP, VPST);
LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *VPST);
LoLoop.ToRemove.insert(VPST);
}
}
}
}
LoLoop.ToRemove.insert(LoLoop.VCTPs.begin(), LoLoop.VCTPs.end());
}
void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
auto ExpandLoopEnd = [this](LowOverheadLoop &LoLoop) {
MachineInstr *End = LoLoop.End;
MachineBasicBlock *MBB = End->getParent();
unsigned Opc = LoLoop.IsTailPredicationLegal() ?
ARM::MVE_LETP : ARM::t2LEUpdate;
MachineInstrBuilder MIB = BuildMI(*MBB, End, End->getDebugLoc(),
TII->get(Opc));
MIB.addDef(ARM::LR);
unsigned Off = LoLoop.Dec == LoLoop.End ? 1 : 0;
MIB.add(End->getOperand(Off + 0));
MIB.add(End->getOperand(Off + 1));
LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB);
LoLoop.ToRemove.insert(LoLoop.Dec);
LoLoop.ToRemove.insert(End);
return &*MIB;
};
auto RemoveDeadBranch = [](MachineInstr *I) {
MachineBasicBlock *BB = I->getParent();
MachineInstr *Terminator = &BB->instr_back();
if (Terminator->isUnconditionalBranch() && I != Terminator) {
MachineBasicBlock *Succ = Terminator->getOperand(0).getMBB();
if (BB->isLayoutSuccessor(Succ)) {
LLVM_DEBUG(dbgs() << "ARM Loops: Removing branch: " << *Terminator);
Terminator->eraseFromParent();
}
}
};
auto ExpandVMOVCopies = [this](SmallPtrSet<MachineInstr *, 4> &VMOVCopies) {
for (auto *MI : VMOVCopies) {
LLVM_DEBUG(dbgs() << "Converting copy to VMOVD: " << *MI);
assert(MI->getOpcode() == ARM::MQPRCopy && "Only expected MQPRCOPY!");
MachineBasicBlock *MBB = MI->getParent();
Register Dst = MI->getOperand(0).getReg();
Register Src = MI->getOperand(1).getReg();
auto MIB1 = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(ARM::VMOVD),
ARM::D0 + (Dst - ARM::Q0) * 2)
.addReg(ARM::D0 + (Src - ARM::Q0) * 2)
.add(predOps(ARMCC::AL));
(void)MIB1;
LLVM_DEBUG(dbgs() << " into " << *MIB1);
auto MIB2 = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(ARM::VMOVD),
ARM::D0 + (Dst - ARM::Q0) * 2 + 1)
.addReg(ARM::D0 + (Src - ARM::Q0) * 2 + 1)
.add(predOps(ARMCC::AL));
LLVM_DEBUG(dbgs() << " and " << *MIB2);
(void)MIB2;
MI->eraseFromParent();
}
};
if (LoLoop.Revert) {
if (isWhileLoopStart(*LoLoop.Start))
RevertWhile(LoLoop.Start);
else
RevertDo(LoLoop.Start);
if (LoLoop.Dec == LoLoop.End)
RevertLoopEndDec(LoLoop.End);
else
RevertLoopEnd(LoLoop.End, RevertLoopDec(LoLoop.Dec));
} else {
ExpandVMOVCopies(LoLoop.VMOVCopies);
LoLoop.Start = ExpandLoopStart(LoLoop);
if (LoLoop.Start)
RemoveDeadBranch(LoLoop.Start);
LoLoop.End = ExpandLoopEnd(LoLoop);
RemoveDeadBranch(LoLoop.End);
if (LoLoop.IsTailPredicationLegal())
ConvertVPTBlocks(LoLoop);
for (auto *I : LoLoop.ToRemove) {
LLVM_DEBUG(dbgs() << "ARM Loops: Erasing " << *I);
I->eraseFromParent();
}
for (auto *I : LoLoop.BlockMasksToRecompute) {
LLVM_DEBUG(dbgs() << "ARM Loops: Recomputing VPT/VPST Block Mask: " << *I);
recomputeVPTBlockMask(*I);
LLVM_DEBUG(dbgs() << " ... done: " << *I);
}
}
PostOrderLoopTraversal DFS(LoLoop.ML, *MLI);
DFS.ProcessLoop();
const SmallVectorImpl<MachineBasicBlock*> &PostOrder = DFS.getOrder();
for (auto *MBB : PostOrder) {
recomputeLiveIns(*MBB);
MBB->sortUniqueLiveIns();
}
for (auto *MBB : reverse(PostOrder))
recomputeLivenessFlags(*MBB);
RDA->reset();
}
bool ARMLowOverheadLoops::RevertNonLoops() {
LLVM_DEBUG(dbgs() << "ARM Loops: Reverting any remaining pseudos...\n");
bool Changed = false;
for (auto &MBB : *MF) {
SmallVector<MachineInstr*, 4> Starts;
SmallVector<MachineInstr*, 4> Decs;
SmallVector<MachineInstr*, 4> Ends;
SmallVector<MachineInstr *, 4> EndDecs;
SmallVector<MachineInstr *, 4> MQPRCopies;
for (auto &I : MBB) {
if (isLoopStart(I))
Starts.push_back(&I);
else if (I.getOpcode() == ARM::t2LoopDec)
Decs.push_back(&I);
else if (I.getOpcode() == ARM::t2LoopEnd)
Ends.push_back(&I);
else if (I.getOpcode() == ARM::t2LoopEndDec)
EndDecs.push_back(&I);
else if (I.getOpcode() == ARM::MQPRCopy)
MQPRCopies.push_back(&I);
}
if (Starts.empty() && Decs.empty() && Ends.empty() && EndDecs.empty() &&
MQPRCopies.empty())
continue;
Changed = true;
for (auto *Start : Starts) {
if (isWhileLoopStart(*Start))
RevertWhile(Start);
else
RevertDo(Start);
}
for (auto *Dec : Decs)
RevertLoopDec(Dec);
for (auto *End : Ends)
RevertLoopEnd(End);
for (auto *End : EndDecs)
RevertLoopEndDec(End);
for (auto *MI : MQPRCopies) {
LLVM_DEBUG(dbgs() << "Converting copy to VORR: " << *MI);
assert(MI->getOpcode() == ARM::MQPRCopy && "Only expected MQPRCOPY!");
MachineBasicBlock *MBB = MI->getParent();
auto MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(ARM::MVE_VORR),
MI->getOperand(0).getReg())
.add(MI->getOperand(1))
.add(MI->getOperand(1));
addUnpredicatedMveVpredROp(MIB, MI->getOperand(0).getReg());
MI->eraseFromParent();
}
}
return Changed;
}
FunctionPass *llvm::createARMLowOverheadLoopsPass() {
return new ARMLowOverheadLoops();
}