#include "DXILOpBuilder.h"
#include "DXILConstants.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/DXILOperationCommon.h"
#include "llvm/Support/ErrorHandling.h"
using namespace llvm;
using namespace llvm::DXIL;
constexpr StringLiteral DXILOpNamePrefix = "dx.op.";
namespace {
enum OverloadKind : uint16_t {
VOID = 1,
HALF = 1 << 1,
FLOAT = 1 << 2,
DOUBLE = 1 << 3,
I1 = 1 << 4,
I8 = 1 << 5,
I16 = 1 << 6,
I32 = 1 << 7,
I64 = 1 << 8,
UserDefineType = 1 << 9,
ObjectType = 1 << 10,
};
}
static const char *getOverloadTypeName(OverloadKind Kind) {
switch (Kind) {
case OverloadKind::HALF:
return "f16";
case OverloadKind::FLOAT:
return "f32";
case OverloadKind::DOUBLE:
return "f64";
case OverloadKind::I1:
return "i1";
case OverloadKind::I8:
return "i8";
case OverloadKind::I16:
return "i16";
case OverloadKind::I32:
return "i32";
case OverloadKind::I64:
return "i64";
case OverloadKind::VOID:
case OverloadKind::ObjectType:
case OverloadKind::UserDefineType:
break;
}
llvm_unreachable("invalid overload type for name");
return "void";
}
static OverloadKind getOverloadKind(Type *Ty) {
Type::TypeID T = Ty->getTypeID();
switch (T) {
case Type::VoidTyID:
return OverloadKind::VOID;
case Type::HalfTyID:
return OverloadKind::HALF;
case Type::FloatTyID:
return OverloadKind::FLOAT;
case Type::DoubleTyID:
return OverloadKind::DOUBLE;
case Type::IntegerTyID: {
IntegerType *ITy = cast<IntegerType>(Ty);
unsigned Bits = ITy->getBitWidth();
switch (Bits) {
case 1:
return OverloadKind::I1;
case 8:
return OverloadKind::I8;
case 16:
return OverloadKind::I16;
case 32:
return OverloadKind::I32;
case 64:
return OverloadKind::I64;
default:
llvm_unreachable("invalid overload type");
return OverloadKind::VOID;
}
}
case Type::PointerTyID:
return OverloadKind::UserDefineType;
case Type::StructTyID:
return OverloadKind::ObjectType;
default:
llvm_unreachable("invalid overload type");
return OverloadKind::VOID;
}
}
static std::string getTypeName(OverloadKind Kind, Type *Ty) {
if (Kind < OverloadKind::UserDefineType) {
return getOverloadTypeName(Kind);
} else if (Kind == OverloadKind::UserDefineType) {
StructType *ST = cast<StructType>(Ty);
return ST->getStructName().str();
} else if (Kind == OverloadKind::ObjectType) {
StructType *ST = cast<StructType>(Ty);
return ST->getStructName().str();
} else {
std::string Str;
raw_string_ostream OS(Str);
Ty->print(OS);
return OS.str();
}
}
struct OpCodeProperty {
DXIL::OpCode OpCode;
unsigned OpCodeNameOffset;
DXIL::OpCodeClass OpCodeClass;
unsigned OpCodeClassNameOffset;
uint16_t OverloadTys;
llvm::Attribute::AttrKind FuncAttr;
int OverloadParamIndex; unsigned NumOfParameters; unsigned ParameterTableOffset; };
#define DXIL_OP_OPERATION_TABLE
#include "DXILOperation.inc"
#undef DXIL_OP_OPERATION_TABLE
static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
const OpCodeProperty &Prop) {
if (Kind == OverloadKind::VOID) {
return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();
}
return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +
getTypeName(Kind, Ty))
.str();
}
static std::string constructOverloadTypeName(OverloadKind Kind,
StringRef TypeName) {
if (Kind == OverloadKind::VOID)
return TypeName.str();
assert(Kind < OverloadKind::UserDefineType && "invalid overload kind");
return (Twine(TypeName) + getOverloadTypeName(Kind)).str();
}
static StructType *getOrCreateStructType(StringRef Name,
ArrayRef<Type *> EltTys,
LLVMContext &Ctx) {
StructType *ST = StructType::getTypeByName(Ctx, Name);
if (ST)
return ST;
return StructType::create(Ctx, EltTys, Name);
}
static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) {
OverloadKind Kind = getOverloadKind(OverloadTy);
std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet.");
Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy,
Type::getInt32Ty(Ctx)};
return getOrCreateStructType(TypeName, FieldTypes, Ctx);
}
static StructType *getHandleType(LLVMContext &Ctx) {
return getOrCreateStructType("dx.types.Handle", Type::getInt8PtrTy(Ctx), Ctx);
}
static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) {
auto &Ctx = OverloadTy->getContext();
switch (Kind) {
case ParameterKind::VOID:
return Type::getVoidTy(Ctx);
case ParameterKind::HALF:
return Type::getHalfTy(Ctx);
case ParameterKind::FLOAT:
return Type::getFloatTy(Ctx);
case ParameterKind::DOUBLE:
return Type::getDoubleTy(Ctx);
case ParameterKind::I1:
return Type::getInt1Ty(Ctx);
case ParameterKind::I8:
return Type::getInt8Ty(Ctx);
case ParameterKind::I16:
return Type::getInt16Ty(Ctx);
case ParameterKind::I32:
return Type::getInt32Ty(Ctx);
case ParameterKind::I64:
return Type::getInt64Ty(Ctx);
case ParameterKind::OVERLOAD:
return OverloadTy;
case ParameterKind::RESOURCE_RET:
return getResRetType(OverloadTy, Ctx);
case ParameterKind::DXIL_HANDLE:
return getHandleType(Ctx);
default:
break;
}
llvm_unreachable("Invalid parameter kind");
return nullptr;
}
static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
Type *OverloadTy) {
SmallVector<Type *> ArgTys;
auto ParamKinds = getOpCodeParameterKind(*Prop);
for (unsigned I = 0; I < Prop->NumOfParameters; ++I) {
ParameterKind Kind = ParamKinds[I];
ArgTys.emplace_back(getTypeFromParameterKind(Kind, OverloadTy));
}
return FunctionType::get(
ArgTys[0], ArrayRef<Type *>(&ArgTys[1], ArgTys.size() - 1), false);
}
static FunctionCallee getOrCreateDXILOpFunction(DXIL::OpCode DXILOp,
Type *OverloadTy, Module &M) {
const OpCodeProperty *Prop = getOpCodeProperty(DXILOp);
OverloadKind Kind = getOverloadKind(OverloadTy);
if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
llvm_unreachable("invalid overload");
}
std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop);
if (auto *Fn = M.getFunction(FnName))
return FunctionCallee(Fn);
FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, OverloadTy);
return M.getOrInsertFunction(FnName, DXILOpFT);
}
namespace llvm {
namespace DXIL {
CallInst *DXILOpBuilder::createDXILOpCall(DXIL::OpCode OpCode, Type *OverloadTy,
llvm::iterator_range<Use *> Args) {
auto Fn = getOrCreateDXILOpFunction(OpCode, OverloadTy, M);
SmallVector<Value *> FullArgs;
FullArgs.emplace_back(B.getInt32((int32_t)OpCode));
FullArgs.append(Args.begin(), Args.end());
return B.CreateCall(Fn, FullArgs);
}
Type *DXILOpBuilder::getOverloadTy(DXIL::OpCode OpCode, FunctionType *FT,
bool NoOpCodeParam) {
const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
if (Prop->OverloadParamIndex < 0) {
auto &Ctx = FT->getContext();
switch (Prop->OverloadTys) {
case OverloadKind::VOID:
return Type::getVoidTy(Ctx);
case OverloadKind::HALF:
return Type::getHalfTy(Ctx);
case OverloadKind::FLOAT:
return Type::getFloatTy(Ctx);
case OverloadKind::DOUBLE:
return Type::getDoubleTy(Ctx);
case OverloadKind::I1:
return Type::getInt1Ty(Ctx);
case OverloadKind::I8:
return Type::getInt8Ty(Ctx);
case OverloadKind::I16:
return Type::getInt16Ty(Ctx);
case OverloadKind::I32:
return Type::getInt32Ty(Ctx);
case OverloadKind::I64:
return Type::getInt64Ty(Ctx);
default:
llvm_unreachable("invalid overload type");
return nullptr;
}
}
Type *OverloadType = FT->getReturnType();
if (Prop->OverloadParamIndex != 0) {
const unsigned SkipedParam = NoOpCodeParam ? 2 : 1;
OverloadType = FT->getParamType(Prop->OverloadParamIndex - SkipedParam);
}
auto ParamKinds = getOpCodeParameterKind(*Prop);
auto Kind = ParamKinds[Prop->OverloadParamIndex];
if (Kind == ParameterKind::CBUFFER_RET ||
Kind == ParameterKind::RESOURCE_RET) {
auto *ST = cast<StructType>(OverloadType);
OverloadType = ST->getElementType(0);
}
return OverloadType;
}
const char *DXILOpBuilder::getOpCodeName(DXIL::OpCode DXILOp) {
return ::getOpCodeName(DXILOp);
}
} }