#ifndef LLVM_ADT_FUNCTIONEXTRAS_H
#define LLVM_ADT_FUNCTIONEXTRAS_H
#include "llvm/ADT/PointerIntPair.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/STLForwardCompat.h"
#include "llvm/Support/MemAlloc.h"
#include "llvm/Support/type_traits.h"
#include <cstring>
#include <memory>
#include <type_traits>
namespace llvm {
template <typename FunctionT> class unique_function;
namespace detail {
template <typename T>
using EnableIfTrivial =
std::enable_if_t<llvm::is_trivially_move_constructible<T>::value &&
std::is_trivially_destructible<T>::value>;
template <typename CallableT, typename ThisT>
using EnableUnlessSameType =
std::enable_if_t<!std::is_same<remove_cvref_t<CallableT>, ThisT>::value>;
template <typename CallableT, typename Ret, typename... Params>
using EnableIfCallable = std::enable_if_t<llvm::disjunction<
std::is_void<Ret>,
std::is_same<decltype(std::declval<CallableT>()(std::declval<Params>()...)),
Ret>,
std::is_same<const decltype(std::declval<CallableT>()(
std::declval<Params>()...)),
Ret>,
std::is_convertible<decltype(std::declval<CallableT>()(
std::declval<Params>()...)),
Ret>>::value>;
template <typename ReturnT, typename... ParamTs> class UniqueFunctionBase {
protected:
static constexpr size_t InlineStorageSize = sizeof(void *) * 3;
template <typename T, class = void>
struct IsSizeLessThanThresholdT : std::false_type {};
template <typename T>
struct IsSizeLessThanThresholdT<
T, std::enable_if_t<sizeof(T) <= 2 * sizeof(void *)>> : std::true_type {};
template <typename T> struct AdjustedParamTBase {
static_assert(!std::is_reference<T>::value,
"references should be handled by template specialization");
using type = typename std::conditional<
llvm::is_trivially_copy_constructible<T>::value &&
llvm::is_trivially_move_constructible<T>::value &&
IsSizeLessThanThresholdT<T>::value,
T, T &>::type;
};
template <typename T> struct AdjustedParamTBase<T &> { using type = T &; };
template <typename T> struct AdjustedParamTBase<T &&> { using type = T &; };
template <typename T>
using AdjustedParamT = typename AdjustedParamTBase<T>::type;
using CallPtrT = ReturnT (*)(void *CallableAddr,
AdjustedParamT<ParamTs>... Params);
using MovePtrT = void (*)(void *LHSCallableAddr, void *RHSCallableAddr);
using DestroyPtrT = void (*)(void *CallableAddr);
struct alignas(8) TrivialCallback {
CallPtrT CallPtr;
};
struct alignas(8) NonTrivialCallbacks {
CallPtrT CallPtr;
MovePtrT MovePtr;
DestroyPtrT DestroyPtr;
};
using CallbackPointerUnionT =
PointerUnion<TrivialCallback *, NonTrivialCallbacks *>;
union StorageUnionT {
struct OutOfLineStorageT {
void *StoragePtr;
size_t Size;
size_t Alignment;
} OutOfLineStorage;
static_assert(
sizeof(OutOfLineStorageT) <= InlineStorageSize,
"Should always use all of the out-of-line storage for inline storage!");
mutable
typename std::aligned_storage<InlineStorageSize, alignof(void *)>::type
InlineStorage;
} StorageUnion;
PointerIntPair<CallbackPointerUnionT, 1, bool> CallbackAndInlineFlag;
bool isInlineStorage() const { return CallbackAndInlineFlag.getInt(); }
bool isTrivialCallback() const {
return CallbackAndInlineFlag.getPointer().template is<TrivialCallback *>();
}
CallPtrT getTrivialCallback() const {
return CallbackAndInlineFlag.getPointer().template get<TrivialCallback *>()->CallPtr;
}
NonTrivialCallbacks *getNonTrivialCallbacks() const {
return CallbackAndInlineFlag.getPointer()
.template get<NonTrivialCallbacks *>();
}
CallPtrT getCallPtr() const {
return isTrivialCallback() ? getTrivialCallback()
: getNonTrivialCallbacks()->CallPtr;
}
void *getCalleePtr() const {
return isInlineStorage() ? getInlineStorage() : getOutOfLineStorage();
}
void *getInlineStorage() const { return &StorageUnion.InlineStorage; }
void *getOutOfLineStorage() const {
return StorageUnion.OutOfLineStorage.StoragePtr;
}
size_t getOutOfLineStorageSize() const {
return StorageUnion.OutOfLineStorage.Size;
}
size_t getOutOfLineStorageAlignment() const {
return StorageUnion.OutOfLineStorage.Alignment;
}
void setOutOfLineStorage(void *Ptr, size_t Size, size_t Alignment) {
StorageUnion.OutOfLineStorage = {Ptr, Size, Alignment};
}
template <typename CalledAsT>
static ReturnT CallImpl(void *CallableAddr,
AdjustedParamT<ParamTs>... Params) {
auto &Func = *reinterpret_cast<CalledAsT *>(CallableAddr);
return Func(std::forward<ParamTs>(Params)...);
}
template <typename CallableT>
static void MoveImpl(void *LHSCallableAddr, void *RHSCallableAddr) noexcept {
new (LHSCallableAddr)
CallableT(std::move(*reinterpret_cast<CallableT *>(RHSCallableAddr)));
}
template <typename CallableT>
static void DestroyImpl(void *CallableAddr) noexcept {
reinterpret_cast<CallableT *>(CallableAddr)->~CallableT();
}
template <typename CallableT, typename CalledAs, typename Enable = void>
struct CallbacksHolder {
static NonTrivialCallbacks Callbacks;
};
template <typename CallableT, typename CalledAs>
struct CallbacksHolder<CallableT, CalledAs, EnableIfTrivial<CallableT>> {
static TrivialCallback Callbacks;
};
template <typename T> struct CalledAs {};
template <typename CallableT, typename CalledAsT>
UniqueFunctionBase(CallableT Callable, CalledAs<CalledAsT>) {
bool IsInlineStorage = true;
void *CallableAddr = getInlineStorage();
if (sizeof(CallableT) > InlineStorageSize ||
alignof(CallableT) > alignof(decltype(StorageUnion.InlineStorage))) {
IsInlineStorage = false;
auto Size = sizeof(CallableT);
auto Alignment = alignof(CallableT);
CallableAddr = allocate_buffer(Size, Alignment);
setOutOfLineStorage(CallableAddr, Size, Alignment);
}
new (CallableAddr) CallableT(std::move(Callable));
CallbackAndInlineFlag.setPointerAndInt(
&CallbacksHolder<CallableT, CalledAsT>::Callbacks, IsInlineStorage);
}
~UniqueFunctionBase() {
if (!CallbackAndInlineFlag.getPointer())
return;
bool IsInlineStorage = isInlineStorage();
if (!isTrivialCallback())
getNonTrivialCallbacks()->DestroyPtr(
IsInlineStorage ? getInlineStorage() : getOutOfLineStorage());
if (!IsInlineStorage)
deallocate_buffer(getOutOfLineStorage(), getOutOfLineStorageSize(),
getOutOfLineStorageAlignment());
}
UniqueFunctionBase(UniqueFunctionBase &&RHS) noexcept {
CallbackAndInlineFlag = RHS.CallbackAndInlineFlag;
if (!RHS)
return;
if (!isInlineStorage()) {
StorageUnion.OutOfLineStorage = RHS.StorageUnion.OutOfLineStorage;
} else if (isTrivialCallback()) {
memcpy(getInlineStorage(), RHS.getInlineStorage(), InlineStorageSize);
} else {
getNonTrivialCallbacks()->MovePtr(getInlineStorage(),
RHS.getInlineStorage());
}
RHS.CallbackAndInlineFlag = {};
#ifndef NDEBUG
memset(RHS.getInlineStorage(), 0xAD, InlineStorageSize);
#endif
}
UniqueFunctionBase &operator=(UniqueFunctionBase &&RHS) noexcept {
if (this == &RHS)
return *this;
this->~UniqueFunctionBase();
new (this) UniqueFunctionBase(std::move(RHS));
return *this;
}
UniqueFunctionBase() = default;
public:
explicit operator bool() const {
return (bool)CallbackAndInlineFlag.getPointer();
}
};
template <typename R, typename... P>
template <typename CallableT, typename CalledAsT, typename Enable>
typename UniqueFunctionBase<R, P...>::NonTrivialCallbacks UniqueFunctionBase<
R, P...>::CallbacksHolder<CallableT, CalledAsT, Enable>::Callbacks = {
&CallImpl<CalledAsT>, &MoveImpl<CallableT>, &DestroyImpl<CallableT>};
template <typename R, typename... P>
template <typename CallableT, typename CalledAsT>
typename UniqueFunctionBase<R, P...>::TrivialCallback
UniqueFunctionBase<R, P...>::CallbacksHolder<
CallableT, CalledAsT, EnableIfTrivial<CallableT>>::Callbacks{
&CallImpl<CalledAsT>};
}
template <typename R, typename... P>
class unique_function<R(P...)> : public detail::UniqueFunctionBase<R, P...> {
using Base = detail::UniqueFunctionBase<R, P...>;
public:
unique_function() = default;
unique_function(std::nullptr_t) {}
unique_function(unique_function &&) = default;
unique_function(const unique_function &) = delete;
unique_function &operator=(unique_function &&) = default;
unique_function &operator=(const unique_function &) = delete;
template <typename CallableT>
unique_function(
CallableT Callable,
detail::EnableUnlessSameType<CallableT, unique_function> * = nullptr,
detail::EnableIfCallable<CallableT, R, P...> * = nullptr)
: Base(std::forward<CallableT>(Callable),
typename Base::template CalledAs<CallableT>{}) {}
R operator()(P... Params) {
return this->getCallPtr()(this->getCalleePtr(), Params...);
}
};
template <typename R, typename... P>
class unique_function<R(P...) const>
: public detail::UniqueFunctionBase<R, P...> {
using Base = detail::UniqueFunctionBase<R, P...>;
public:
unique_function() = default;
unique_function(std::nullptr_t) {}
unique_function(unique_function &&) = default;
unique_function(const unique_function &) = delete;
unique_function &operator=(unique_function &&) = default;
unique_function &operator=(const unique_function &) = delete;
template <typename CallableT>
unique_function(
CallableT Callable,
detail::EnableUnlessSameType<CallableT, unique_function> * = nullptr,
detail::EnableIfCallable<const CallableT, R, P...> * = nullptr)
: Base(std::forward<CallableT>(Callable),
typename Base::template CalledAs<const CallableT>{}) {}
R operator()(P... Params) const {
return this->getCallPtr()(this->getCalleePtr(), Params...);
}
};
}
#endif