#include "RISCV.h"
#include "RISCVTargetMachine.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/IR/GetElementPtrTypeIterator.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/IntrinsicsRISCV.h"
#include "llvm/Transforms/Utils/Local.h"
using namespace llvm;
#define DEBUG_TYPE "riscv-gather-scatter-lowering"
namespace {
class RISCVGatherScatterLowering : public FunctionPass {
const RISCVSubtarget *ST = nullptr;
const RISCVTargetLowering *TLI = nullptr;
LoopInfo *LI = nullptr;
const DataLayout *DL = nullptr;
SmallVector<WeakTrackingVH> MaybeDeadPHIs;
DenseMap<GetElementPtrInst *, std::pair<Value *, Value *>> StridedAddrs;
public:
static char ID;
RISCVGatherScatterLowering() : FunctionPass(ID) {}
bool runOnFunction(Function &F) override;
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
AU.addRequired<TargetPassConfig>();
AU.addRequired<LoopInfoWrapperPass>();
}
StringRef getPassName() const override {
return "RISCV gather/scatter lowering";
}
private:
bool isLegalTypeAndAlignment(Type *DataType, Value *AlignOp);
bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr,
Value *AlignOp);
std::pair<Value *, Value *> determineBaseAndStride(GetElementPtrInst *GEP,
IRBuilder<> &Builder);
bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride,
PHINode *&BasePtr, BinaryOperator *&Inc,
IRBuilder<> &Builder);
};
}
char RISCVGatherScatterLowering::ID = 0;
INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE,
"RISCV gather/scatter lowering pass", false, false)
FunctionPass *llvm::createRISCVGatherScatterLoweringPass() {
return new RISCVGatherScatterLowering();
}
bool RISCVGatherScatterLowering::isLegalTypeAndAlignment(Type *DataType,
Value *AlignOp) {
Type *ScalarType = DataType->getScalarType();
if (!TLI->isLegalElementTypeForRVV(ScalarType))
return false;
MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue();
if (MA && MA->value() < DL->getTypeStoreSize(ScalarType).getFixedSize())
return false;
EVT DataVT = TLI->getValueType(*DL, DataType);
if (!TLI->isTypeLegal(DataVT))
return false;
return true;
}
static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) {
unsigned NumElts = cast<FixedVectorType>(StartC->getType())->getNumElements();
auto *StartVal =
dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement((unsigned)0));
if (!StartVal)
return std::make_pair(nullptr, nullptr);
APInt StrideVal(StartVal->getValue().getBitWidth(), 0);
ConstantInt *Prev = StartVal;
for (unsigned i = 1; i != NumElts; ++i) {
auto *C = dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement(i));
if (!C)
return std::make_pair(nullptr, nullptr);
APInt LocalStride = C->getValue() - Prev->getValue();
if (i == 1)
StrideVal = LocalStride;
else if (StrideVal != LocalStride)
return std::make_pair(nullptr, nullptr);
Prev = C;
}
Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal);
return std::make_pair(StartVal, Stride);
}
static std::pair<Value *, Value *> matchStridedStart(Value *Start,
IRBuilder<> &Builder) {
auto *StartC = dyn_cast<Constant>(Start);
if (StartC)
return matchStridedConstant(StartC);
auto *BO = dyn_cast<BinaryOperator>(Start);
if (!BO || BO->getOpcode() != Instruction::Add)
return std::make_pair(nullptr, nullptr);
unsigned OtherIndex = 1;
Value *Splat = getSplatValue(BO->getOperand(0));
if (!Splat) {
Splat = getSplatValue(BO->getOperand(1));
OtherIndex = 0;
}
if (!Splat)
return std::make_pair(nullptr, nullptr);
Value *Stride;
std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex),
Builder);
if (!Start)
return std::make_pair(nullptr, nullptr);
Builder.SetInsertPoint(BO);
Builder.SetCurrentDebugLocation(DebugLoc());
Start = Builder.CreateAdd(Start, Splat);
return std::make_pair(Start, Stride);
}
bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
Value *&Stride,
PHINode *&BasePtr,
BinaryOperator *&Inc,
IRBuilder<> &Builder) {
if (auto *Phi = dyn_cast<PHINode>(Index)) {
if (Phi->getParent() != L->getHeader())
return false;
Value *Step, *Start;
if (!matchSimpleRecurrence(Phi, Inc, Start, Step) ||
Inc->getOpcode() != Instruction::Add)
return false;
assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1;
assert(Phi->getIncomingValue(IncrementingBlock) == Inc &&
"Expected one operand of phi to be Inc");
if (!L->isLoopInvariant(Step))
return false;
Step = getSplatValue(Step);
if (!Step)
return false;
std::tie(Start, Stride) = matchStridedStart(Start, Builder);
if (!Start)
return false;
assert(Stride != nullptr);
BasePtr =
PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi);
Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar",
Inc);
BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock));
BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock));
MaybeDeadPHIs.push_back(Phi);
return true;
}
auto *BO = dyn_cast<BinaryOperator>(Index);
if (!BO)
return false;
if (BO->getOpcode() != Instruction::Add &&
BO->getOpcode() != Instruction::Or &&
BO->getOpcode() != Instruction::Mul &&
BO->getOpcode() != Instruction::Shl)
return false;
if (BO->getOpcode() == Instruction::Shl && !isa<Constant>(BO->getOperand(1)))
return false;
if (BO->getOpcode() == Instruction::Or &&
!haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL))
return false;
Value *OtherOp;
if (isa<Instruction>(BO->getOperand(0)) &&
L->contains(cast<Instruction>(BO->getOperand(0)))) {
Index = cast<Instruction>(BO->getOperand(0));
OtherOp = BO->getOperand(1);
} else if (isa<Instruction>(BO->getOperand(1)) &&
L->contains(cast<Instruction>(BO->getOperand(1)))) {
Index = cast<Instruction>(BO->getOperand(1));
OtherOp = BO->getOperand(0);
} else {
return false;
}
if (!L->isLoopInvariant(OtherOp))
return false;
Value *SplatOp = getSplatValue(OtherOp);
if (!SplatOp)
return false;
if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder))
return false;
unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0;
unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0;
Value *Step = Inc->getOperand(StepIndex);
Value *Start = BasePtr->getOperand(StartBlock);
Builder.SetInsertPoint(
BasePtr->getIncomingBlock(StartBlock)->getTerminator());
Builder.SetCurrentDebugLocation(DebugLoc());
switch (BO->getOpcode()) {
default:
llvm_unreachable("Unexpected opcode!");
case Instruction::Add:
case Instruction::Or: {
if (isa<ConstantInt>(Start) && cast<ConstantInt>(Start)->isZero())
Start = SplatOp;
else
Start = Builder.CreateAdd(Start, SplatOp, "start");
BasePtr->setIncomingValue(StartBlock, Start);
break;
}
case Instruction::Mul: {
if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero())
Start = Builder.CreateMul(Start, SplatOp, "start");
Step = Builder.CreateMul(Step, SplatOp, "step");
if (isa<ConstantInt>(Stride) && cast<ConstantInt>(Stride)->isOne())
Stride = SplatOp;
else
Stride = Builder.CreateMul(Stride, SplatOp, "stride");
Inc->setOperand(StepIndex, Step);
BasePtr->setIncomingValue(StartBlock, Start);
break;
}
case Instruction::Shl: {
if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero())
Start = Builder.CreateShl(Start, SplatOp, "start");
Step = Builder.CreateShl(Step, SplatOp, "step");
Stride = Builder.CreateShl(Stride, SplatOp, "stride");
Inc->setOperand(StepIndex, Step);
BasePtr->setIncomingValue(StartBlock, Start);
break;
}
}
return true;
}
std::pair<Value *, Value *>
RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP,
IRBuilder<> &Builder) {
auto I = StridedAddrs.find(GEP);
if (I != StridedAddrs.end())
return I->second;
SmallVector<Value *, 2> Ops(GEP->operands());
if (Ops[0]->getType()->isVectorTy())
return std::make_pair(nullptr, nullptr);
Loop *L = LI->getLoopFor(GEP->getParent());
if (!L || !L->getLoopPreheader() || !L->getLoopLatch())
return std::make_pair(nullptr, nullptr);
Optional<unsigned> VecOperand;
unsigned TypeScale = 0;
gep_type_iterator GTI = gep_type_begin(GEP);
for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) {
if (!Ops[i]->getType()->isVectorTy())
continue;
if (VecOperand)
return std::make_pair(nullptr, nullptr);
VecOperand = i;
TypeSize TS = DL->getTypeAllocSize(GTI.getIndexedType());
if (TS.isScalable())
return std::make_pair(nullptr, nullptr);
TypeScale = TS.getFixedSize();
}
if (!VecOperand)
return std::make_pair(nullptr, nullptr);
Value *VecIndex = Ops[*VecOperand];
Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType());
if (VecIndex->getType() != VecIntPtrTy)
return std::make_pair(nullptr, nullptr);
Value *Stride;
BinaryOperator *Inc;
PHINode *BasePhi;
if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder))
return std::make_pair(nullptr, nullptr);
assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1;
assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc &&
"Expected one operand of phi to be Inc");
Builder.SetInsertPoint(GEP);
Ops[*VecOperand] = BasePhi;
Type *SourceTy = GEP->getSourceElementType();
Value *BasePtr =
Builder.CreateGEP(SourceTy, Ops[0], makeArrayRef(Ops).drop_front());
Builder.SetInsertPoint(
BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator());
Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
assert(Stride->getType() == IntPtrTy && "Unexpected type");
if (TypeScale != 1)
Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
auto P = std::make_pair(BasePtr, Stride);
StridedAddrs[GEP] = P;
return P;
}
bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
Type *DataType,
Value *Ptr,
Value *AlignOp) {
if (!isLegalTypeAndAlignment(DataType, AlignOp))
return false;
auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
if (!GEP)
return false;
IRBuilder<> Builder(GEP);
Value *BasePtr, *Stride;
std::tie(BasePtr, Stride) = determineBaseAndStride(GEP, Builder);
if (!BasePtr)
return false;
assert(Stride != nullptr);
Builder.SetInsertPoint(II);
CallInst *Call;
if (II->getIntrinsicID() == Intrinsic::masked_gather)
Call = Builder.CreateIntrinsic(
Intrinsic::riscv_masked_strided_load,
{DataType, BasePtr->getType(), Stride->getType()},
{II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)});
else
Call = Builder.CreateIntrinsic(
Intrinsic::riscv_masked_strided_store,
{DataType, BasePtr->getType(), Stride->getType()},
{II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)});
Call->takeName(II);
II->replaceAllUsesWith(Call);
II->eraseFromParent();
if (GEP->use_empty())
RecursivelyDeleteTriviallyDeadInstructions(GEP);
return true;
}
bool RISCVGatherScatterLowering::runOnFunction(Function &F) {
if (skipFunction(F))
return false;
auto &TPC = getAnalysis<TargetPassConfig>();
auto &TM = TPC.getTM<RISCVTargetMachine>();
ST = &TM.getSubtarget<RISCVSubtarget>(F);
if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors())
return false;
TLI = ST->getTargetLowering();
DL = &F.getParent()->getDataLayout();
LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
StridedAddrs.clear();
SmallVector<IntrinsicInst *, 4> Gathers;
SmallVector<IntrinsicInst *, 4> Scatters;
bool Changed = false;
for (BasicBlock &BB : F) {
for (Instruction &I : BB) {
IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&
isa<FixedVectorType>(II->getType())) {
Gathers.push_back(II);
} else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&
isa<FixedVectorType>(II->getArgOperand(0)->getType())) {
Scatters.push_back(II);
}
}
}
for (auto *II : Gathers)
Changed |= tryCreateStridedLoadStore(
II, II->getType(), II->getArgOperand(0), II->getArgOperand(1));
for (auto *II : Scatters)
Changed |=
tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(),
II->getArgOperand(1), II->getArgOperand(2));
while (!MaybeDeadPHIs.empty()) {
if (auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val()))
RecursivelyDeleteDeadPHINode(Phi);
}
return Changed;
}