#include "R600.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Metadata.h"
#include "llvm/Pass.h"
#include "llvm/Transforms/Utils/Cloning.h"
using namespace llvm;
static StringRef GetImageSizeFunc = "llvm.OpenCL.image.get.size";
static StringRef GetImageFormatFunc = "llvm.OpenCL.image.get.format";
static StringRef GetImageResourceIDFunc = "llvm.OpenCL.image.get.resource.id";
static StringRef GetSamplerResourceIDFunc =
"llvm.OpenCL.sampler.get.resource.id";
static StringRef ImageSizeArgMDType = "__llvm_image_size";
static StringRef ImageFormatArgMDType = "__llvm_image_format";
static StringRef KernelsMDNodeName = "opencl.kernels";
static StringRef KernelArgMDNodeNames[] = {
"kernel_arg_addr_space",
"kernel_arg_access_qual",
"kernel_arg_type",
"kernel_arg_base_type",
"kernel_arg_type_qual"};
static const unsigned NumKernelArgMDNodes = 5;
namespace {
using MDVector = SmallVector<Metadata *, 8>;
struct KernelArgMD {
MDVector ArgVector[NumKernelArgMDNodes];
};
}
static inline bool
IsImageType(StringRef TypeString) {
return TypeString == "image2d_t" || TypeString == "image3d_t";
}
static inline bool
IsSamplerType(StringRef TypeString) {
return TypeString == "sampler_t";
}
static Function *
GetFunctionFromMDNode(MDNode *Node) {
if (!Node)
return nullptr;
size_t NumOps = Node->getNumOperands();
if (NumOps != NumKernelArgMDNodes + 1)
return nullptr;
auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
if (!F)
return nullptr;
size_t ExpectNumArgNodeOps = F->arg_size() + 1;
for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
return nullptr;
if (!ArgNode->getOperand(0))
return nullptr;
MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0));
if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i])
return nullptr;
}
return F;
}
static StringRef
AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
}
static StringRef
ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
}
static MDVector
GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
MDVector Res;
for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
Res.push_back(Node->getOperand(OpIdx));
}
return Res;
}
static void
PushArgMD(KernelArgMD &MD, const MDVector &V) {
assert(V.size() == NumKernelArgMDNodes);
for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
MD.ArgVector[i].push_back(V[i]);
}
}
namespace {
class R600OpenCLImageTypeLoweringPass : public ModulePass {
static char ID;
LLVMContext *Context;
Type *Int32Type;
Type *ImageSizeType;
Type *ImageFormatType;
SmallVector<Instruction *, 4> InstsToErase;
bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
Argument &ImageSizeArg,
Argument &ImageFormatArg) {
bool Modified = false;
for (auto &Use : ImageArg.uses()) {
auto Inst = dyn_cast<CallInst>(Use.getUser());
if (!Inst) {
continue;
}
Function *F = Inst->getCalledFunction();
if (!F)
continue;
Value *Replacement = nullptr;
StringRef Name = F->getName();
if (Name.startswith(GetImageResourceIDFunc)) {
Replacement = ConstantInt::get(Int32Type, ResourceID);
} else if (Name.startswith(GetImageSizeFunc)) {
Replacement = &ImageSizeArg;
} else if (Name.startswith(GetImageFormatFunc)) {
Replacement = &ImageFormatArg;
} else {
continue;
}
Inst->replaceAllUsesWith(Replacement);
InstsToErase.push_back(Inst);
Modified = true;
}
return Modified;
}
bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
bool Modified = false;
for (const auto &Use : SamplerArg.uses()) {
auto Inst = dyn_cast<CallInst>(Use.getUser());
if (!Inst) {
continue;
}
Function *F = Inst->getCalledFunction();
if (!F)
continue;
Value *Replacement = nullptr;
StringRef Name = F->getName();
if (Name == GetSamplerResourceIDFunc) {
Replacement = ConstantInt::get(Int32Type, ResourceID);
} else {
continue;
}
Inst->replaceAllUsesWith(Replacement);
InstsToErase.push_back(Inst);
Modified = true;
}
return Modified;
}
bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
uint32_t NumReadOnlyImageArgs = 0;
uint32_t NumWriteOnlyImageArgs = 0;
uint32_t NumSamplerArgs = 0;
bool Modified = false;
InstsToErase.clear();
for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
Argument &Arg = *ArgI;
StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
if (IsImageType(Type)) {
StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
uint32_t ResourceID;
if (AccessQual == "read_only") {
ResourceID = NumReadOnlyImageArgs++;
} else if (AccessQual == "write_only") {
ResourceID = NumWriteOnlyImageArgs++;
} else {
llvm_unreachable("Wrong image access qualifier.");
}
Argument &SizeArg = *(++ArgI);
Argument &FormatArg = *(++ArgI);
Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
} else if (IsSamplerType(Type)) {
uint32_t ResourceID = NumSamplerArgs++;
Modified |= replaceSamplerUses(Arg, ResourceID);
}
}
for (unsigned i = 0; i < InstsToErase.size(); ++i) {
InstsToErase[i]->eraseFromParent();
}
return Modified;
}
std::tuple<Function *, MDNode *>
addImplicitArgs(Function *F, MDNode *KernelMDNode) {
bool Modified = false;
FunctionType *FT = F->getFunctionType();
SmallVector<Type *, 8> ArgTypes;
KernelArgMD NewArgMDs;
PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
for (unsigned i = 0; i < FT->getNumParams(); ++i) {
ArgTypes.push_back(FT->getParamType(i));
MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
PushArgMD(NewArgMDs, ArgMD);
if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
continue;
ArgTypes.push_back(ImageSizeType);
ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
PushArgMD(NewArgMDs, ArgMD);
ArgTypes.push_back(ImageFormatType);
ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
PushArgMD(NewArgMDs, ArgMD);
Modified = true;
}
if (!Modified) {
return std::make_tuple(nullptr, nullptr);
}
auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
ValueToValueMapTy VMap;
auto NewFArgIt = NewF->arg_begin();
for (auto &Arg: F->args()) {
auto ArgName = Arg.getName();
NewFArgIt->setName(ArgName);
VMap[&Arg] = &(*NewFArgIt++);
if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
(NewFArgIt++)->setName(Twine("__size_") + ArgName);
(NewFArgIt++)->setName(Twine("__format_") + ArgName);
}
}
SmallVector<ReturnInst*, 8> Returns;
CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
Returns);
SmallVector<Metadata *, 6> KernelMDArgs;
KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
for (const MDVector &MDV : NewArgMDs.ArgVector)
KernelMDArgs.push_back(MDNode::get(*Context, MDV));
MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
return std::make_tuple(NewF, NewMDNode);
}
bool transformKernels(Module &M) {
NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
if (!KernelsMDNode)
return false;
bool Modified = false;
for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
Function *F = GetFunctionFromMDNode(KernelMDNode);
if (!F)
continue;
Function *NewF;
MDNode *NewMDNode;
std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
if (NewF) {
F->eraseFromParent();
M.getFunctionList().push_back(NewF);
M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
NewF->getAttributes());
KernelsMDNode->setOperand(i, NewMDNode);
F = NewF;
KernelMDNode = NewMDNode;
Modified = true;
}
Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
}
return Modified;
}
public:
R600OpenCLImageTypeLoweringPass() : ModulePass(ID) {}
bool runOnModule(Module &M) override {
Context = &M.getContext();
Int32Type = Type::getInt32Ty(M.getContext());
ImageSizeType = ArrayType::get(Int32Type, 3);
ImageFormatType = ArrayType::get(Int32Type, 2);
return transformKernels(M);
}
StringRef getPassName() const override {
return "R600 OpenCL Image Type Pass";
}
};
}
char R600OpenCLImageTypeLoweringPass::ID = 0;
ModulePass *llvm::createR600OpenCLImageTypeLoweringPass() {
return new R600OpenCLImageTypeLoweringPass();
}