#include "ARM.h"
#include "ARMBaseInstrInfo.h"
#include "ARMSubtarget.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsARM.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include <algorithm>
#include <cassert>
using namespace llvm;
#define DEBUG_TYPE "mve-laneinterleave"
cl::opt<bool> EnableInterleave(
"enable-mve-interleave", cl::Hidden, cl::init(true),
cl::desc("Enable interleave MVE vector operation lowering"));
namespace {
class MVELaneInterleaving : public FunctionPass {
public:
static char ID;
explicit MVELaneInterleaving() : FunctionPass(ID) {
initializeMVELaneInterleavingPass(*PassRegistry::getPassRegistry());
}
bool runOnFunction(Function &F) override;
StringRef getPassName() const override { return "MVE lane interleaving"; }
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
AU.addRequired<TargetPassConfig>();
FunctionPass::getAnalysisUsage(AU);
}
};
}
char MVELaneInterleaving::ID = 0;
INITIALIZE_PASS(MVELaneInterleaving, DEBUG_TYPE, "MVE lane interleaving", false,
false)
Pass *llvm::createMVELaneInterleavingPass() {
return new MVELaneInterleaving();
}
static bool isProfitableToInterleave(SmallSetVector<Instruction *, 4> &Exts,
SmallSetVector<Instruction *, 4> &Truncs) {
for (auto *E : Exts) {
if (isa<FPExtInst>(E) || !isa<LoadInst>(E->getOperand(0))) {
LLVM_DEBUG(dbgs() << "Beneficial due to " << *E << "\n");
return true;
}
}
for (auto *T : Truncs) {
if (T->hasOneUse() && !isa<StoreInst>(*T->user_begin())) {
LLVM_DEBUG(dbgs() << "Beneficial due to " << *T << "\n");
return true;
}
}
for (auto *E : Exts) {
if (!E->hasOneUse() ||
cast<Instruction>(*E->user_begin())->getOpcode() != Instruction::Mul) {
LLVM_DEBUG(dbgs() << "Not beneficial due to " << *E << "\n");
return false;
}
}
return true;
}
static bool tryInterleave(Instruction *Start,
SmallPtrSetImpl<Instruction *> &Visited) {
LLVM_DEBUG(dbgs() << "tryInterleave from " << *Start << "\n");
auto *VT = cast<FixedVectorType>(Start->getType());
if (!isa<Instruction>(Start->getOperand(0)))
return false;
std::vector<Instruction *> Worklist;
Worklist.push_back(Start);
Worklist.push_back(cast<Instruction>(Start->getOperand(0)));
SmallSetVector<Instruction *, 4> Truncs;
SmallSetVector<Instruction *, 4> Exts;
SmallSetVector<Use *, 4> OtherLeafs;
SmallSetVector<Instruction *, 4> Ops;
while (!Worklist.empty()) {
Instruction *I = Worklist.back();
Worklist.pop_back();
switch (I->getOpcode()) {
case Instruction::Trunc:
case Instruction::FPTrunc:
if (!Truncs.insert(I))
continue;
Visited.insert(I);
break;
case Instruction::SExt:
case Instruction::ZExt:
case Instruction::FPExt:
if (Exts.count(I))
continue;
for (auto *Use : I->users())
Worklist.push_back(cast<Instruction>(Use));
Exts.insert(I);
break;
case Instruction::Call: {
IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
if (!II)
return false;
switch (II->getIntrinsicID()) {
case Intrinsic::abs:
case Intrinsic::smin:
case Intrinsic::smax:
case Intrinsic::umin:
case Intrinsic::umax:
case Intrinsic::sadd_sat:
case Intrinsic::ssub_sat:
case Intrinsic::uadd_sat:
case Intrinsic::usub_sat:
case Intrinsic::minnum:
case Intrinsic::maxnum:
case Intrinsic::fabs:
case Intrinsic::fma:
case Intrinsic::ceil:
case Intrinsic::floor:
case Intrinsic::rint:
case Intrinsic::round:
case Intrinsic::trunc:
break;
default:
return false;
}
LLVM_FALLTHROUGH; }
case Instruction::Add:
case Instruction::Sub:
case Instruction::Mul:
case Instruction::AShr:
case Instruction::LShr:
case Instruction::Shl:
case Instruction::ICmp:
case Instruction::FCmp:
case Instruction::FAdd:
case Instruction::FMul:
case Instruction::Select:
if (!Ops.insert(I))
continue;
for (Use &Op : I->operands()) {
if (!isa<FixedVectorType>(Op->getType()))
continue;
if (isa<Instruction>(Op))
Worklist.push_back(cast<Instruction>(&Op));
else
OtherLeafs.insert(&Op);
}
for (auto *Use : I->users())
Worklist.push_back(cast<Instruction>(Use));
break;
case Instruction::ShuffleVector:
if (cast<ShuffleVectorInst>(I)->isZeroEltSplat())
continue;
LLVM_FALLTHROUGH;
default:
LLVM_DEBUG(dbgs() << " Unhandled instruction: " << *I << "\n");
return false;
}
}
if (Exts.empty() && OtherLeafs.empty())
return false;
LLVM_DEBUG({
dbgs() << "Found group:\n Exts:";
for (auto *I : Exts)
dbgs() << " " << *I << "\n";
dbgs() << " Ops:";
for (auto *I : Ops)
dbgs() << " " << *I << "\n";
dbgs() << " OtherLeafs:";
for (auto *I : OtherLeafs)
dbgs() << " " << *I->get() << " of " << *I->getUser() << "\n";
dbgs() << "Truncs:";
for (auto *I : Truncs)
dbgs() << " " << *I << "\n";
});
assert(!Truncs.empty() && "Expected some truncs");
unsigned NumElts = VT->getNumElements();
unsigned BaseElts = VT->getScalarSizeInBits() == 16
? 8
: (VT->getScalarSizeInBits() == 8 ? 16 : 0);
if (BaseElts == 0 || NumElts % BaseElts != 0) {
LLVM_DEBUG(dbgs() << " Type is unsupported\n");
return false;
}
if (Start->getOperand(0)->getType()->getScalarSizeInBits() !=
VT->getScalarSizeInBits() * 2) {
LLVM_DEBUG(dbgs() << " Type not double sized\n");
return false;
}
for (Instruction *I : Exts)
if (I->getOperand(0)->getType() != VT) {
LLVM_DEBUG(dbgs() << " Wrong type on " << *I << "\n");
return false;
}
for (Instruction *I : Truncs)
if (I->getType() != VT) {
LLVM_DEBUG(dbgs() << " Wrong type on " << *I << "\n");
return false;
}
if (!isProfitableToInterleave(Exts, Truncs))
return false;
IRBuilder<> Builder(Start);
SmallVector<int, 16> LeafMask;
SmallVector<int, 16> TruncMask;
for (unsigned Base = 0; Base < NumElts; Base += BaseElts) {
for (unsigned i = 0; i < BaseElts / 2; i++)
LeafMask.push_back(Base + i * 2);
for (unsigned i = 0; i < BaseElts / 2; i++)
LeafMask.push_back(Base + i * 2 + 1);
}
for (unsigned Base = 0; Base < NumElts; Base += BaseElts) {
for (unsigned i = 0; i < BaseElts / 2; i++) {
TruncMask.push_back(Base + i);
TruncMask.push_back(Base + i + BaseElts / 2);
}
}
for (Instruction *I : Exts) {
LLVM_DEBUG(dbgs() << "Replacing ext " << *I << "\n");
Builder.SetInsertPoint(I);
Value *Shuffle = Builder.CreateShuffleVector(I->getOperand(0), LeafMask);
bool FPext = isa<FPExtInst>(I);
bool Sext = isa<SExtInst>(I);
Value *Ext = FPext ? Builder.CreateFPExt(Shuffle, I->getType())
: Sext ? Builder.CreateSExt(Shuffle, I->getType())
: Builder.CreateZExt(Shuffle, I->getType());
I->replaceAllUsesWith(Ext);
LLVM_DEBUG(dbgs() << " with " << *Shuffle << "\n");
}
for (Use *I : OtherLeafs) {
LLVM_DEBUG(dbgs() << "Replacing leaf " << *I << "\n");
Builder.SetInsertPoint(cast<Instruction>(I->getUser()));
Value *Shuffle = Builder.CreateShuffleVector(I->get(), LeafMask);
I->getUser()->setOperand(I->getOperandNo(), Shuffle);
LLVM_DEBUG(dbgs() << " with " << *Shuffle << "\n");
}
for (Instruction *I : Truncs) {
LLVM_DEBUG(dbgs() << "Replacing trunc " << *I << "\n");
Builder.SetInsertPoint(I->getParent(), ++I->getIterator());
Value *Shuf = Builder.CreateShuffleVector(I, TruncMask);
I->replaceAllUsesWith(Shuf);
cast<Instruction>(Shuf)->setOperand(0, I);
LLVM_DEBUG(dbgs() << " with " << *Shuf << "\n");
}
return true;
}
bool MVELaneInterleaving::runOnFunction(Function &F) {
if (!EnableInterleave)
return false;
auto &TPC = getAnalysis<TargetPassConfig>();
auto &TM = TPC.getTM<TargetMachine>();
auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
if (!ST->hasMVEIntegerOps())
return false;
bool Changed = false;
SmallPtrSet<Instruction *, 16> Visited;
for (Instruction &I : reverse(instructions(F))) {
if (I.getType()->isVectorTy() &&
(isa<TruncInst>(I) || isa<FPTruncInst>(I)) && !Visited.count(&I))
Changed |= tryInterleave(&I, Visited);
}
return Changed;
}