#include "AArch64.h"
#include "Utils/AArch64BaseInfo.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/IntrinsicsAArch64.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/InitializePasses.h"
#include "llvm/Support/Debug.h"
using namespace llvm;
using namespace llvm::PatternMatch;
#define DEBUG_TYPE "aarch64-sve-intrinsic-opts"
namespace {
struct SVEIntrinsicOpts : public ModulePass {
static char ID; SVEIntrinsicOpts() : ModulePass(ID) {
initializeSVEIntrinsicOptsPass(*PassRegistry::getPassRegistry());
}
bool runOnModule(Module &M) override;
void getAnalysisUsage(AnalysisUsage &AU) const override;
private:
bool coalescePTrueIntrinsicCalls(BasicBlock &BB,
SmallSetVector<IntrinsicInst *, 4> &PTrues);
bool optimizePTrueIntrinsicCalls(SmallSetVector<Function *, 4> &Functions);
bool optimizePredicateStore(Instruction *I);
bool optimizePredicateLoad(Instruction *I);
bool optimizeInstructions(SmallSetVector<Function *, 4> &Functions);
bool optimizeFunctions(SmallSetVector<Function *, 4> &Functions);
};
}
void SVEIntrinsicOpts::getAnalysisUsage(AnalysisUsage &AU) const {
AU.addRequired<DominatorTreeWrapperPass>();
AU.setPreservesCFG();
}
char SVEIntrinsicOpts::ID = 0;
static const char *name = "SVE intrinsics optimizations";
INITIALIZE_PASS_BEGIN(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass);
INITIALIZE_PASS_END(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false)
ModulePass *llvm::createSVEIntrinsicOptsPass() {
return new SVEIntrinsicOpts();
}
static bool isPTruePromoted(IntrinsicInst *PTrue) {
SmallVector<IntrinsicInst *, 4> ConvertToUses;
for (User *User : PTrue->users()) {
if (match(User, m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>())) {
ConvertToUses.push_back(cast<IntrinsicInst>(User));
}
}
if (ConvertToUses.empty())
return false;
const auto *PTrueVTy = cast<ScalableVectorType>(PTrue->getType());
for (IntrinsicInst *ConvertToUse : ConvertToUses) {
for (User *User : ConvertToUse->users()) {
auto *IntrUser = dyn_cast<IntrinsicInst>(User);
if (IntrUser && IntrUser->getIntrinsicID() ==
Intrinsic::aarch64_sve_convert_from_svbool) {
const auto *IntrUserVTy = cast<ScalableVectorType>(IntrUser->getType());
if (IntrUserVTy->getElementCount().getKnownMinValue() >
PTrueVTy->getElementCount().getKnownMinValue())
return true;
}
}
}
return false;
}
bool SVEIntrinsicOpts::coalescePTrueIntrinsicCalls(
BasicBlock &BB, SmallSetVector<IntrinsicInst *, 4> &PTrues) {
if (PTrues.size() <= 1)
return false;
auto *MostEncompassingPTrue = *std::max_element(
PTrues.begin(), PTrues.end(), [](auto *PTrue1, auto *PTrue2) {
auto *PTrue1VTy = cast<ScalableVectorType>(PTrue1->getType());
auto *PTrue2VTy = cast<ScalableVectorType>(PTrue2->getType());
return PTrue1VTy->getElementCount().getKnownMinValue() <
PTrue2VTy->getElementCount().getKnownMinValue();
});
PTrues.remove(MostEncompassingPTrue);
PTrues.remove_if(isPTruePromoted);
MostEncompassingPTrue->moveBefore(BB, BB.getFirstInsertionPt());
LLVMContext &Ctx = BB.getContext();
IRBuilder<> Builder(Ctx);
Builder.SetInsertPoint(&BB, ++MostEncompassingPTrue->getIterator());
auto *MostEncompassingPTrueVTy =
cast<VectorType>(MostEncompassingPTrue->getType());
auto *ConvertToSVBool = Builder.CreateIntrinsic(
Intrinsic::aarch64_sve_convert_to_svbool, {MostEncompassingPTrueVTy},
{MostEncompassingPTrue});
bool ConvertFromCreated = false;
for (auto *PTrue : PTrues) {
auto *PTrueVTy = cast<VectorType>(PTrue->getType());
if (MostEncompassingPTrueVTy != PTrueVTy) {
ConvertFromCreated = true;
Builder.SetInsertPoint(&BB, ++ConvertToSVBool->getIterator());
auto *ConvertFromSVBool =
Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
{PTrueVTy}, {ConvertToSVBool});
PTrue->replaceAllUsesWith(ConvertFromSVBool);
} else
PTrue->replaceAllUsesWith(MostEncompassingPTrue);
PTrue->eraseFromParent();
}
if (!ConvertFromCreated)
ConvertToSVBool->eraseFromParent();
return true;
}
bool SVEIntrinsicOpts::optimizePTrueIntrinsicCalls(
SmallSetVector<Function *, 4> &Functions) {
bool Changed = false;
for (auto *F : Functions) {
for (auto &BB : *F) {
SmallSetVector<IntrinsicInst *, 4> SVAllPTrues;
SmallSetVector<IntrinsicInst *, 4> SVPow2PTrues;
for (Instruction &I : BB) {
if (I.use_empty())
continue;
auto *IntrI = dyn_cast<IntrinsicInst>(&I);
if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
continue;
const auto PTruePattern =
cast<ConstantInt>(IntrI->getOperand(0))->getZExtValue();
if (PTruePattern == AArch64SVEPredPattern::all)
SVAllPTrues.insert(IntrI);
if (PTruePattern == AArch64SVEPredPattern::pow2)
SVPow2PTrues.insert(IntrI);
}
Changed |= coalescePTrueIntrinsicCalls(BB, SVAllPTrues);
Changed |= coalescePTrueIntrinsicCalls(BB, SVPow2PTrues);
}
}
return Changed;
}
bool SVEIntrinsicOpts::optimizePredicateStore(Instruction *I) {
auto *F = I->getFunction();
auto Attr = F->getFnAttribute(Attribute::VScaleRange);
if (!Attr.isValid())
return false;
unsigned MinVScale = Attr.getVScaleRangeMin();
Optional<unsigned> MaxVScale = Attr.getVScaleRangeMax();
if (!MaxVScale || MinVScale != MaxVScale)
return false;
auto *PredType =
ScalableVectorType::get(Type::getInt1Ty(I->getContext()), 16);
auto *FixedPredType =
FixedVectorType::get(Type::getInt8Ty(I->getContext()), MinVScale * 2);
auto *Store = dyn_cast<StoreInst>(I);
if (!Store || !Store->isSimple())
return false;
if (Store->getOperand(0)->getType() != FixedPredType)
return false;
auto *IntrI = dyn_cast<IntrinsicInst>(Store->getOperand(0));
if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_extract)
return false;
if (!cast<ConstantInt>(IntrI->getOperand(1))->isZero())
return false;
auto *BitCast = dyn_cast<BitCastInst>(IntrI->getOperand(0));
if (!BitCast)
return false;
if (BitCast->getOperand(0)->getType() != PredType)
return false;
IRBuilder<> Builder(I->getContext());
Builder.SetInsertPoint(I);
auto *PtrBitCast = Builder.CreateBitCast(
Store->getPointerOperand(),
PredType->getPointerTo(Store->getPointerAddressSpace()));
Builder.CreateStore(BitCast->getOperand(0), PtrBitCast);
Store->eraseFromParent();
if (IntrI->getNumUses() == 0)
IntrI->eraseFromParent();
if (BitCast->getNumUses() == 0)
BitCast->eraseFromParent();
return true;
}
bool SVEIntrinsicOpts::optimizePredicateLoad(Instruction *I) {
auto *F = I->getFunction();
auto Attr = F->getFnAttribute(Attribute::VScaleRange);
if (!Attr.isValid())
return false;
unsigned MinVScale = Attr.getVScaleRangeMin();
Optional<unsigned> MaxVScale = Attr.getVScaleRangeMax();
if (!MaxVScale || MinVScale != MaxVScale)
return false;
auto *PredType =
ScalableVectorType::get(Type::getInt1Ty(I->getContext()), 16);
auto *FixedPredType =
FixedVectorType::get(Type::getInt8Ty(I->getContext()), MinVScale * 2);
auto *BitCast = dyn_cast<BitCastInst>(I);
if (!BitCast || BitCast->getType() != PredType)
return false;
auto *IntrI = dyn_cast<IntrinsicInst>(BitCast->getOperand(0));
if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_insert)
return false;
if (!isa<UndefValue>(IntrI->getOperand(0)) ||
!cast<ConstantInt>(IntrI->getOperand(2))->isZero())
return false;
auto *Load = dyn_cast<LoadInst>(IntrI->getOperand(1));
if (!Load || !Load->isSimple())
return false;
if (Load->getType() != FixedPredType)
return false;
IRBuilder<> Builder(I->getContext());
Builder.SetInsertPoint(Load);
auto *PtrBitCast = Builder.CreateBitCast(
Load->getPointerOperand(),
PredType->getPointerTo(Load->getPointerAddressSpace()));
auto *LoadPred = Builder.CreateLoad(PredType, PtrBitCast);
BitCast->replaceAllUsesWith(LoadPred);
BitCast->eraseFromParent();
if (IntrI->getNumUses() == 0)
IntrI->eraseFromParent();
if (Load->getNumUses() == 0)
Load->eraseFromParent();
return true;
}
bool SVEIntrinsicOpts::optimizeInstructions(
SmallSetVector<Function *, 4> &Functions) {
bool Changed = false;
for (auto *F : Functions) {
DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>(*F).getDomTree();
BasicBlock *Root = DT->getRoot();
ReversePostOrderTraversal<BasicBlock *> RPOT(Root);
for (auto *BB : RPOT) {
for (Instruction &I : make_early_inc_range(*BB)) {
switch (I.getOpcode()) {
case Instruction::Store:
Changed |= optimizePredicateStore(&I);
break;
case Instruction::BitCast:
Changed |= optimizePredicateLoad(&I);
break;
}
}
}
}
return Changed;
}
bool SVEIntrinsicOpts::optimizeFunctions(
SmallSetVector<Function *, 4> &Functions) {
bool Changed = false;
Changed |= optimizePTrueIntrinsicCalls(Functions);
Changed |= optimizeInstructions(Functions);
return Changed;
}
bool SVEIntrinsicOpts::runOnModule(Module &M) {
bool Changed = false;
SmallSetVector<Function *, 4> Functions;
for (auto &F : M.getFunctionList()) {
if (!F.isDeclaration())
continue;
switch (F.getIntrinsicID()) {
case Intrinsic::vector_extract:
case Intrinsic::vector_insert:
case Intrinsic::aarch64_sve_ptrue:
for (User *U : F.users())
Functions.insert(cast<Instruction>(U)->getFunction());
break;
default:
break;
}
}
if (!Functions.empty())
Changed |= optimizeFunctions(Functions);
return Changed;
}