#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
#define LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
#include "MCTargetDesc/SPIRVBaseInfo.h"
#include "MCTargetDesc/SPIRVMCTargetDesc.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
#include <type_traits>
namespace llvm {
namespace SPIRV {
class DTSortableEntry : public MapVector<const MachineFunction *, Register> {
SmallVector<DTSortableEntry *, 2> Deps;
struct FlagsTy {
unsigned IsFunc : 1;
unsigned IsGV : 1;
FlagsTy() : IsFunc(0), IsGV(0) {}
};
FlagsTy Flags;
public:
bool getIsFunc() const { return Flags.IsFunc; }
bool getIsGV() const { return Flags.IsGV; }
void setIsFunc(bool V) { Flags.IsFunc = V; }
void setIsGV(bool V) { Flags.IsGV = V; }
const SmallVector<DTSortableEntry *, 2> &getDeps() const { return Deps; }
void addDep(DTSortableEntry *E) { Deps.push_back(E); }
};
}
template <typename KeyTy> class SPIRVDuplicatesTrackerBase {
public:
using StorageTy = MapVector<KeyTy, SPIRV::DTSortableEntry>;
private:
StorageTy Storage;
public:
void add(KeyTy V, const MachineFunction *MF, Register R) {
if (find(V, MF).isValid())
return;
Storage[V][MF] = R;
if (std::is_same<Function,
typename std::remove_const<
typename std::remove_pointer<KeyTy>::type>::type>() ||
std::is_same<Argument,
typename std::remove_const<
typename std::remove_pointer<KeyTy>::type>::type>())
Storage[V].setIsFunc(true);
if (std::is_same<GlobalVariable,
typename std::remove_const<
typename std::remove_pointer<KeyTy>::type>::type>())
Storage[V].setIsGV(true);
}
Register find(KeyTy V, const MachineFunction *MF) const {
auto iter = Storage.find(V);
if (iter != Storage.end()) {
auto Map = iter->second;
auto iter2 = Map.find(MF);
if (iter2 != Map.end())
return iter2->second;
}
return Register();
}
const StorageTy &getAllUses() const { return Storage; }
private:
StorageTy &getAllUses() { return Storage; }
friend class SPIRVGeneralDuplicatesTracker;
};
template <typename T>
class SPIRVDuplicatesTracker : public SPIRVDuplicatesTrackerBase<const T *> {};
class SPIRVGeneralDuplicatesTracker {
SPIRVDuplicatesTracker<Type> TT;
SPIRVDuplicatesTracker<Constant> CT;
SPIRVDuplicatesTracker<GlobalVariable> GT;
SPIRVDuplicatesTracker<Function> FT;
SPIRVDuplicatesTracker<Argument> AT;
using SPIRVReg2EntryTy =
MapVector<MachineOperand *, SPIRV::DTSortableEntry *>;
template <typename T>
void prebuildReg2Entry(SPIRVDuplicatesTracker<T> &DT,
SPIRVReg2EntryTy &Reg2Entry);
public:
void buildDepsGraph(std::vector<SPIRV::DTSortableEntry *> &Graph,
MachineModuleInfo *MMI);
void add(const Type *T, const MachineFunction *MF, Register R) {
TT.add(T, MF, R);
}
void add(const Constant *C, const MachineFunction *MF, Register R) {
CT.add(C, MF, R);
}
void add(const GlobalVariable *GV, const MachineFunction *MF, Register R) {
GT.add(GV, MF, R);
}
void add(const Function *F, const MachineFunction *MF, Register R) {
FT.add(F, MF, R);
}
void add(const Argument *Arg, const MachineFunction *MF, Register R) {
AT.add(Arg, MF, R);
}
Register find(const Type *T, const MachineFunction *MF) {
return TT.find(const_cast<Type *>(T), MF);
}
Register find(const Constant *C, const MachineFunction *MF) {
return CT.find(const_cast<Constant *>(C), MF);
}
Register find(const GlobalVariable *GV, const MachineFunction *MF) {
return GT.find(const_cast<GlobalVariable *>(GV), MF);
}
Register find(const Function *F, const MachineFunction *MF) {
return FT.find(const_cast<Function *>(F), MF);
}
Register find(const Argument *Arg, const MachineFunction *MF) {
return AT.find(const_cast<Argument *>(Arg), MF);
}
const SPIRVDuplicatesTracker<Type> *getTypes() { return &TT; }
};
} #endif