#include "ARM.h"
#include "ARMBaseInstrInfo.h"
#include "ARMSubtarget.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
#include "llvm/InitializePasses.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsARM.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Transforms/Utils/Local.h"
#include <algorithm>
#include <cassert>
using namespace llvm;
#define DEBUG_TYPE "arm-mve-gather-scatter-lowering"
cl::opt<bool> EnableMaskedGatherScatters(
"enable-arm-maskedgatscat", cl::Hidden, cl::init(true),
cl::desc("Enable the generation of masked gathers and scatters"));
namespace {
class MVEGatherScatterLowering : public FunctionPass {
public:
static char ID;
explicit MVEGatherScatterLowering() : FunctionPass(ID) {
initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry());
}
bool runOnFunction(Function &F) override;
StringRef getPassName() const override {
return "MVE gather/scatter lowering";
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
AU.addRequired<TargetPassConfig>();
AU.addRequired<LoopInfoWrapperPass>();
FunctionPass::getAnalysisUsage(AU);
}
private:
LoopInfo *LI = nullptr;
const DataLayout *DL;
bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
Align Alignment);
void lookThroughBitcast(Value *&Ptr);
Value *decomposePtr(Value *Ptr, Value *&Offsets, int &Scale,
FixedVectorType *Ty, Type *MemoryTy,
IRBuilder<> &Builder);
Value *decomposeGEP(Value *&Offsets, FixedVectorType *Ty,
GetElementPtrInst *GEP, IRBuilder<> &Builder);
int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize);
Optional<int64_t> getIfConst(const Value *V);
std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale);
Instruction *lowerGather(IntrinsicInst *I);
Instruction *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
Instruction *&Root,
IRBuilder<> &Builder);
Instruction *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
IRBuilder<> &Builder,
int64_t Increment = 0);
Instruction *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr,
IRBuilder<> &Builder,
int64_t Increment = 0);
Instruction *lowerScatter(IntrinsicInst *I);
Instruction *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets,
IRBuilder<> &Builder);
Instruction *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr,
IRBuilder<> &Builder,
int64_t Increment = 0);
Instruction *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr,
IRBuilder<> &Builder,
int64_t Increment = 0);
Instruction *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *Ptr,
IRBuilder<> &Builder);
Instruction *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr,
Value *Ptr, unsigned TypeScale,
IRBuilder<> &Builder);
bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI);
Value *foldGEP(GetElementPtrInst *GEP, Value *&Offsets, unsigned &Scale,
IRBuilder<> &Builder);
bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);
void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
void pushOutMulShl(unsigned Opc, PHINode *&Phi, Value *IncrementPerRound,
Value *OffsSecondOperand, unsigned LoopIncrement,
IRBuilder<> &Builder);
};
}
char MVEGatherScatterLowering::ID = 0;
INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE,
"MVE gather/scattering lowering pass", false, false)
Pass *llvm::createMVEGatherScatterLoweringPass() {
return new MVEGatherScatterLowering();
}
bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
unsigned ElemSize,
Align Alignment) {
if (((NumElements == 4 &&
(ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||
(NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||
(NumElements == 16 && ElemSize == 8)) &&
Alignment >= ElemSize / 8)
return true;
LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have "
<< "valid alignment or vector type \n");
return false;
}
static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount) {
unsigned TargetElemSize = 128 / TargetElemCount;
unsigned OffsetElemSize = cast<FixedVectorType>(Offsets->getType())
->getElementType()
->getScalarSizeInBits();
if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) {
Constant *ConstOff = dyn_cast<Constant>(Offsets);
if (!ConstOff)
return false;
int64_t TargetElemMaxSize = (1ULL << TargetElemSize);
auto CheckValueSize = [TargetElemMaxSize](Value *OffsetElem) {
ConstantInt *OConst = dyn_cast<ConstantInt>(OffsetElem);
if (!OConst)
return false;
int SExtValue = OConst->getSExtValue();
if (SExtValue >= TargetElemMaxSize || SExtValue < 0)
return false;
return true;
};
if (isa<FixedVectorType>(ConstOff->getType())) {
for (unsigned i = 0; i < TargetElemCount; i++) {
if (!CheckValueSize(ConstOff->getAggregateElement(i)))
return false;
}
} else {
if (!CheckValueSize(ConstOff))
return false;
}
}
return true;
}
Value *MVEGatherScatterLowering::decomposePtr(Value *Ptr, Value *&Offsets,
int &Scale, FixedVectorType *Ty,
Type *MemoryTy,
IRBuilder<> &Builder) {
if (auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) {
if (Value *V = decomposeGEP(Offsets, Ty, GEP, Builder)) {
Scale =
computeScale(GEP->getSourceElementType()->getPrimitiveSizeInBits(),
MemoryTy->getScalarSizeInBits());
return Scale == -1 ? nullptr : V;
}
}
FixedVectorType *PtrTy = cast<FixedVectorType>(Ptr->getType());
if (PtrTy->getNumElements() != 4 || MemoryTy->getScalarSizeInBits() == 32)
return nullptr;
Value *Zero = ConstantInt::get(Builder.getInt32Ty(), 0);
Value *BasePtr = Builder.CreateIntToPtr(Zero, Builder.getInt8PtrTy());
Offsets = Builder.CreatePtrToInt(
Ptr, FixedVectorType::get(Builder.getInt32Ty(), 4));
Scale = 0;
return BasePtr;
}
Value *MVEGatherScatterLowering::decomposeGEP(Value *&Offsets,
FixedVectorType *Ty,
GetElementPtrInst *GEP,
IRBuilder<> &Builder) {
if (!GEP) {
LLVM_DEBUG(dbgs() << "masked gathers/scatters: no getelementpointer "
<< "found\n");
return nullptr;
}
LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found."
<< " Looking at intrinsic for base + vector of offsets\n");
Value *GEPPtr = GEP->getPointerOperand();
Offsets = GEP->getOperand(1);
if (GEPPtr->getType()->isVectorTy() ||
!isa<FixedVectorType>(Offsets->getType()))
return nullptr;
if (GEP->getNumOperands() != 2) {
LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many"
<< " operands. Expanding.\n");
return nullptr;
}
Offsets = GEP->getOperand(1);
unsigned OffsetsElemCount =
cast<FixedVectorType>(Offsets->getType())->getNumElements();
assert(Ty->getNumElements() == OffsetsElemCount);
ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets);
if (ZextOffs)
Offsets = ZextOffs->getOperand(0);
FixedVectorType *OffsetType = cast<FixedVectorType>(Offsets->getType());
if (!ZextOffs || cast<FixedVectorType>(ZextOffs->getDestTy())
->getElementType()
->getScalarSizeInBits() != 32)
if (!checkOffsetSize(Offsets, OffsetsElemCount))
return nullptr;
if (Ty != Offsets->getType()) {
if ((Ty->getElementType()->getScalarSizeInBits() <
OffsetType->getElementType()->getScalarSizeInBits())) {
Offsets = Builder.CreateTrunc(Offsets, Ty);
} else {
Offsets = Builder.CreateZExt(Offsets, VectorType::getInteger(Ty));
}
}
LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n");
return GEPPtr;
}
void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {
if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
auto *BCTy = cast<FixedVectorType>(BitCast->getType());
auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType());
if (BCTy->getNumElements() == BCSrcTy->getNumElements()) {
LLVM_DEBUG(dbgs() << "masked gathers/scatters: looking through "
<< "bitcast\n");
Ptr = BitCast->getOperand(0);
}
}
}
int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize,
unsigned MemoryElemSize) {
if (GEPElemSize == 32 && MemoryElemSize == 32)
return 2;
else if (GEPElemSize == 16 && MemoryElemSize == 16)
return 1;
else if (GEPElemSize == 8)
return 0;
LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't "
<< "create intrinsic\n");
return -1;
}
Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
const Constant *C = dyn_cast<Constant>(V);
if (C && C->getSplatValue())
return Optional<int64_t>{C->getUniqueInteger().getSExtValue()};
if (!isa<Instruction>(V))
return Optional<int64_t>{};
const Instruction *I = cast<Instruction>(V);
if (I->getOpcode() == Instruction::Add || I->getOpcode() == Instruction::Or ||
I->getOpcode() == Instruction::Mul ||
I->getOpcode() == Instruction::Shl) {
Optional<int64_t> Op0 = getIfConst(I->getOperand(0));
Optional<int64_t> Op1 = getIfConst(I->getOperand(1));
if (!Op0 || !Op1)
return Optional<int64_t>{};
if (I->getOpcode() == Instruction::Add)
return Optional<int64_t>{Op0.value() + Op1.value()};
if (I->getOpcode() == Instruction::Mul)
return Optional<int64_t>{Op0.value() * Op1.value()};
if (I->getOpcode() == Instruction::Shl)
return Optional<int64_t>{Op0.value() << Op1.value()};
if (I->getOpcode() == Instruction::Or)
return Optional<int64_t>{Op0.value() | Op1.value()};
}
return Optional<int64_t>{};
}
static bool isAddLikeOr(Instruction *I, const DataLayout &DL) {
return I->getOpcode() == Instruction::Or &&
haveNoCommonBitsSet(I->getOperand(0), I->getOperand(1), DL);
}
std::pair<Value *, int64_t>
MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) {
std::pair<Value *, int64_t> ReturnFalse =
std::pair<Value *, int64_t>(nullptr, 0);
Instruction *Add = dyn_cast<Instruction>(Inst);
if (Add == nullptr ||
(Add->getOpcode() != Instruction::Add && !isAddLikeOr(Add, *DL)))
return ReturnFalse;
Value *Summand;
Optional<int64_t> Const;
if ((Const = getIfConst(Add->getOperand(0))))
Summand = Add->getOperand(1);
else if ((Const = getIfConst(Add->getOperand(1))))
Summand = Add->getOperand(0);
else
return ReturnFalse;
int64_t Immediate = *Const << TypeScale;
if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0)
return ReturnFalse;
return std::pair<Value *, int64_t>(Summand, Immediate);
}
Instruction *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
using namespace PatternMatch;
LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n"
<< *I << "\n");
auto *Ty = cast<FixedVectorType>(I->getType());
Value *Ptr = I->getArgOperand(0);
Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue();
Value *Mask = I->getArgOperand(2);
Value *PassThru = I->getArgOperand(3);
if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
Alignment))
return nullptr;
lookThroughBitcast(Ptr);
assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
IRBuilder<> Builder(I->getContext());
Builder.SetInsertPoint(I);
Builder.SetCurrentDebugLocation(I->getDebugLoc());
Instruction *Root = I;
Instruction *Load = tryCreateIncrementingGatScat(I, Ptr, Builder);
if (!Load)
Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder);
if (!Load)
Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
if (!Load)
return nullptr;
if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) {
LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
<< "creating select\n");
Load = SelectInst::Create(Mask, Load, PassThru);
Builder.Insert(Load);
}
Root->replaceAllUsesWith(Load);
Root->eraseFromParent();
if (Root != I)
I->eraseFromParent();
LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n"
<< *Load << "\n");
return Load;
}
Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBase(
IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
using namespace PatternMatch;
auto *Ty = cast<FixedVectorType>(I->getType());
LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
return nullptr;
Value *Mask = I->getArgOperand(2);
if (match(Mask, m_One()))
return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
{Ty, Ptr->getType()},
{Ptr, Builder.getInt32(Increment)});
else
return Builder.CreateIntrinsic(
Intrinsic::arm_mve_vldr_gather_base_predicated,
{Ty, Ptr->getType(), Mask->getType()},
{Ptr, Builder.getInt32(Increment), Mask});
}
Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB(
IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
using namespace PatternMatch;
auto *Ty = cast<FixedVectorType>(I->getType());
LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers with "
<< "writeback\n");
if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
return nullptr;
Value *Mask = I->getArgOperand(2);
if (match(Mask, m_One()))
return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb,
{Ty, Ptr->getType()},
{Ptr, Builder.getInt32(Increment)});
else
return Builder.CreateIntrinsic(
Intrinsic::arm_mve_vldr_gather_base_wb_predicated,
{Ty, Ptr->getType(), Mask->getType()},
{Ptr, Builder.getInt32(Increment), Mask});
}
Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) {
using namespace PatternMatch;
Type *MemoryTy = I->getType();
Type *ResultTy = MemoryTy;
unsigned Unsigned = 1;
auto *Extend = Root;
bool TruncResult = false;
if (MemoryTy->getPrimitiveSizeInBits() < 128) {
if (I->hasOneUse()) {
Instruction* User = cast<Instruction>(*I->users().begin());
if (isa<SExtInst>(User) &&
User->getType()->getPrimitiveSizeInBits() == 128) {
LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: "
<< *User << "\n");
Extend = User;
ResultTy = User->getType();
Unsigned = 0;
} else if (isa<ZExtInst>(User) &&
User->getType()->getPrimitiveSizeInBits() == 128) {
LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: "
<< *ResultTy << "\n");
Extend = User;
ResultTy = User->getType();
}
}
if (ResultTy->getPrimitiveSizeInBits() < 128 &&
ResultTy->isIntOrIntVectorTy()) {
ResultTy = ResultTy->getWithNewBitWidth(
128 / cast<FixedVectorType>(ResultTy)->getNumElements());
TruncResult = true;
LLVM_DEBUG(dbgs() << "masked gathers: Small input type, truncing to: "
<< *ResultTy << "\n");
}
if (ResultTy->getPrimitiveSizeInBits() != 128) {
LLVM_DEBUG(dbgs() << "masked gathers: Extend needed but not provided "
"from the correct type. Expanding\n");
return nullptr;
}
}
Value *Offsets;
int Scale;
Value *BasePtr = decomposePtr(
Ptr, Offsets, Scale, cast<FixedVectorType>(ResultTy), MemoryTy, Builder);
if (!BasePtr)
return nullptr;
Root = Extend;
Value *Mask = I->getArgOperand(2);
Instruction *Load = nullptr;
if (!match(Mask, m_One()))
Load = Builder.CreateIntrinsic(
Intrinsic::arm_mve_vldr_gather_offset_predicated,
{ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()},
{BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()),
Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask});
else
Load = Builder.CreateIntrinsic(
Intrinsic::arm_mve_vldr_gather_offset,
{ResultTy, BasePtr->getType(), Offsets->getType()},
{BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()),
Builder.getInt32(Scale), Builder.getInt32(Unsigned)});
if (TruncResult) {
Load = TruncInst::Create(Instruction::Trunc, Load, MemoryTy);
Builder.Insert(Load);
}
return Load;
}
Instruction *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) {
using namespace PatternMatch;
LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n"
<< *I << "\n");
Value *Input = I->getArgOperand(0);
Value *Ptr = I->getArgOperand(1);
Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue();
auto *Ty = cast<FixedVectorType>(Input->getType());
if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
Alignment))
return nullptr;
lookThroughBitcast(Ptr);
assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
IRBuilder<> Builder(I->getContext());
Builder.SetInsertPoint(I);
Builder.SetCurrentDebugLocation(I->getDebugLoc());
Instruction *Store = tryCreateIncrementingGatScat(I, Ptr, Builder);
if (!Store)
Store = tryCreateMaskedScatterOffset(I, Ptr, Builder);
if (!Store)
Store = tryCreateMaskedScatterBase(I, Ptr, Builder);
if (!Store)
return nullptr;
LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n"
<< *Store << "\n");
I->eraseFromParent();
return Store;
}
Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBase(
IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
using namespace PatternMatch;
Value *Input = I->getArgOperand(0);
auto *Ty = cast<FixedVectorType>(Input->getType());
if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) {
return nullptr;
}
Value *Mask = I->getArgOperand(3);
LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n");
if (match(Mask, m_One()))
return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base,
{Ptr->getType(), Input->getType()},
{Ptr, Builder.getInt32(Increment), Input});
else
return Builder.CreateIntrinsic(
Intrinsic::arm_mve_vstr_scatter_base_predicated,
{Ptr->getType(), Input->getType(), Mask->getType()},
{Ptr, Builder.getInt32(Increment), Input, Mask});
}
Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB(
IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
using namespace PatternMatch;
Value *Input = I->getArgOperand(0);
auto *Ty = cast<FixedVectorType>(Input->getType());
LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers "
<< "with writeback\n");
if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
return nullptr;
Value *Mask = I->getArgOperand(3);
if (match(Mask, m_One()))
return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb,
{Ptr->getType(), Input->getType()},
{Ptr, Builder.getInt32(Increment), Input});
else
return Builder.CreateIntrinsic(
Intrinsic::arm_mve_vstr_scatter_base_wb_predicated,
{Ptr->getType(), Input->getType(), Mask->getType()},
{Ptr, Builder.getInt32(Increment), Input, Mask});
}
Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterOffset(
IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
using namespace PatternMatch;
Value *Input = I->getArgOperand(0);
Value *Mask = I->getArgOperand(3);
Type *InputTy = Input->getType();
Type *MemoryTy = InputTy;
LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing"
<< " to base + vector of offsets\n");
if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) {
Value *PreTrunc = Trunc->getOperand(0);
Type *PreTruncTy = PreTrunc->getType();
if (PreTruncTy->getPrimitiveSizeInBits() == 128) {
Input = PreTrunc;
InputTy = PreTruncTy;
}
}
bool ExtendInput = false;
if (InputTy->getPrimitiveSizeInBits() < 128 &&
InputTy->isIntOrIntVectorTy()) {
InputTy = InputTy->getWithNewBitWidth(
128 / cast<FixedVectorType>(InputTy)->getNumElements());
ExtendInput = true;
LLVM_DEBUG(dbgs() << "masked scatters: Small input type, will extend:\n"
<< *Input << "\n");
}
if (InputTy->getPrimitiveSizeInBits() != 128) {
LLVM_DEBUG(dbgs() << "masked scatters: cannot create scatters for "
"non-standard input types. Expanding.\n");
return nullptr;
}
Value *Offsets;
int Scale;
Value *BasePtr = decomposePtr(
Ptr, Offsets, Scale, cast<FixedVectorType>(InputTy), MemoryTy, Builder);
if (!BasePtr)
return nullptr;
if (ExtendInput)
Input = Builder.CreateZExt(Input, InputTy);
if (!match(Mask, m_One()))
return Builder.CreateIntrinsic(
Intrinsic::arm_mve_vstr_scatter_offset_predicated,
{BasePtr->getType(), Offsets->getType(), Input->getType(),
Mask->getType()},
{BasePtr, Offsets, Input,
Builder.getInt32(MemoryTy->getScalarSizeInBits()),
Builder.getInt32(Scale), Mask});
else
return Builder.CreateIntrinsic(
Intrinsic::arm_mve_vstr_scatter_offset,
{BasePtr->getType(), Offsets->getType(), Input->getType()},
{BasePtr, Offsets, Input,
Builder.getInt32(MemoryTy->getScalarSizeInBits()),
Builder.getInt32(Scale)});
}
Instruction *MVEGatherScatterLowering::tryCreateIncrementingGatScat(
IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
FixedVectorType *Ty;
if (I->getIntrinsicID() == Intrinsic::masked_gather)
Ty = cast<FixedVectorType>(I->getType());
else
Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType());
if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
return nullptr;
Loop *L = LI->getLoopFor(I->getParent());
if (L == nullptr)
return nullptr;
GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
Value *Offsets;
Value *BasePtr = decomposeGEP(Offsets, Ty, GEP, Builder);
if (!BasePtr)
return nullptr;
LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
"wb gather/scatter\n");
int TypeScale =
computeScale(DL->getTypeSizeInBits(GEP->getOperand(0)->getType()),
DL->getTypeSizeInBits(GEP->getType()) /
cast<FixedVectorType>(GEP->getType())->getNumElements());
if (TypeScale == -1)
return nullptr;
if (GEP->hasOneUse()) {
if (auto *Load = tryCreateIncrementingWBGatScat(I, BasePtr, Offsets,
TypeScale, Builder))
return Load;
}
LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
"non-wb gather/scatter\n");
std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
if (Add.first == nullptr)
return nullptr;
Value *OffsetsIncoming = Add.first;
int64_t Immediate = Add.second;
Instruction *ScaledOffsets = BinaryOperator::Create(
Instruction::Shl, OffsetsIncoming,
Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)),
"ScaledIndex", I);
OffsetsIncoming = BinaryOperator::Create(
Instruction::Add, ScaledOffsets,
Builder.CreateVectorSplat(
Ty->getNumElements(),
Builder.CreatePtrToInt(
BasePtr,
cast<VectorType>(ScaledOffsets->getType())->getElementType())),
"StartIndex", I);
if (I->getIntrinsicID() == Intrinsic::masked_gather)
return tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate);
else
return tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate);
}
Instruction *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat(
IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale,
IRBuilder<> &Builder) {
Loop *L = LI->getLoopFor(I->getParent());
PHINode *Phi = dyn_cast<PHINode>(Offsets);
if (Phi == nullptr || Phi->getNumIncomingValues() != 2 ||
Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2)
return nullptr;
unsigned IncrementIndex =
Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1;
Offsets = Phi->getIncomingValue(IncrementIndex);
std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
if (Add.first == nullptr)
return nullptr;
Value *OffsetsIncoming = Add.first;
int64_t Immediate = Add.second;
if (OffsetsIncoming != Phi)
return nullptr;
Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back());
unsigned NumElems =
cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements();
Instruction *ScaledOffsets = BinaryOperator::Create(
Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex),
Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)),
"ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
OffsetsIncoming = BinaryOperator::Create(
Instruction::Add, ScaledOffsets,
Builder.CreateVectorSplat(
NumElems,
Builder.CreatePtrToInt(
BasePtr,
cast<VectorType>(ScaledOffsets->getType())->getElementType())),
"StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
OffsetsIncoming = BinaryOperator::Create(
Instruction::Sub, OffsetsIncoming,
Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)),
"PreIncrementStartIndex",
&Phi->getIncomingBlock(1 - IncrementIndex)->back());
Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming);
Builder.SetInsertPoint(I);
Instruction *EndResult;
Instruction *NewInduction;
if (I->getIntrinsicID() == Intrinsic::masked_gather) {
Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate);
EndResult = ExtractValueInst::Create(Load, 0, "Gather");
NewInduction = ExtractValueInst::Create(Load, 1, "GatherIncrement");
Builder.Insert(EndResult);
Builder.Insert(NewInduction);
} else {
EndResult = NewInduction =
tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate);
}
Instruction *AddInst = cast<Instruction>(Offsets);
AddInst->replaceAllUsesWith(NewInduction);
AddInst->eraseFromParent();
Phi->setIncomingValue(IncrementIndex, NewInduction);
return EndResult;
}
void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
Value *OffsSecondOperand,
unsigned StartIndex) {
LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n");
Instruction *InsertionPoint =
&cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back());
Instruction *NewIndex = BinaryOperator::Create(
Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand,
"PushedOutAdd", InsertionPoint);
unsigned IncrementIndex = StartIndex == 0 ? 1 : 0;
Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex));
Phi->addIncoming(Phi->getIncomingValue(IncrementIndex),
Phi->getIncomingBlock(IncrementIndex));
Phi->removeIncomingValue(IncrementIndex);
Phi->removeIncomingValue(StartIndex);
}
void MVEGatherScatterLowering::pushOutMulShl(unsigned Opcode, PHINode *&Phi,
Value *IncrementPerRound,
Value *OffsSecondOperand,
unsigned LoopIncrement,
IRBuilder<> &Builder) {
LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");
Instruction *InsertionPoint = &cast<Instruction>(
Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back());
Value *StartIndex =
BinaryOperator::Create((Instruction::BinaryOps)Opcode,
Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
OffsSecondOperand, "PushedOutMul", InsertionPoint);
Instruction *Product =
BinaryOperator::Create((Instruction::BinaryOps)Opcode, IncrementPerRound,
OffsSecondOperand, "Product", InsertionPoint);
Instruction *NewIncrement = BinaryOperator::Create(
Instruction::Add, Phi, Product, "IncrementPushedOutMul",
cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back())
.getPrevNode());
Phi->addIncoming(StartIndex,
Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1));
Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement));
Phi->removeIncomingValue((unsigned)0);
Phi->removeIncomingValue((unsigned)0);
}
static bool hasAllGatScatUsers(Instruction *I, const DataLayout &DL) {
if (I->hasNUses(0)) {
return false;
}
bool Gatscat = true;
for (User *U : I->users()) {
if (!isa<Instruction>(U))
return false;
if (isa<GetElementPtrInst>(U) ||
isGatherScatter(dyn_cast<IntrinsicInst>(U))) {
return Gatscat;
} else {
unsigned OpCode = cast<Instruction>(U)->getOpcode();
if ((OpCode == Instruction::Add || OpCode == Instruction::Mul ||
OpCode == Instruction::Shl ||
isAddLikeOr(cast<Instruction>(U), DL)) &&
hasAllGatScatUsers(cast<Instruction>(U), DL)) {
continue;
}
return false;
}
}
return Gatscat;
}
bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
LoopInfo *LI) {
LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize: "
<< *Offsets << "\n");
if (!isa<Instruction>(Offsets))
return false;
Instruction *Offs = cast<Instruction>(Offsets);
if (Offs->getOpcode() != Instruction::Add && !isAddLikeOr(Offs, *DL) &&
Offs->getOpcode() != Instruction::Mul &&
Offs->getOpcode() != Instruction::Shl)
return false;
Loop *L = LI->getLoopFor(BB);
if (L == nullptr)
return false;
if (!Offs->hasOneUse()) {
if (!hasAllGatScatUsers(Offs, *DL))
return false;
}
PHINode *Phi;
int OffsSecondOp;
if (isa<PHINode>(Offs->getOperand(0))) {
Phi = cast<PHINode>(Offs->getOperand(0));
OffsSecondOp = 1;
} else if (isa<PHINode>(Offs->getOperand(1))) {
Phi = cast<PHINode>(Offs->getOperand(1));
OffsSecondOp = 0;
} else {
bool Changed = false;
if (isa<Instruction>(Offs->getOperand(0)) &&
L->contains(cast<Instruction>(Offs->getOperand(0))))
Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI);
if (isa<Instruction>(Offs->getOperand(1)) &&
L->contains(cast<Instruction>(Offs->getOperand(1))))
Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI);
if (!Changed)
return false;
if (isa<PHINode>(Offs->getOperand(0))) {
Phi = cast<PHINode>(Offs->getOperand(0));
OffsSecondOp = 1;
} else if (isa<PHINode>(Offs->getOperand(1))) {
Phi = cast<PHINode>(Offs->getOperand(1));
OffsSecondOp = 0;
} else {
return false;
}
}
if (Phi->getParent() != L->getHeader())
return false;
BinaryOperator *IncInstruction;
Value *Start, *IncrementPerRound;
if (!matchSimpleRecurrence(Phi, IncInstruction, Start, IncrementPerRound) ||
IncInstruction->getOpcode() != Instruction::Add)
return false;
int IncrementingBlock = Phi->getIncomingValue(0) == IncInstruction ? 0 : 1;
Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp);
if (IncrementPerRound->getType() != OffsSecondOperand->getType() ||
!L->isLoopInvariant(OffsSecondOperand))
return false;
if (!isa<Constant>(IncrementPerRound) &&
!(isa<Instruction>(IncrementPerRound) &&
!L->contains(cast<Instruction>(IncrementPerRound))))
return false;
PHINode *NewPhi;
if (Phi->getNumUses() == 2) {
if (IncInstruction->getNumUses() != 1) {
IncInstruction = BinaryOperator::Create(
Instruction::BinaryOps(IncInstruction->getOpcode()), Phi,
IncrementPerRound, "LoopIncrement", IncInstruction);
Phi->setIncomingValue(IncrementingBlock, IncInstruction);
}
NewPhi = Phi;
} else {
NewPhi = PHINode::Create(Phi->getType(), 2, "NewPhi", Phi);
NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1),
Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1));
IncInstruction = BinaryOperator::Create(
Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi,
IncrementPerRound, "LoopIncrement", IncInstruction);
NewPhi->addIncoming(IncInstruction,
Phi->getIncomingBlock(IncrementingBlock));
IncrementingBlock = 1;
}
IRBuilder<> Builder(BB->getContext());
Builder.SetInsertPoint(Phi);
Builder.SetCurrentDebugLocation(Offs->getDebugLoc());
switch (Offs->getOpcode()) {
case Instruction::Add:
case Instruction::Or:
pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
break;
case Instruction::Mul:
case Instruction::Shl:
pushOutMulShl(Offs->getOpcode(), NewPhi, IncrementPerRound,
OffsSecondOperand, IncrementingBlock, Builder);
break;
default:
return false;
}
LLVM_DEBUG(dbgs() << "masked gathers/scatters: simplified loop variable "
<< "add/mul\n");
Offs->replaceAllUsesWith(NewPhi);
if (Offs->hasNUses(0))
Offs->eraseFromParent();
if (IncInstruction->hasNUses(0))
IncInstruction->eraseFromParent();
return true;
}
static Value *CheckAndCreateOffsetAdd(Value *X, unsigned ScaleX, Value *Y,
unsigned ScaleY, IRBuilder<> &Builder) {
auto FixSummands = [&Builder](FixedVectorType *&VT, Value *&NonVectorVal) {
ConstantInt *Const;
if ((Const = dyn_cast<ConstantInt>(NonVectorVal)) &&
VT->getElementType() != NonVectorVal->getType()) {
unsigned TargetElemSize = VT->getElementType()->getPrimitiveSizeInBits();
uint64_t N = Const->getZExtValue();
if (N < (unsigned)(1 << (TargetElemSize - 1))) {
NonVectorVal = Builder.CreateVectorSplat(
VT->getNumElements(), Builder.getIntN(TargetElemSize, N));
return;
}
}
NonVectorVal =
Builder.CreateVectorSplat(VT->getNumElements(), NonVectorVal);
};
FixedVectorType *XElType = dyn_cast<FixedVectorType>(X->getType());
FixedVectorType *YElType = dyn_cast<FixedVectorType>(Y->getType());
if (XElType && !YElType) {
FixSummands(XElType, Y);
YElType = cast<FixedVectorType>(Y->getType());
} else if (YElType && !XElType) {
FixSummands(YElType, X);
XElType = cast<FixedVectorType>(X->getType());
}
assert(XElType && YElType && "Unknown vector types");
if (XElType != YElType) {
LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n");
return nullptr;
}
if (XElType->getElementType()->getScalarSizeInBits() != 32) {
Constant *ConstX = dyn_cast<Constant>(X);
Constant *ConstY = dyn_cast<Constant>(Y);
if (!ConstX || !ConstY)
return nullptr;
unsigned TargetElemSize = 128 / XElType->getNumElements();
for (unsigned i = 0; i < XElType->getNumElements(); i++) {
ConstantInt *ConstXEl =
dyn_cast<ConstantInt>(ConstX->getAggregateElement(i));
ConstantInt *ConstYEl =
dyn_cast<ConstantInt>(ConstY->getAggregateElement(i));
if (!ConstXEl || !ConstYEl ||
ConstXEl->getZExtValue() * ScaleX +
ConstYEl->getZExtValue() * ScaleY >=
(unsigned)(1 << (TargetElemSize - 1)))
return nullptr;
}
}
Value *XScale = Builder.CreateVectorSplat(
XElType->getNumElements(),
Builder.getIntN(XElType->getScalarSizeInBits(), ScaleX));
Value *YScale = Builder.CreateVectorSplat(
YElType->getNumElements(),
Builder.getIntN(YElType->getScalarSizeInBits(), ScaleY));
Value *Add = Builder.CreateAdd(Builder.CreateMul(X, XScale),
Builder.CreateMul(Y, YScale));
if (checkOffsetSize(Add, XElType->getNumElements()))
return Add;
else
return nullptr;
}
Value *MVEGatherScatterLowering::foldGEP(GetElementPtrInst *GEP,
Value *&Offsets, unsigned &Scale,
IRBuilder<> &Builder) {
Value *GEPPtr = GEP->getPointerOperand();
Offsets = GEP->getOperand(1);
Scale = DL->getTypeAllocSize(GEP->getSourceElementType());
if (GEP->getNumIndices() != 1 || !isa<Constant>(Offsets))
return nullptr;
if (GetElementPtrInst *BaseGEP = dyn_cast<GetElementPtrInst>(GEPPtr)) {
Value *BaseBasePtr = foldGEP(BaseGEP, Offsets, Scale, Builder);
if (!BaseBasePtr)
return nullptr;
Offsets = CheckAndCreateOffsetAdd(
Offsets, Scale, GEP->getOperand(1),
DL->getTypeAllocSize(GEP->getSourceElementType()), Builder);
if (Offsets == nullptr)
return nullptr;
Scale = 1; return BaseBasePtr;
}
return GEPPtr;
}
bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB,
LoopInfo *LI) {
GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Address);
if (!GEP)
return false;
bool Changed = false;
if (GEP->hasOneUse() && isa<GetElementPtrInst>(GEP->getPointerOperand())) {
IRBuilder<> Builder(GEP->getContext());
Builder.SetInsertPoint(GEP);
Builder.SetCurrentDebugLocation(GEP->getDebugLoc());
Value *Offsets;
unsigned Scale;
Value *Base = foldGEP(GEP, Offsets, Scale, Builder);
if (Offsets && Base && Base != GEP) {
assert(Scale == 1 && "Expected to fold GEP to a scale of 1");
Type *BaseTy = Builder.getInt8PtrTy();
if (auto *VecTy = dyn_cast<FixedVectorType>(Base->getType()))
BaseTy = FixedVectorType::get(BaseTy, VecTy);
GetElementPtrInst *NewAddress = GetElementPtrInst::Create(
Builder.getInt8Ty(), Builder.CreateBitCast(Base, BaseTy), Offsets,
"gep.merged", GEP);
LLVM_DEBUG(dbgs() << "Folded GEP: " << *GEP
<< "\n new : " << *NewAddress << "\n");
GEP->replaceAllUsesWith(
Builder.CreateBitCast(NewAddress, GEP->getType()));
GEP = NewAddress;
Changed = true;
}
}
Changed |= optimiseOffsets(GEP->getOperand(1), GEP->getParent(), LI);
return Changed;
}
bool MVEGatherScatterLowering::runOnFunction(Function &F) {
if (!EnableMaskedGatherScatters)
return false;
auto &TPC = getAnalysis<TargetPassConfig>();
auto &TM = TPC.getTM<TargetMachine>();
auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
if (!ST->hasMVEIntegerOps())
return false;
LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
DL = &F.getParent()->getDataLayout();
SmallVector<IntrinsicInst *, 4> Gathers;
SmallVector<IntrinsicInst *, 4> Scatters;
bool Changed = false;
for (BasicBlock &BB : F) {
Changed |= SimplifyInstructionsInBlock(&BB);
for (Instruction &I : BB) {
IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&
isa<FixedVectorType>(II->getType())) {
Gathers.push_back(II);
Changed |= optimiseAddress(II->getArgOperand(0), II->getParent(), LI);
} else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&
isa<FixedVectorType>(II->getArgOperand(0)->getType())) {
Scatters.push_back(II);
Changed |= optimiseAddress(II->getArgOperand(1), II->getParent(), LI);
}
}
}
for (unsigned i = 0; i < Gathers.size(); i++) {
IntrinsicInst *I = Gathers[i];
Instruction *L = lowerGather(I);
if (L == nullptr)
continue;
SimplifyInstructionsInBlock(L->getParent());
Changed = true;
}
for (unsigned i = 0; i < Scatters.size(); i++) {
IntrinsicInst *I = Scatters[i];
Instruction *S = lowerScatter(I);
if (S == nullptr)
continue;
SimplifyInstructionsInBlock(S->getParent());
Changed = true;
}
return Changed;
}