#include "AArch64TargetMachine.h"
#include "llvm/CodeGen/GlobalISel/CSEInfo.h"
#include "llvm/CodeGen/GlobalISel/Combiner.h"
#include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
#include "llvm/CodeGen/GlobalISel/CombinerInfo.h"
#include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
#include "llvm/CodeGen/GlobalISel/GISelKnownBits.h"
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
#include "llvm/CodeGen/GlobalISel/Utils.h"
#include "llvm/CodeGen/MachineDominators.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/TargetOpcodes.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "aarch64-postlegalizer-combiner"
using namespace llvm;
using namespace MIPatternMatch;
bool matchExtractVecEltPairwiseAdd(
MachineInstr &MI, MachineRegisterInfo &MRI,
std::tuple<unsigned, LLT, Register> &MatchInfo) {
Register Src1 = MI.getOperand(1).getReg();
Register Src2 = MI.getOperand(2).getReg();
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
auto Cst = getIConstantVRegValWithLookThrough(Src2, MRI);
if (!Cst || Cst->Value != 0)
return false;
auto *FAddMI = getOpcodeDef(TargetOpcode::G_FADD, Src1, MRI);
if (!FAddMI)
return false;
unsigned DstSize = DstTy.getSizeInBits();
if (DstSize != 16 && DstSize != 32 && DstSize != 64)
return false;
Register Src1Op1 = FAddMI->getOperand(1).getReg();
Register Src1Op2 = FAddMI->getOperand(2).getReg();
MachineInstr *Shuffle =
getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op2, MRI);
MachineInstr *Other = MRI.getVRegDef(Src1Op1);
if (!Shuffle) {
Shuffle = getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op1, MRI);
Other = MRI.getVRegDef(Src1Op2);
}
if (Shuffle && Shuffle->getOperand(3).getShuffleMask()[0] == 1 &&
Other == MRI.getVRegDef(Shuffle->getOperand(1).getReg())) {
std::get<0>(MatchInfo) = TargetOpcode::G_FADD;
std::get<1>(MatchInfo) = DstTy;
std::get<2>(MatchInfo) = Other->getOperand(0).getReg();
return true;
}
return false;
}
bool applyExtractVecEltPairwiseAdd(
MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B,
std::tuple<unsigned, LLT, Register> &MatchInfo) {
unsigned Opc = std::get<0>(MatchInfo);
assert(Opc == TargetOpcode::G_FADD && "Unexpected opcode!");
LLT Ty = std::get<1>(MatchInfo);
Register Src = std::get<2>(MatchInfo);
LLT s64 = LLT::scalar(64);
B.setInstrAndDebugLoc(MI);
auto Elt0 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 0));
auto Elt1 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 1));
B.buildInstr(Opc, {MI.getOperand(0).getReg()}, {Elt0, Elt1});
MI.eraseFromParent();
return true;
}
static bool isSignExtended(Register R, MachineRegisterInfo &MRI) {
unsigned Opc = MRI.getVRegDef(R)->getOpcode();
return Opc == TargetOpcode::G_SEXT || Opc == TargetOpcode::G_SEXT_INREG;
}
static bool isZeroExtended(Register R, MachineRegisterInfo &MRI) {
return MRI.getVRegDef(R)->getOpcode() == TargetOpcode::G_ZEXT;
}
bool matchAArch64MulConstCombine(
MachineInstr &MI, MachineRegisterInfo &MRI,
std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) {
assert(MI.getOpcode() == TargetOpcode::G_MUL);
Register LHS = MI.getOperand(1).getReg();
Register RHS = MI.getOperand(2).getReg();
Register Dst = MI.getOperand(0).getReg();
const LLT Ty = MRI.getType(LHS);
auto Const = getIConstantVRegValWithLookThrough(RHS, MRI);
if (!Const)
return false;
APInt ConstValue = Const->Value.sext(Ty.getSizeInBits());
unsigned TrailingZeroes = ConstValue.countTrailingZeros();
if (TrailingZeroes) {
if (MRI.hasOneNonDBGUse(LHS) &&
(isSignExtended(LHS, MRI) || isZeroExtended(LHS, MRI)))
return false;
if (MRI.hasOneNonDBGUse(Dst)) {
MachineInstr &UseMI = *MRI.use_instr_begin(Dst);
unsigned UseOpc = UseMI.getOpcode();
if (UseOpc == TargetOpcode::G_ADD || UseOpc == TargetOpcode::G_PTR_ADD ||
UseOpc == TargetOpcode::G_SUB)
return false;
}
}
APInt ShiftedConstValue = ConstValue.ashr(TrailingZeroes);
unsigned ShiftAmt, AddSubOpc;
bool ShiftValUseIsLHS = true;
bool NegateResult = false;
if (ConstValue.isNonNegative()) {
APInt SCVMinus1 = ShiftedConstValue - 1;
APInt CVPlus1 = ConstValue + 1;
if (SCVMinus1.isPowerOf2()) {
ShiftAmt = SCVMinus1.logBase2();
AddSubOpc = TargetOpcode::G_ADD;
} else if (CVPlus1.isPowerOf2()) {
ShiftAmt = CVPlus1.logBase2();
AddSubOpc = TargetOpcode::G_SUB;
} else
return false;
} else {
APInt CVNegPlus1 = -ConstValue + 1;
APInt CVNegMinus1 = -ConstValue - 1;
if (CVNegPlus1.isPowerOf2()) {
ShiftAmt = CVNegPlus1.logBase2();
AddSubOpc = TargetOpcode::G_SUB;
ShiftValUseIsLHS = false;
} else if (CVNegMinus1.isPowerOf2()) {
ShiftAmt = CVNegMinus1.logBase2();
AddSubOpc = TargetOpcode::G_ADD;
NegateResult = true;
} else
return false;
}
if (NegateResult && TrailingZeroes)
return false;
ApplyFn = [=](MachineIRBuilder &B, Register DstReg) {
auto Shift = B.buildConstant(LLT::scalar(64), ShiftAmt);
auto ShiftedVal = B.buildShl(Ty, LHS, Shift);
Register AddSubLHS = ShiftValUseIsLHS ? ShiftedVal.getReg(0) : LHS;
Register AddSubRHS = ShiftValUseIsLHS ? LHS : ShiftedVal.getReg(0);
auto Res = B.buildInstr(AddSubOpc, {Ty}, {AddSubLHS, AddSubRHS});
assert(!(NegateResult && TrailingZeroes) &&
"NegateResult and TrailingZeroes cannot both be true for now.");
if (NegateResult) {
B.buildSub(DstReg, B.buildConstant(Ty, 0), Res);
return;
}
if (TrailingZeroes) {
B.buildShl(DstReg, Res, B.buildConstant(LLT::scalar(64), TrailingZeroes));
return;
}
B.buildCopy(DstReg, Res.getReg(0));
};
return true;
}
bool applyAArch64MulConstCombine(
MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B,
std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) {
B.setInstrAndDebugLoc(MI);
ApplyFn(B, MI.getOperand(0).getReg());
MI.eraseFromParent();
return true;
}
bool matchFoldMergeToZext(MachineInstr &MI, MachineRegisterInfo &MRI) {
auto &Merge = cast<GMerge>(MI);
LLT SrcTy = MRI.getType(Merge.getSourceReg(0));
if (SrcTy != LLT::scalar(32) || Merge.getNumSources() != 2)
return false;
return mi_match(Merge.getSourceReg(1), MRI, m_SpecificICst(0));
}
void applyFoldMergeToZext(MachineInstr &MI, MachineRegisterInfo &MRI,
MachineIRBuilder &B, GISelChangeObserver &Observer) {
Observer.changingInstr(MI);
MI.setDesc(B.getTII().get(TargetOpcode::G_ZEXT));
MI.removeOperand(2);
Observer.changedInstr(MI);
}
static bool matchMutateAnyExtToZExt(MachineInstr &MI, MachineRegisterInfo &MRI) {
assert(MI.getOpcode() == TargetOpcode::G_ANYEXT);
Register Dst = MI.getOperand(0).getReg();
Register Src = MI.getOperand(1).getReg();
return MRI.getType(Dst).isScalar() &&
mi_match(Src, MRI,
m_any_of(m_GICmp(m_Pred(), m_Reg(), m_Reg()),
m_GFCmp(m_Pred(), m_Reg(), m_Reg())));
}
static void applyMutateAnyExtToZExt(MachineInstr &MI, MachineRegisterInfo &MRI,
MachineIRBuilder &B,
GISelChangeObserver &Observer) {
Observer.changingInstr(MI);
MI.setDesc(B.getTII().get(TargetOpcode::G_ZEXT));
Observer.changedInstr(MI);
}
static bool matchSplitStoreZero128(MachineInstr &MI, MachineRegisterInfo &MRI) {
GStore &Store = cast<GStore>(MI);
if (!Store.isSimple())
return false;
LLT ValTy = MRI.getType(Store.getValueReg());
if (!ValTy.isVector() || ValTy.getSizeInBits() != 128)
return false;
if (ValTy.getSizeInBits() != Store.getMemSizeInBits())
return false; if (!MRI.hasOneNonDBGUse(Store.getValueReg()))
return false;
auto MaybeCst = isConstantOrConstantSplatVector(
*MRI.getVRegDef(Store.getValueReg()), MRI);
return MaybeCst && MaybeCst->isZero();
}
static void applySplitStoreZero128(MachineInstr &MI, MachineRegisterInfo &MRI,
MachineIRBuilder &B,
GISelChangeObserver &Observer) {
B.setInstrAndDebugLoc(MI);
GStore &Store = cast<GStore>(MI);
assert(MRI.getType(Store.getValueReg()).isVector() &&
"Expected a vector store value");
LLT NewTy = LLT::scalar(64);
Register PtrReg = Store.getPointerReg();
auto Zero = B.buildConstant(NewTy, 0);
auto HighPtr = B.buildPtrAdd(MRI.getType(PtrReg), PtrReg,
B.buildConstant(LLT::scalar(64), 8));
auto &MF = *MI.getMF();
auto *LowMMO = MF.getMachineMemOperand(&Store.getMMO(), 0, NewTy);
auto *HighMMO = MF.getMachineMemOperand(&Store.getMMO(), 8, NewTy);
B.buildStore(Zero, PtrReg, *LowMMO);
B.buildStore(Zero, HighPtr, *HighMMO);
Store.eraseFromParent();
}
#define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS
#include "AArch64GenPostLegalizeGICombiner.inc"
#undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS
namespace {
#define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_H
#include "AArch64GenPostLegalizeGICombiner.inc"
#undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_H
class AArch64PostLegalizerCombinerInfo : public CombinerInfo {
GISelKnownBits *KB;
MachineDominatorTree *MDT;
public:
AArch64GenPostLegalizerCombinerHelperRuleConfig GeneratedRuleCfg;
AArch64PostLegalizerCombinerInfo(bool EnableOpt, bool OptSize, bool MinSize,
GISelKnownBits *KB,
MachineDominatorTree *MDT)
: CombinerInfo( true, false,
nullptr, EnableOpt, OptSize, MinSize),
KB(KB), MDT(MDT) {
if (!GeneratedRuleCfg.parseCommandLineOption())
report_fatal_error("Invalid rule identifier");
}
bool combine(GISelChangeObserver &Observer, MachineInstr &MI,
MachineIRBuilder &B) const override;
};
bool AArch64PostLegalizerCombinerInfo::combine(GISelChangeObserver &Observer,
MachineInstr &MI,
MachineIRBuilder &B) const {
const auto *LI =
MI.getParent()->getParent()->getSubtarget().getLegalizerInfo();
CombinerHelper Helper(Observer, B, KB, MDT, LI);
AArch64GenPostLegalizerCombinerHelper Generated(GeneratedRuleCfg);
return Generated.tryCombineAll(Observer, MI, B, Helper);
}
#define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_CPP
#include "AArch64GenPostLegalizeGICombiner.inc"
#undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_CPP
class AArch64PostLegalizerCombiner : public MachineFunctionPass {
public:
static char ID;
AArch64PostLegalizerCombiner(bool IsOptNone = false);
StringRef getPassName() const override {
return "AArch64PostLegalizerCombiner";
}
bool runOnMachineFunction(MachineFunction &MF) override;
void getAnalysisUsage(AnalysisUsage &AU) const override;
private:
bool IsOptNone;
};
}
void AArch64PostLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const {
AU.addRequired<TargetPassConfig>();
AU.setPreservesCFG();
getSelectionDAGFallbackAnalysisUsage(AU);
AU.addRequired<GISelKnownBitsAnalysis>();
AU.addPreserved<GISelKnownBitsAnalysis>();
if (!IsOptNone) {
AU.addRequired<MachineDominatorTree>();
AU.addPreserved<MachineDominatorTree>();
AU.addRequired<GISelCSEAnalysisWrapperPass>();
AU.addPreserved<GISelCSEAnalysisWrapperPass>();
}
MachineFunctionPass::getAnalysisUsage(AU);
}
AArch64PostLegalizerCombiner::AArch64PostLegalizerCombiner(bool IsOptNone)
: MachineFunctionPass(ID), IsOptNone(IsOptNone) {
initializeAArch64PostLegalizerCombinerPass(*PassRegistry::getPassRegistry());
}
bool AArch64PostLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) {
if (MF.getProperties().hasProperty(
MachineFunctionProperties::Property::FailedISel))
return false;
assert(MF.getProperties().hasProperty(
MachineFunctionProperties::Property::Legalized) &&
"Expected a legalized function?");
auto *TPC = &getAnalysis<TargetPassConfig>();
const Function &F = MF.getFunction();
bool EnableOpt =
MF.getTarget().getOptLevel() != CodeGenOpt::None && !skipFunction(F);
GISelKnownBits *KB = &getAnalysis<GISelKnownBitsAnalysis>().get(MF);
MachineDominatorTree *MDT =
IsOptNone ? nullptr : &getAnalysis<MachineDominatorTree>();
AArch64PostLegalizerCombinerInfo PCInfo(EnableOpt, F.hasOptSize(),
F.hasMinSize(), KB, MDT);
GISelCSEAnalysisWrapper &Wrapper =
getAnalysis<GISelCSEAnalysisWrapperPass>().getCSEWrapper();
auto *CSEInfo = &Wrapper.get(TPC->getCSEConfig());
Combiner C(PCInfo, TPC);
return C.combineMachineInstrs(MF, CSEInfo);
}
char AArch64PostLegalizerCombiner::ID = 0;
INITIALIZE_PASS_BEGIN(AArch64PostLegalizerCombiner, DEBUG_TYPE,
"Combine AArch64 MachineInstrs after legalization", false,
false)
INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
INITIALIZE_PASS_DEPENDENCY(GISelKnownBitsAnalysis)
INITIALIZE_PASS_END(AArch64PostLegalizerCombiner, DEBUG_TYPE,
"Combine AArch64 MachineInstrs after legalization", false,
false)
namespace llvm {
FunctionPass *createAArch64PostLegalizerCombiner(bool IsOptNone) {
return new AArch64PostLegalizerCombiner(IsOptNone);
}
}