#include "X86.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/CodeGen/Passes.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/CodeGen/ValueTypes.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/IntrinsicsX86.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
#include "llvm/Transforms/Utils/Local.h"
#include <map>
using namespace llvm;
using namespace PatternMatch;
#define DEBUG_TYPE "lower-amx-type"
static bool isAMXCast(Instruction *II) {
return match(II,
m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) ||
match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));
}
static bool isAMXIntrinsic(Value *I) {
auto *II = dyn_cast<IntrinsicInst>(I);
if (!II)
return false;
if (isAMXCast(II))
return false;
if (II->getType()->isX86_AMXTy())
return true;
for (Value *V : II->args()) {
if (V->getType()->isX86_AMXTy())
return true;
}
return false;
}
static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB,
Type *Ty) {
Function &F = *BB->getParent();
Module *M = BB->getModule();
const DataLayout &DL = M->getDataLayout();
LLVMContext &Ctx = Builder.getContext();
auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
unsigned AllocaAS = DL.getAllocaAddrSpace();
AllocaInst *AllocaRes =
new AllocaInst(Ty, AllocaAS, "", &F.getEntryBlock().front());
AllocaRes->setAlignment(AllocaAlignment);
return AllocaRes;
}
static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) {
for (Instruction &I : F.getEntryBlock())
if (!isa<AllocaInst>(&I))
return &I;
llvm_unreachable("No terminator in the entry block!");
}
static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
IRBuilder<> Builder(II);
Value *Row = nullptr, *Col = nullptr;
switch (II->getIntrinsicID()) {
default:
llvm_unreachable("Expect amx intrinsics");
case Intrinsic::x86_tileloadd64_internal:
case Intrinsic::x86_tileloaddt164_internal:
case Intrinsic::x86_tilestored64_internal: {
Row = II->getArgOperand(0);
Col = II->getArgOperand(1);
break;
}
case Intrinsic::x86_tdpbssd_internal:
case Intrinsic::x86_tdpbsud_internal:
case Intrinsic::x86_tdpbusd_internal:
case Intrinsic::x86_tdpbuud_internal:
case Intrinsic::x86_tdpbf16ps_internal: {
switch (OpNo) {
case 3:
Row = II->getArgOperand(0);
Col = II->getArgOperand(1);
break;
case 4:
Row = II->getArgOperand(0);
Col = II->getArgOperand(2);
break;
case 5:
if (isa<ConstantInt>(II->getArgOperand(2)))
Row = Builder.getInt16(
(cast<ConstantInt>(II->getOperand(2))->getSExtValue()) / 4);
else if (isa<Instruction>(II->getArgOperand(2))) {
Builder.SetInsertPoint(cast<Instruction>(II->getOperand(2)));
Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4));
cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->getOperand(2)));
} else {
IRBuilder<> NewBuilder(
getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4));
}
Col = II->getArgOperand(1);
break;
}
break;
}
}
return std::make_pair(Row, Col);
}
static std::pair<Value *, Value *> getShape(PHINode *Phi) {
Use &U = *(Phi->use_begin());
unsigned OpNo = U.getOperandNo();
User *V = U.getUser();
while (V) {
if (isAMXCast(dyn_cast<Instruction>(V))) {
if (V->use_empty())
break;
Use &U = *(V->use_begin());
OpNo = U.getOperandNo();
V = U.getUser();
} else if (isAMXIntrinsic(V)) {
return getShape(cast<IntrinsicInst>(V), OpNo);
} else if (isa<PHINode>(V)) {
if (V->use_empty())
break;
Use &U = *(V->use_begin());
V = U.getUser();
} else {
break;
}
}
return std::make_pair(nullptr, nullptr);
}
namespace {
class X86LowerAMXType {
Function &Func;
std::map<Value *, Value *> Col2Row;
public:
X86LowerAMXType(Function &F) : Func(F) {}
bool visit();
void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
bool transformBitcast(BitCastInst *Bitcast);
};
void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
Value *Row = nullptr, *Col = nullptr;
Use &U = *(Bitcast->use_begin());
unsigned OpNo = U.getOperandNo();
auto *II = cast<IntrinsicInst>(U.getUser());
std::tie(Row, Col) = getShape(II, OpNo);
IRBuilder<> Builder(Bitcast);
Value *Stride = Builder.getInt64(64);
Value *I8Ptr =
Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
Value *NewInst =
Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
Bitcast->replaceAllUsesWith(NewInst);
}
void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
Value *Tile = Bitcast->getOperand(0);
auto *II = cast<IntrinsicInst>(Tile);
Value *Row = II->getOperand(0);
Value *Col = II->getOperand(1);
IRBuilder<> Builder(ST);
Value *Stride = Builder.getInt64(64);
Value *I8Ptr =
Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
if (Bitcast->hasOneUse())
return;
Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
Bitcast->replaceAllUsesWith(Vec);
}
bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
IRBuilder<> Builder(Bitcast);
AllocaInst *AllocaAddr;
Value *I8Ptr, *Stride;
auto *Src = Bitcast->getOperand(0);
auto Prepare = [&](Type *MemTy) {
AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy);
I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
Stride = Builder.getInt64(64);
};
if (Bitcast->getType()->isX86_AMXTy()) {
Use &U = *(Bitcast->use_begin());
unsigned OpNo = U.getOperandNo();
auto *II = dyn_cast<IntrinsicInst>(U.getUser());
if (!II)
return false; Prepare(Bitcast->getOperand(0)->getType());
Builder.CreateStore(Src, AllocaAddr);
Value *Row = nullptr, *Col = nullptr;
std::tie(Row, Col) = getShape(II, OpNo);
std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
Value *NewInst = Builder.CreateIntrinsic(
Intrinsic::x86_tileloadd64_internal, None, Args);
Bitcast->replaceAllUsesWith(NewInst);
} else {
auto *II = dyn_cast<IntrinsicInst>(Src);
if (!II)
return false; Prepare(Bitcast->getType());
Value *Row = II->getOperand(0);
Value *Col = II->getOperand(1);
std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
Bitcast->replaceAllUsesWith(NewInst);
}
return true;
}
bool X86LowerAMXType::visit() {
SmallVector<Instruction *, 8> DeadInsts;
Col2Row.clear();
for (BasicBlock *BB : post_order(&Func)) {
for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(*BB))) {
auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
if (!Bitcast)
continue;
Value *Src = Bitcast->getOperand(0);
if (Bitcast->getType()->isX86_AMXTy()) {
if (Bitcast->user_empty()) {
DeadInsts.push_back(Bitcast);
continue;
}
LoadInst *LD = dyn_cast<LoadInst>(Src);
if (!LD) {
if (transformBitcast(Bitcast))
DeadInsts.push_back(Bitcast);
continue;
}
combineLoadBitcast(LD, Bitcast);
DeadInsts.push_back(Bitcast);
if (LD->hasOneUse())
DeadInsts.push_back(LD);
} else if (Src->getType()->isX86_AMXTy()) {
if (Bitcast->user_empty()) {
DeadInsts.push_back(Bitcast);
continue;
}
StoreInst *ST = nullptr;
for (Use &U : Bitcast->uses()) {
ST = dyn_cast<StoreInst>(U.getUser());
if (ST)
break;
}
if (!ST) {
if (transformBitcast(Bitcast))
DeadInsts.push_back(Bitcast);
continue;
}
combineBitcastStore(Bitcast, ST);
DeadInsts.push_back(ST);
DeadInsts.push_back(Bitcast);
}
}
}
bool C = !DeadInsts.empty();
for (auto *Inst : DeadInsts)
Inst->eraseFromParent();
return C;
}
}
static Value *getAllocaPos(BasicBlock *BB) {
Module *M = BB->getModule();
Function *F = BB->getParent();
IRBuilder<> Builder(&F->getEntryBlock().front());
const DataLayout &DL = M->getDataLayout();
unsigned AllocaAS = DL.getAllocaAddrSpace();
Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
AllocaInst *AllocaRes =
new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front());
BasicBlock::iterator Iter = AllocaRes->getIterator();
++Iter;
Builder.SetInsertPoint(&*Iter);
Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getInt8PtrTy());
return I8Ptr;
}
static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {
assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
auto *II = cast<IntrinsicInst>(TileDef);
assert(II && "Not tile intrinsic!");
Value *Row = II->getOperand(0);
Value *Col = II->getOperand(1);
BasicBlock *BB = TileDef->getParent();
BasicBlock::iterator Iter = TileDef->getIterator();
IRBuilder<> Builder(BB, ++Iter);
Value *Stride = Builder.getInt64(64);
std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
Instruction *TileStore =
Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
return TileStore;
}
static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
Value *V = U.get();
assert(V->getType()->isX86_AMXTy() && "Not define tile!");
IntrinsicInst *II = nullptr;
if (IsPHI) {
Value *PhiOp = dyn_cast<PHINode>(V)->getIncomingValue(0);
II = cast<IntrinsicInst>(PhiOp);
} else {
II = cast<IntrinsicInst>(V);
}
Value *Row = II->getOperand(0);
Value *Col = II->getOperand(1);
Instruction *UserI = dyn_cast<Instruction>(U.getUser());
IRBuilder<> Builder(UserI);
Value *Stride = Builder.getInt64(64);
std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
Value *TileLoad =
Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
UserI->replaceUsesOfWith(V, TileLoad);
}
static bool isIncomingOfPHI(Instruction *I) {
for (Use &U : I->uses()) {
User *V = U.getUser();
if (isa<PHINode>(V))
return true;
}
return false;
}
namespace {
class X86VolatileTileData {
Function &F;
public:
X86VolatileTileData(Function &Func) : F(Func) {}
Value *updatePhiIncomings(BasicBlock *BB,
SmallVector<Instruction *, 2> &Incomings);
void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
bool volatileTileData();
void volatileTilePHI(PHINode *Inst);
void volatileTileNonPHI(Instruction *I);
};
Value *X86VolatileTileData::updatePhiIncomings(
BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) {
Value *I8Ptr = getAllocaPos(BB);
for (auto *I : Incomings) {
User *Store = createTileStore(I, I8Ptr);
for (Use &U : I->uses()) {
User *V = U.getUser();
if (isa<PHINode>(V) || V == Store)
continue;
replaceWithTileLoad(U, I8Ptr);
}
}
return I8Ptr;
}
void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
Value *StorePtr) {
for (Use &U : PHI->uses())
replaceWithTileLoad(U, StorePtr, true);
PHI->eraseFromParent();
}
void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
BasicBlock *BB = PHI->getParent();
SmallVector<Instruction *, 2> Incomings;
for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
Value *Op = PHI->getIncomingValue(I);
Instruction *Inst = dyn_cast<Instruction>(Op);
assert(Inst && "We shouldn't fold AMX instrution!");
Incomings.push_back(Inst);
}
Value *StorePtr = updatePhiIncomings(BB, Incomings);
replacePhiDefWithLoad(PHI, StorePtr);
}
void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
BasicBlock *BB = I->getParent();
Value *I8Ptr = getAllocaPos(BB);
User *Store = createTileStore(I, I8Ptr);
for (Use &U : I->uses()) {
User *V = U.getUser();
assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
if (V != Store)
replaceWithTileLoad(U, I8Ptr);
}
}
bool X86VolatileTileData::volatileTileData() {
bool Changed = false;
for (BasicBlock &BB : F) {
SmallVector<Instruction *, 2> PHIInsts;
SmallVector<Instruction *, 8> AMXDefInsts;
for (Instruction &I : BB) {
if (!I.getType()->isX86_AMXTy())
continue;
if (isa<PHINode>(&I))
PHIInsts.push_back(&I);
else
AMXDefInsts.push_back(&I);
}
for (Instruction *I : AMXDefInsts) {
if (isIncomingOfPHI(I))
continue;
volatileTileNonPHI(I);
Changed = true;
}
for (Instruction *I : PHIInsts) {
volatileTilePHI(dyn_cast<PHINode>(I));
Changed = true;
}
}
return Changed;
}
}
namespace {
class X86LowerAMXCast {
Function &Func;
public:
X86LowerAMXCast(Function &F) : Func(F) {}
void combineCastStore(IntrinsicInst *Cast, StoreInst *ST);
void combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
bool combineLdSt(SmallVectorImpl<Instruction *> &Casts);
bool combineAMXcast(TargetLibraryInfo *TLI);
bool transformAMXCast(IntrinsicInst *AMXCast);
bool transformAllAMXCast();
bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
SmallSetVector<Instruction *, 16> &DeadInst);
};
static bool DCEInstruction(Instruction *I,
SmallSetVector<Instruction *, 16> &WorkList,
const TargetLibraryInfo *TLI) {
if (isInstructionTriviallyDead(I, TLI)) {
salvageDebugInfo(*I);
salvageKnowledge(I);
for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
Value *OpV = I->getOperand(i);
I->setOperand(i, nullptr);
if (!OpV->use_empty() || I == OpV)
continue;
if (Instruction *OpI = dyn_cast<Instruction>(OpV)) {
if (isInstructionTriviallyDead(OpI, TLI)) {
WorkList.insert(OpI);
}
}
}
I->eraseFromParent();
return true;
}
return false;
}
bool X86LowerAMXCast::optimizeAMXCastFromPhi(
IntrinsicInst *CI, PHINode *PN,
SmallSetVector<Instruction *, 16> &DeadInst) {
IRBuilder<> Builder(CI);
Value *Src = CI->getOperand(0);
Type *SrcTy = Src->getType(); Type *DestTy = CI->getType();
SmallVector<PHINode *, 4> PhiWorklist;
SmallSetVector<PHINode *, 4> OldPhiNodes;
PhiWorklist.push_back(PN);
OldPhiNodes.insert(PN);
while (!PhiWorklist.empty()) {
auto *OldPN = PhiWorklist.pop_back_val();
for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) {
Value *IncValue = OldPN->getIncomingValue(I);
if (isa<Constant>(IncValue)) {
auto *IncConst = dyn_cast<Constant>(IncValue);
if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())
return false;
Value *Row = nullptr, *Col = nullptr;
std::tie(Row, Col) = getShape(OldPN);
if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col))
return false;
auto *Block = OldPN->getIncomingBlock(I);
BasicBlock::iterator Iter = Block->getTerminator()->getIterator();
Instruction *NewInst = Builder.CreateIntrinsic(
Intrinsic::x86_tilezero_internal, None, {Row, Col});
NewInst->moveBefore(&*Iter);
NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
{IncValue->getType()}, {NewInst});
NewInst->moveBefore(&*Iter);
OldPN->setIncomingValue(I, NewInst);
IncValue = NewInst;
}
if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
if (OldPhiNodes.insert(PNode))
PhiWorklist.push_back(PNode);
continue;
}
Instruction *ACI = dyn_cast<Instruction>(IncValue);
if (ACI && isAMXCast(ACI)) {
Type *TyA = ACI->getOperand(0)->getType();
Type *TyB = ACI->getType();
if (TyA != DestTy || TyB != SrcTy)
return false;
continue;
}
return false;
}
}
for (auto *OldPN : OldPhiNodes) {
for (User *V : OldPN->users()) {
Instruction *ACI = dyn_cast<Instruction>(V);
if (ACI && isAMXCast(ACI)) {
Type *TyB = ACI->getOperand(0)->getType();
Type *TyA = ACI->getType();
if (TyA != DestTy || TyB != SrcTy)
return false;
} else if (auto *PHI = dyn_cast<PHINode>(V)) {
if (OldPhiNodes.count(PHI) == 0)
return false;
} else
return false;
}
}
SmallDenseMap<PHINode *, PHINode *> NewPNodes;
for (auto *OldPN : OldPhiNodes) {
Builder.SetInsertPoint(OldPN);
PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
NewPNodes[OldPN] = NewPN;
}
for (auto *OldPN : OldPhiNodes) {
PHINode *NewPN = NewPNodes[OldPN];
for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
Value *V = OldPN->getOperand(j);
Value *NewV = nullptr;
Instruction *ACI = dyn_cast<Instruction>(V);
if (ACI && isAMXCast(ACI))
NewV = ACI->getOperand(0);
else if (auto *PrevPN = dyn_cast<PHINode>(V))
NewV = NewPNodes[PrevPN];
assert(NewV);
NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
}
}
for (auto *OldPN : OldPhiNodes) {
PHINode *NewPN = NewPNodes[OldPN];
for (User *V : make_early_inc_range(OldPN->users())) {
Instruction *ACI = dyn_cast<Instruction>(V);
if (ACI && isAMXCast(ACI)) {
Type *TyB = ACI->getOperand(0)->getType();
Type *TyA = ACI->getType();
assert(TyA == DestTy && TyB == SrcTy);
(void)TyA;
(void)TyB;
ACI->replaceAllUsesWith(NewPN);
DeadInst.insert(ACI);
} else if (auto *PHI = dyn_cast<PHINode>(V)) {
assert(OldPhiNodes.contains(PHI));
(void)PHI;
} else
llvm_unreachable("all uses should be handled");
}
}
return true;
}
void X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
Value *Tile = Cast->getOperand(0);
if (!isAMXIntrinsic(Tile))
return;
auto *II = cast<IntrinsicInst>(Tile);
Value *Row = II->getOperand(0);
Value *Col = II->getOperand(1);
IRBuilder<> Builder(ST);
Value *Stride = Builder.getInt64(64);
Value *I8Ptr =
Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
}
void X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
Value *Row = nullptr, *Col = nullptr;
Use &U = *(Cast->use_begin());
unsigned OpNo = U.getOperandNo();
auto *II = cast<IntrinsicInst>(U.getUser());
if (!isAMXIntrinsic(II))
return;
std::tie(Row, Col) = getShape(II, OpNo);
IRBuilder<> Builder(LD);
Value *Stride = Builder.getInt64(64);
Value *I8Ptr =
Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
Value *NewInst =
Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
Cast->replaceAllUsesWith(NewInst);
}
bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
bool Change = false;
for (auto *Cast : Casts) {
auto *II = cast<IntrinsicInst>(Cast);
if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
SmallVector<Instruction *, 2> DeadStores;
for (User *U : Cast->users()) {
StoreInst *Store = dyn_cast<StoreInst>(U);
if (!Store)
continue;
combineCastStore(cast<IntrinsicInst>(Cast), Store);
DeadStores.push_back(Store);
Change = true;
}
for (auto *Store : DeadStores)
Store->eraseFromParent();
} else { SmallVector<Instruction *, 2> DeadLoads;
auto *Load = dyn_cast<LoadInst>(Cast->getOperand(0));
if (!Load || !Load->hasOneUse())
continue;
combineLoadCast(cast<IntrinsicInst>(Cast), Load);
Cast->setOperand(0, nullptr);
Load->eraseFromParent();
}
}
return Change;
}
bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
bool Change = false;
SmallVector<Instruction *, 8> Vec2TileInsts;
SmallVector<Instruction *, 8> Tile2VecInsts;
SmallVector<Instruction *, 8> PhiCastWorkList;
SmallSetVector<Instruction *, 16> DeadInst;
for (BasicBlock &BB : Func) {
for (Instruction &I : BB) {
Value *Vec;
if (match(&I,
m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec))))
Vec2TileInsts.push_back(&I);
else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
m_Value(Vec))))
Tile2VecInsts.push_back(&I);
}
}
auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
for (auto *Inst : Insts) {
for (User *U : Inst->users()) {
IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
if (!II || II->getIntrinsicID() != IID)
continue;
II->replaceAllUsesWith(Inst->getOperand(0));
Change = true;
}
}
};
Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
SmallVector<Instruction *, 8> LiveCasts;
auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
for (auto *Inst : Insts) {
if (Inst->use_empty()) {
Inst->eraseFromParent();
Change = true;
} else {
LiveCasts.push_back(Inst);
}
}
};
EraseInst(Vec2TileInsts);
EraseInst(Tile2VecInsts);
Change |= combineLdSt(LiveCasts);
EraseInst(LiveCasts);
for (BasicBlock &BB : Func) {
for (Instruction &I : BB) {
if (isAMXCast(&I)) {
if (isa<PHINode>(I.getOperand(0)))
PhiCastWorkList.push_back(&I);
}
}
}
for (auto *I : PhiCastWorkList) {
if (DeadInst.contains(I))
continue;
PHINode *PN = cast<PHINode>(I->getOperand(0));
if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) {
DeadInst.insert(PN);
Change = true;
}
}
while (!DeadInst.empty()) {
Instruction *I = DeadInst.pop_back_val();
Change |= DCEInstruction(I, DeadInst, TLI);
}
return Change;
}
bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
IRBuilder<> Builder(AMXCast);
AllocaInst *AllocaAddr;
Value *I8Ptr, *Stride;
auto *Src = AMXCast->getOperand(0);
auto Prepare = [&](Type *MemTy) {
AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy);
I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
Stride = Builder.getInt64(64);
};
if (AMXCast->getType()->isX86_AMXTy()) {
if (AMXCast->use_empty()) {
AMXCast->eraseFromParent();
return true;
}
Use &U = *(AMXCast->use_begin());
unsigned OpNo = U.getOperandNo();
auto *II = dyn_cast<IntrinsicInst>(U.getUser());
if (!II)
return false; Prepare(AMXCast->getOperand(0)->getType());
Builder.CreateStore(Src, AllocaAddr);
Value *Row = nullptr, *Col = nullptr;
std::tie(Row, Col) = getShape(II, OpNo);
std::array<Value *, 4> Args = {
Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
Value *NewInst = Builder.CreateIntrinsic(
Intrinsic::x86_tileloadd64_internal, None, Args);
AMXCast->replaceAllUsesWith(NewInst);
AMXCast->eraseFromParent();
} else {
auto *II = dyn_cast<IntrinsicInst>(Src);
if (!II)
return false; Prepare(AMXCast->getType());
Value *Row = II->getOperand(0);
Value *Col = II->getOperand(1);
std::array<Value *, 5> Args = {
Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);
AMXCast->replaceAllUsesWith(NewInst);
AMXCast->eraseFromParent();
}
return true;
}
bool X86LowerAMXCast::transformAllAMXCast() {
bool Change = false;
SmallVector<Instruction *, 8> WorkLists;
for (BasicBlock &BB : Func) {
for (Instruction &I : BB) {
if (isAMXCast(&I))
WorkLists.push_back(&I);
}
}
for (auto *Inst : WorkLists) {
Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
}
return Change;
}
}
namespace {
class X86LowerAMXTypeLegacyPass : public FunctionPass {
public:
static char ID;
X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry());
}
bool runOnFunction(Function &F) override {
bool C = false;
TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
TargetLibraryInfo *TLI =
&getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
X86LowerAMXCast LAC(F);
C |= LAC.combineAMXcast(TLI);
C |= LAC.transformAllAMXCast();
X86LowerAMXType LAT(F);
C |= LAT.visit();
if (TM->getOptLevel() == CodeGenOpt::None) {
if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
X86VolatileTileData VTD(F);
C = VTD.volatileTileData() || C;
}
}
return C;
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
AU.addRequired<TargetPassConfig>();
AU.addRequired<TargetLibraryInfoWrapperPass>();
}
};
}
static const char PassName[] = "Lower AMX type for load/store";
char X86LowerAMXTypeLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
false)
INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
false)
FunctionPass *llvm::createX86LowerAMXTypePass() {
return new X86LowerAMXTypeLegacyPass();
}