#include "DXILPointerType.h"
#include "PointerTypeAnalysis.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/SourceMgr.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
using ::testing::Contains;
using ::testing::Pair;
using namespace llvm;
using namespace llvm::dxil;
template <typename T> struct IsA {
friend bool operator==(const Value *V, const IsA &) { return isa<T>(V); }
};
TEST(DXILPointerType, PrintTest) {
std::string Buffer;
LLVMContext Context;
raw_string_ostream OS(Buffer);
Type *I8Ptr = TypedPointerType::get(Type::getInt8Ty(Context), 0);
I8Ptr->print(OS);
EXPECT_TRUE(StringRef(Buffer).startswith("dxil-ptr ("));
}
TEST(PointerTypeAnalysis, DigressToi8) {
StringRef Assembly = R"(
define i64 @test(ptr %p) {
store i32 0, ptr %p
%v = load i64, ptr %p
ret i64 %v
}
)";
LLVMContext Context;
SMDiagnostic Error;
auto M = parseAssemblyString(Assembly, Error, Context);
ASSERT_TRUE(M) << "Bad assembly?";
PointerTypeMap Map = PointerTypeAnalysis::run(*M);
ASSERT_EQ(Map.size(), 2u);
Type *I8Ptr = TypedPointerType::get(Type::getInt8Ty(Context), 0);
Type *FnTy = FunctionType::get(Type::getInt64Ty(Context), {I8Ptr}, false);
EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I8Ptr)));
}
TEST(PointerTypeAnalysis, DiscoverStore) {
StringRef Assembly = R"(
define i32 @test(ptr %p) {
store i32 0, ptr %p
ret i32 0
}
)";
LLVMContext Context;
SMDiagnostic Error;
auto M = parseAssemblyString(Assembly, Error, Context);
ASSERT_TRUE(M) << "Bad assembly?";
PointerTypeMap Map = PointerTypeAnalysis::run(*M);
ASSERT_EQ(Map.size(), 2u);
Type *I32Ptr = TypedPointerType::get(Type::getInt32Ty(Context), 0);
Type *FnTy = FunctionType::get(Type::getInt32Ty(Context), {I32Ptr}, false);
EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I32Ptr)));
}
TEST(PointerTypeAnalysis, DiscoverLoad) {
StringRef Assembly = R"(
define i32 @test(ptr %p) {
%v = load i32, ptr %p
ret i32 %v
}
)";
LLVMContext Context;
SMDiagnostic Error;
auto M = parseAssemblyString(Assembly, Error, Context);
ASSERT_TRUE(M) << "Bad assembly?";
PointerTypeMap Map = PointerTypeAnalysis::run(*M);
ASSERT_EQ(Map.size(), 2u);
Type *I32Ptr = TypedPointerType::get(Type::getInt32Ty(Context), 0);
Type *FnTy = FunctionType::get(Type::getInt32Ty(Context), {I32Ptr}, false);
EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I32Ptr)));
}
TEST(PointerTypeAnalysis, DiscoverGEP) {
StringRef Assembly = R"(
define ptr @test(ptr %p) {
%p2 = getelementptr i64, ptr %p, i64 1
ret ptr %p2
}
)";
LLVMContext Context;
SMDiagnostic Error;
auto M = parseAssemblyString(Assembly, Error, Context);
ASSERT_TRUE(M) << "Bad assembly?";
PointerTypeMap Map = PointerTypeAnalysis::run(*M);
ASSERT_EQ(Map.size(), 3u);
Type *I64Ptr = TypedPointerType::get(Type::getInt64Ty(Context), 0);
Type *FnTy = FunctionType::get(I64Ptr, {I64Ptr}, false);
EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I64Ptr)));
EXPECT_THAT(Map, Contains(Pair(IsA<GetElementPtrInst>(), I64Ptr)));
}
TEST(PointerTypeAnalysis, TraceIndirect) {
StringRef Assembly = R"(
define i64 @test(ptr %p) {
%p2 = load ptr, ptr %p
%v = load i64, ptr %p2
ret i64 %v
}
)";
LLVMContext Context;
SMDiagnostic Error;
auto M = parseAssemblyString(Assembly, Error, Context);
ASSERT_TRUE(M) << "Bad assembly?";
PointerTypeMap Map = PointerTypeAnalysis::run(*M);
ASSERT_EQ(Map.size(), 3u);
Type *I64Ptr = TypedPointerType::get(Type::getInt64Ty(Context), 0);
Type *I64PtrPtr = TypedPointerType::get(I64Ptr, 0);
Type *FnTy = FunctionType::get(Type::getInt64Ty(Context), {I64PtrPtr}, false);
EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I64PtrPtr)));
EXPECT_THAT(Map, Contains(Pair(IsA<LoadInst>(), I64Ptr)));
}
TEST(PointerTypeAnalysis, WithNoOpCasts) {
StringRef Assembly = R"(
define i64 @test(ptr %p) {
%1 = bitcast ptr %p to ptr
%2 = bitcast ptr %p to ptr
store i32 0, ptr %1, align 4
%3 = load i64, ptr %2, align 8
ret i64 %3
}
)";
LLVMContext Context;
SMDiagnostic Error;
auto M = parseAssemblyString(Assembly, Error, Context);
ASSERT_TRUE(M) << "Bad assembly?";
PointerTypeMap Map = PointerTypeAnalysis::run(*M);
ASSERT_EQ(Map.size(), 4u);
Type *I8Ptr = TypedPointerType::get(Type::getInt8Ty(Context), 0);
Type *I32Ptr = TypedPointerType::get(Type::getInt32Ty(Context), 0);
Type *I64Ptr = TypedPointerType::get(Type::getInt64Ty(Context), 0);
Type *FnTy = FunctionType::get(Type::getInt64Ty(Context), {I8Ptr}, false);
EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I8Ptr)));
EXPECT_THAT(Map, Contains(Pair(IsA<BitCastInst>(), I64Ptr)));
EXPECT_THAT(Map, Contains(Pair(IsA<BitCastInst>(), I32Ptr)));
}