#include "MCTargetDesc/NVPTXBaseInfo.h"
#include "NVPTX.h"
#include "NVPTXUtilities.h"
#include "llvm/CodeGen/ValueTypes.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/ValueMap.h"
#include "llvm/Transforms/Utils/ValueMapper.h"
using namespace llvm;
namespace llvm {
void initializeGenericToNVVMPass(PassRegistry &);
}
namespace {
class GenericToNVVM : public ModulePass {
public:
static char ID;
GenericToNVVM() : ModulePass(ID) {}
bool runOnModule(Module &M) override;
void getAnalysisUsage(AnalysisUsage &AU) const override {}
private:
Value *remapConstant(Module *M, Function *F, Constant *C,
IRBuilder<> &Builder);
Value *remapConstantVectorOrConstantAggregate(Module *M, Function *F,
Constant *C,
IRBuilder<> &Builder);
Value *remapConstantExpr(Module *M, Function *F, ConstantExpr *C,
IRBuilder<> &Builder);
typedef ValueMap<GlobalVariable *, GlobalVariable *> GVMapTy;
typedef ValueMap<Constant *, Value *> ConstantToValueMapTy;
GVMapTy GVMap;
ConstantToValueMapTy ConstantToValueMap;
};
}
char GenericToNVVM::ID = 0;
ModulePass *llvm::createGenericToNVVMPass() { return new GenericToNVVM(); }
INITIALIZE_PASS(
GenericToNVVM, "generic-to-nvvm",
"Ensure that the global variables are in the global address space", false,
false)
bool GenericToNVVM::runOnModule(Module &M) {
for (GlobalVariable &GV : llvm::make_early_inc_range(M.globals())) {
if (GV.getType()->getAddressSpace() == llvm::ADDRESS_SPACE_GENERIC &&
!llvm::isTexture(GV) && !llvm::isSurface(GV) && !llvm::isSampler(GV) &&
!GV.getName().startswith("llvm.")) {
GlobalVariable *NewGV = new GlobalVariable(
M, GV.getValueType(), GV.isConstant(), GV.getLinkage(),
GV.hasInitializer() ? GV.getInitializer() : nullptr, "", &GV,
GV.getThreadLocalMode(), llvm::ADDRESS_SPACE_GLOBAL);
NewGV->copyAttributesFrom(&GV);
NewGV->copyMetadata(&GV, 0);
GVMap[&GV] = NewGV;
}
}
if (GVMap.empty()) {
return false;
}
for (Function &F : M) {
if (F.isDeclaration()) {
continue;
}
IRBuilder<> Builder(F.getEntryBlock().getFirstNonPHIOrDbg());
for (BasicBlock &BB : F) {
for (Instruction &II : BB) {
for (unsigned i = 0, e = II.getNumOperands(); i < e; ++i) {
Value *Operand = II.getOperand(i);
if (isa<Constant>(Operand)) {
II.setOperand(
i, remapConstant(&M, &F, cast<Constant>(Operand), Builder));
}
}
}
}
ConstantToValueMap.clear();
}
ValueToValueMapTy VM;
for (auto I = GVMap.begin(), E = GVMap.end(); I != E; ++I)
VM[I->first] = I->second;
for (GVMapTy::iterator I = GVMap.begin(), E = GVMap.end(); I != E;) {
GlobalVariable *GV = I->first;
GlobalVariable *NewGV = I->second;
auto Next = std::next(I);
GVMap.erase(I);
I = Next;
Constant *BitCastNewGV = ConstantExpr::getPointerCast(NewGV, GV->getType());
GV->replaceAllUsesWith(BitCastNewGV);
std::string Name = std::string(GV->getName());
GV->eraseFromParent();
NewGV->setName(Name);
}
assert(GVMap.empty() && "Expected it to be empty by now");
return true;
}
Value *GenericToNVVM::remapConstant(Module *M, Function *F, Constant *C,
IRBuilder<> &Builder) {
ConstantToValueMapTy::iterator CTII = ConstantToValueMap.find(C);
if (CTII != ConstantToValueMap.end()) {
return CTII->second;
}
Value *NewValue = C;
if (isa<GlobalVariable>(C)) {
GVMapTy::iterator I = GVMap.find(cast<GlobalVariable>(C));
if (I != GVMap.end()) {
GlobalVariable *GV = I->second;
NewValue = Builder.CreateAddrSpaceCast(
GV,
PointerType::get(GV->getValueType(), llvm::ADDRESS_SPACE_GENERIC));
}
} else if (isa<ConstantAggregate>(C)) {
NewValue = remapConstantVectorOrConstantAggregate(M, F, C, Builder);
} else if (isa<ConstantExpr>(C)) {
NewValue = remapConstantExpr(M, F, cast<ConstantExpr>(C), Builder);
}
ConstantToValueMap[C] = NewValue;
return NewValue;
}
Value *GenericToNVVM::remapConstantVectorOrConstantAggregate(
Module *M, Function *F, Constant *C, IRBuilder<> &Builder) {
bool OperandChanged = false;
SmallVector<Value *, 4> NewOperands;
unsigned NumOperands = C->getNumOperands();
for (unsigned i = 0; i < NumOperands; ++i) {
Value *Operand = C->getOperand(i);
Value *NewOperand = remapConstant(M, F, cast<Constant>(Operand), Builder);
OperandChanged |= Operand != NewOperand;
NewOperands.push_back(NewOperand);
}
if (!OperandChanged) {
return C;
}
Value *NewValue = PoisonValue::get(C->getType());
if (isa<ConstantVector>(C)) {
for (unsigned i = 0; i < NumOperands; ++i) {
Value *Idx = ConstantInt::get(Type::getInt32Ty(M->getContext()), i);
NewValue = Builder.CreateInsertElement(NewValue, NewOperands[i], Idx);
}
} else {
for (unsigned i = 0; i < NumOperands; ++i) {
NewValue =
Builder.CreateInsertValue(NewValue, NewOperands[i], makeArrayRef(i));
}
}
return NewValue;
}
Value *GenericToNVVM::remapConstantExpr(Module *M, Function *F, ConstantExpr *C,
IRBuilder<> &Builder) {
bool OperandChanged = false;
SmallVector<Value *, 4> NewOperands;
unsigned NumOperands = C->getNumOperands();
for (unsigned i = 0; i < NumOperands; ++i) {
Value *Operand = C->getOperand(i);
Value *NewOperand = remapConstant(M, F, cast<Constant>(Operand), Builder);
OperandChanged |= Operand != NewOperand;
NewOperands.push_back(NewOperand);
}
if (!OperandChanged) {
return C;
}
unsigned Opcode = C->getOpcode();
switch (Opcode) {
case Instruction::ICmp:
return Builder.CreateICmp(CmpInst::Predicate(C->getPredicate()),
NewOperands[0], NewOperands[1]);
case Instruction::FCmp:
llvm_unreachable("Address space conversion should have no effect "
"on float point CompareConstantExpr (fcmp)!");
case Instruction::ExtractElement:
return Builder.CreateExtractElement(NewOperands[0], NewOperands[1]);
case Instruction::InsertElement:
return Builder.CreateInsertElement(NewOperands[0], NewOperands[1],
NewOperands[2]);
case Instruction::ShuffleVector:
return Builder.CreateShuffleVector(NewOperands[0], NewOperands[1],
NewOperands[2]);
case Instruction::GetElementPtr:
return Builder.CreateGEP(cast<GEPOperator>(C)->getSourceElementType(),
NewOperands[0],
makeArrayRef(&NewOperands[1], NumOperands - 1), "",
cast<GEPOperator>(C)->isInBounds());
case Instruction::Select:
return Builder.CreateSelect(NewOperands[0], NewOperands[1], NewOperands[2]);
default:
if (Instruction::isBinaryOp(Opcode)) {
return Builder.CreateBinOp(Instruction::BinaryOps(C->getOpcode()),
NewOperands[0], NewOperands[1]);
}
if (Instruction::isCast(Opcode)) {
return Builder.CreateCast(Instruction::CastOps(C->getOpcode()),
NewOperands[0], C->getType());
}
llvm_unreachable("GenericToNVVM encountered an unsupported ConstantExpr");
}
}