#include "CodeGenFunction.h"
#include "clang/Basic/Builtins.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Instruction.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Transforms/Utils/AMDGPUEmitPrintf.h"
using namespace clang;
using namespace CodeGen;
namespace {
llvm::Function *GetVprintfDeclaration(llvm::Module &M) {
llvm::Type *ArgTypes[] = {llvm::Type::getInt8PtrTy(M.getContext()),
llvm::Type::getInt8PtrTy(M.getContext())};
llvm::FunctionType *VprintfFuncType = llvm::FunctionType::get(
llvm::Type::getInt32Ty(M.getContext()), ArgTypes, false);
if (auto *F = M.getFunction("vprintf")) {
assert(F->getFunctionType() == VprintfFuncType);
return F;
}
return llvm::Function::Create(
VprintfFuncType, llvm::GlobalVariable::ExternalLinkage, "vprintf", &M);
}
llvm::Function *GetOpenMPVprintfDeclaration(CodeGenModule &CGM) {
const char *Name = "__llvm_omp_vprintf";
llvm::Module &M = CGM.getModule();
llvm::Type *ArgTypes[] = {llvm::Type::getInt8PtrTy(M.getContext()),
llvm::Type::getInt8PtrTy(M.getContext()),
llvm::Type::getInt32Ty(M.getContext())};
llvm::FunctionType *VprintfFuncType = llvm::FunctionType::get(
llvm::Type::getInt32Ty(M.getContext()), ArgTypes, false);
if (auto *F = M.getFunction(Name)) {
if (F->getFunctionType() != VprintfFuncType) {
CGM.Error(SourceLocation(),
"Invalid type declaration for __llvm_omp_vprintf");
return nullptr;
}
return F;
}
return llvm::Function::Create(
VprintfFuncType, llvm::GlobalVariable::ExternalLinkage, Name, &M);
}
std::pair<llvm::Value *, llvm::TypeSize>
packArgsIntoNVPTXFormatBuffer(CodeGenFunction *CGF, const CallArgList &Args) {
const llvm::DataLayout &DL = CGF->CGM.getDataLayout();
llvm::LLVMContext &Ctx = CGF->CGM.getLLVMContext();
CGBuilderTy &Builder = CGF->Builder;
if (Args.size() <= 1) {
llvm::Value * BufferPtr = llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(Ctx));
return {BufferPtr, llvm::TypeSize::Fixed(0)};
} else {
llvm::SmallVector<llvm::Type *, 8> ArgTypes;
for (unsigned I = 1, NumArgs = Args.size(); I < NumArgs; ++I)
ArgTypes.push_back(Args[I].getRValue(*CGF).getScalarVal()->getType());
llvm::Type *AllocaTy = llvm::StructType::create(ArgTypes, "printf_args");
llvm::Value *Alloca = CGF->CreateTempAlloca(AllocaTy);
for (unsigned I = 1, NumArgs = Args.size(); I < NumArgs; ++I) {
llvm::Value *P = Builder.CreateStructGEP(AllocaTy, Alloca, I - 1);
llvm::Value *Arg = Args[I].getRValue(*CGF).getScalarVal();
Builder.CreateAlignedStore(Arg, P, DL.getPrefTypeAlign(Arg->getType()));
}
llvm::Value *BufferPtr =
Builder.CreatePointerCast(Alloca, llvm::Type::getInt8PtrTy(Ctx));
return {BufferPtr, DL.getTypeAllocSize(AllocaTy)};
}
}
bool containsNonScalarVarargs(CodeGenFunction *CGF, CallArgList Args) {
return llvm::any_of(llvm::drop_begin(Args), [&](const CallArg &A) {
return !A.getRValue(*CGF).isScalar();
});
}
RValue EmitDevicePrintfCallExpr(const CallExpr *E, CodeGenFunction *CGF,
llvm::Function *Decl, bool WithSizeArg) {
CodeGenModule &CGM = CGF->CGM;
CGBuilderTy &Builder = CGF->Builder;
assert(E->getBuiltinCallee() == Builtin::BIprintf);
assert(E->getNumArgs() >= 1);
CallArgList Args;
CGF->EmitCallArgs(Args,
E->getDirectCallee()->getType()->getAs<FunctionProtoType>(),
E->arguments(), E->getDirectCallee(),
0);
if (containsNonScalarVarargs(CGF, Args)) {
CGM.ErrorUnsupported(E, "non-scalar arg to printf");
return RValue::get(llvm::ConstantInt::get(CGF->IntTy, 0));
}
auto r = packArgsIntoNVPTXFormatBuffer(CGF, Args);
llvm::Value *BufferPtr = r.first;
llvm::SmallVector<llvm::Value *, 3> Vec = {
Args[0].getRValue(*CGF).getScalarVal(), BufferPtr};
if (WithSizeArg) {
llvm::Constant *Size =
llvm::ConstantInt::get(llvm::Type::getInt32Ty(CGM.getLLVMContext()),
static_cast<uint32_t>(r.second.getFixedSize()));
Vec.push_back(Size);
}
return RValue::get(Builder.CreateCall(Decl, Vec));
}
}
RValue CodeGenFunction::EmitNVPTXDevicePrintfCallExpr(const CallExpr *E) {
assert(getTarget().getTriple().isNVPTX());
return EmitDevicePrintfCallExpr(
E, this, GetVprintfDeclaration(CGM.getModule()), false);
}
RValue CodeGenFunction::EmitAMDGPUDevicePrintfCallExpr(const CallExpr *E) {
assert(getTarget().getTriple().getArch() == llvm::Triple::amdgcn);
assert(E->getBuiltinCallee() == Builtin::BIprintf ||
E->getBuiltinCallee() == Builtin::BI__builtin_printf);
assert(E->getNumArgs() >= 1);
CallArgList CallArgs;
EmitCallArgs(CallArgs,
E->getDirectCallee()->getType()->getAs<FunctionProtoType>(),
E->arguments(), E->getDirectCallee(),
0);
SmallVector<llvm::Value *, 8> Args;
for (auto A : CallArgs) {
if (!A.getRValue(*this).isScalar()) {
CGM.ErrorUnsupported(E, "non-scalar arg to printf");
return RValue::get(llvm::ConstantInt::get(IntTy, -1));
}
llvm::Value *Arg = A.getRValue(*this).getScalarVal();
Args.push_back(Arg);
}
llvm::IRBuilder<> IRB(Builder.GetInsertBlock(), Builder.GetInsertPoint());
IRB.SetCurrentDebugLocation(Builder.getCurrentDebugLocation());
auto Printf = llvm::emitAMDGPUPrintfCall(IRB, Args);
Builder.SetInsertPoint(IRB.GetInsertBlock(), IRB.GetInsertPoint());
return RValue::get(Printf);
}
RValue CodeGenFunction::EmitOpenMPDevicePrintfCallExpr(const CallExpr *E) {
assert(getTarget().getTriple().isNVPTX() ||
getTarget().getTriple().isAMDGCN());
return EmitDevicePrintfCallExpr(E, this, GetOpenMPVprintfDeclaration(CGM),
true);
}