#ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_WRAPPERFUNCTIONUTILS_H
#define LLVM_EXECUTIONENGINE_ORC_SHARED_WRAPPERFUNCTIONUTILS_H
#include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h"
#include "llvm/ExecutionEngine/Orc/Shared/SimplePackedSerialization.h"
#include "llvm/Support/Error.h"
#include <type_traits>
namespace llvm {
namespace orc {
namespace shared {
union CWrapperFunctionResultDataUnion {
char *ValuePtr;
char Value[sizeof(ValuePtr)];
};
typedef struct {
CWrapperFunctionResultDataUnion Data;
size_t Size;
} CWrapperFunctionResult;
class WrapperFunctionResult {
public:
WrapperFunctionResult() { init(R); }
WrapperFunctionResult(CWrapperFunctionResult R) : R(R) {
init(R);
}
WrapperFunctionResult(const WrapperFunctionResult &) = delete;
WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete;
WrapperFunctionResult(WrapperFunctionResult &&Other) {
init(R);
std::swap(R, Other.R);
}
WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) {
WrapperFunctionResult Tmp(std::move(Other));
std::swap(R, Tmp.R);
return *this;
}
~WrapperFunctionResult() {
if ((R.Size > sizeof(R.Data.Value)) ||
(R.Size == 0 && R.Data.ValuePtr != nullptr))
free(R.Data.ValuePtr);
}
CWrapperFunctionResult release() {
CWrapperFunctionResult Tmp;
init(Tmp);
std::swap(R, Tmp);
return Tmp;
}
char *data() {
assert((R.Size != 0 || R.Data.ValuePtr == nullptr) &&
"Cannot get data for out-of-band error value");
return R.Size > sizeof(R.Data.Value) ? R.Data.ValuePtr : R.Data.Value;
}
const char *data() const {
assert((R.Size != 0 || R.Data.ValuePtr == nullptr) &&
"Cannot get data for out-of-band error value");
return R.Size > sizeof(R.Data.Value) ? R.Data.ValuePtr : R.Data.Value;
}
size_t size() const {
assert((R.Size != 0 || R.Data.ValuePtr == nullptr) &&
"Cannot get data for out-of-band error value");
return R.Size;
}
bool empty() const { return R.Size == 0 && R.Data.ValuePtr == nullptr; }
static WrapperFunctionResult allocate(size_t Size) {
WrapperFunctionResult WFR;
WFR.R.Size = Size;
if (WFR.R.Size > sizeof(WFR.R.Data.Value))
WFR.R.Data.ValuePtr = (char *)malloc(WFR.R.Size);
return WFR;
}
static WrapperFunctionResult copyFrom(const char *Source, size_t Size) {
auto WFR = allocate(Size);
memcpy(WFR.data(), Source, Size);
return WFR;
}
static WrapperFunctionResult copyFrom(const char *Source) {
return copyFrom(Source, strlen(Source) + 1);
}
static WrapperFunctionResult copyFrom(const std::string &Source) {
return copyFrom(Source.c_str());
}
static WrapperFunctionResult createOutOfBandError(const char *Msg) {
WrapperFunctionResult WFR;
char *Tmp = (char *)malloc(strlen(Msg) + 1);
strcpy(Tmp, Msg);
WFR.R.Data.ValuePtr = Tmp;
return WFR;
}
static WrapperFunctionResult createOutOfBandError(const std::string &Msg) {
return createOutOfBandError(Msg.c_str());
}
const char *getOutOfBandError() const {
return R.Size == 0 ? R.Data.ValuePtr : nullptr;
}
private:
static void init(CWrapperFunctionResult &R) {
R.Data.ValuePtr = nullptr;
R.Size = 0;
}
CWrapperFunctionResult R;
};
namespace detail {
template <typename SPSArgListT, typename... ArgTs>
WrapperFunctionResult
serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) {
auto Result = WrapperFunctionResult::allocate(SPSArgListT::size(Args...));
SPSOutputBuffer OB(Result.data(), Result.size());
if (!SPSArgListT::serialize(OB, Args...))
return WrapperFunctionResult::createOutOfBandError(
"Error serializing arguments to blob in call");
return Result;
}
template <typename RetT> class WrapperFunctionHandlerCaller {
public:
template <typename HandlerT, typename ArgTupleT, std::size_t... I>
static decltype(auto) call(HandlerT &&H, ArgTupleT &Args,
std::index_sequence<I...>) {
return std::forward<HandlerT>(H)(std::get<I>(Args)...);
}
};
template <> class WrapperFunctionHandlerCaller<void> {
public:
template <typename HandlerT, typename ArgTupleT, std::size_t... I>
static SPSEmpty call(HandlerT &&H, ArgTupleT &Args,
std::index_sequence<I...>) {
std::forward<HandlerT>(H)(std::get<I>(Args)...);
return SPSEmpty();
}
};
template <typename WrapperFunctionImplT,
template <typename> class ResultSerializer, typename... SPSTagTs>
class WrapperFunctionHandlerHelper
: public WrapperFunctionHandlerHelper<
decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()),
ResultSerializer, SPSTagTs...> {};
template <typename RetT, typename... ArgTs,
template <typename> class ResultSerializer, typename... SPSTagTs>
class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
SPSTagTs...> {
public:
using ArgTuple = std::tuple<std::decay_t<ArgTs>...>;
using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>;
template <typename HandlerT>
static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData,
size_t ArgSize) {
ArgTuple Args;
if (!deserialize(ArgData, ArgSize, Args, ArgIndices{}))
return WrapperFunctionResult::createOutOfBandError(
"Could not deserialize arguments for wrapper function call");
auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call(
std::forward<HandlerT>(H), Args, ArgIndices{});
return ResultSerializer<decltype(HandlerResult)>::serialize(
std::move(HandlerResult));
}
private:
template <std::size_t... I>
static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args,
std::index_sequence<I...>) {
SPSInputBuffer IB(ArgData, ArgSize);
return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...);
}
};
template <typename RetT, typename... ArgTs,
template <typename> class ResultSerializer, typename... SPSTagTs>
class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,
SPSTagTs...>
: public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
SPSTagTs...> {};
template <typename ClassT, typename RetT, typename... ArgTs,
template <typename> class ResultSerializer, typename... SPSTagTs>
class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer,
SPSTagTs...>
: public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
SPSTagTs...> {};
template <typename ClassT, typename RetT, typename... ArgTs,
template <typename> class ResultSerializer, typename... SPSTagTs>
class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
ResultSerializer, SPSTagTs...>
: public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
SPSTagTs...> {};
template <typename WrapperFunctionImplT,
template <typename> class ResultSerializer, typename... SPSTagTs>
class WrapperFunctionAsyncHandlerHelper
: public WrapperFunctionAsyncHandlerHelper<
decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()),
ResultSerializer, SPSTagTs...> {};
template <typename RetT, typename SendResultT, typename... ArgTs,
template <typename> class ResultSerializer, typename... SPSTagTs>
class WrapperFunctionAsyncHandlerHelper<RetT(SendResultT, ArgTs...),
ResultSerializer, SPSTagTs...> {
public:
using ArgTuple = std::tuple<std::decay_t<ArgTs>...>;
using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>;
template <typename HandlerT, typename SendWrapperFunctionResultT>
static void applyAsync(HandlerT &&H,
SendWrapperFunctionResultT &&SendWrapperFunctionResult,
const char *ArgData, size_t ArgSize) {
ArgTuple Args;
if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) {
SendWrapperFunctionResult(WrapperFunctionResult::createOutOfBandError(
"Could not deserialize arguments for wrapper function call"));
return;
}
auto SendResult =
[SendWFR = std::move(SendWrapperFunctionResult)](auto Result) mutable {
using ResultT = decltype(Result);
SendWFR(ResultSerializer<ResultT>::serialize(std::move(Result)));
};
callAsync(std::forward<HandlerT>(H), std::move(SendResult), std::move(Args),
ArgIndices{});
}
private:
template <std::size_t... I>
static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args,
std::index_sequence<I...>) {
SPSInputBuffer IB(ArgData, ArgSize);
return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...);
}
template <typename HandlerT, typename SerializeAndSendResultT,
typename ArgTupleT, std::size_t... I>
static void callAsync(HandlerT &&H,
SerializeAndSendResultT &&SerializeAndSendResult,
ArgTupleT Args, std::index_sequence<I...>) {
(void)Args; return std::forward<HandlerT>(H)(std::move(SerializeAndSendResult),
std::move(std::get<I>(Args))...);
}
};
template <typename RetT, typename... ArgTs,
template <typename> class ResultSerializer, typename... SPSTagTs>
class WrapperFunctionAsyncHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,
SPSTagTs...>
: public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer,
SPSTagTs...> {};
template <typename ClassT, typename RetT, typename... ArgTs,
template <typename> class ResultSerializer, typename... SPSTagTs>
class WrapperFunctionAsyncHandlerHelper<RetT (ClassT::*)(ArgTs...),
ResultSerializer, SPSTagTs...>
: public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer,
SPSTagTs...> {};
template <typename ClassT, typename RetT, typename... ArgTs,
template <typename> class ResultSerializer, typename... SPSTagTs>
class WrapperFunctionAsyncHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
ResultSerializer, SPSTagTs...>
: public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer,
SPSTagTs...> {};
template <typename SPSRetTagT, typename RetT> class ResultSerializer {
public:
static WrapperFunctionResult serialize(RetT Result) {
return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
Result);
}
};
template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> {
public:
static WrapperFunctionResult serialize(Error Err) {
return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
toSPSSerializable(std::move(Err)));
}
};
template <typename SPSRetTagT>
class ResultSerializer<SPSRetTagT, ErrorSuccess> {
public:
static WrapperFunctionResult serialize(ErrorSuccess Err) {
return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
toSPSSerializable(std::move(Err)));
}
};
template <typename SPSRetTagT, typename T>
class ResultSerializer<SPSRetTagT, Expected<T>> {
public:
static WrapperFunctionResult serialize(Expected<T> E) {
return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
toSPSSerializable(std::move(E)));
}
};
template <typename SPSRetTagT, typename RetT> class ResultDeserializer {
public:
static RetT makeValue() { return RetT(); }
static void makeSafe(RetT &Result) {}
static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) {
SPSInputBuffer IB(ArgData, ArgSize);
if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result))
return make_error<StringError>(
"Error deserializing return value from blob in call",
inconvertibleErrorCode());
return Error::success();
}
};
template <> class ResultDeserializer<SPSError, Error> {
public:
static Error makeValue() { return Error::success(); }
static void makeSafe(Error &Err) { cantFail(std::move(Err)); }
static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) {
SPSInputBuffer IB(ArgData, ArgSize);
SPSSerializableError BSE;
if (!SPSArgList<SPSError>::deserialize(IB, BSE))
return make_error<StringError>(
"Error deserializing return value from blob in call",
inconvertibleErrorCode());
Err = fromSPSSerializable(std::move(BSE));
return Error::success();
}
};
template <typename SPSTagT, typename T>
class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> {
public:
static Expected<T> makeValue() { return T(); }
static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); }
static Error deserialize(Expected<T> &E, const char *ArgData,
size_t ArgSize) {
SPSInputBuffer IB(ArgData, ArgSize);
SPSSerializableExpected<T> BSE;
if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE))
return make_error<StringError>(
"Error deserializing return value from blob in call",
inconvertibleErrorCode());
E = fromSPSSerializable(std::move(BSE));
return Error::success();
}
};
template <typename SPSRetTagT, typename RetT> class AsyncCallResultHelper {
};
}
template <typename SPSSignature> class WrapperFunction;
template <typename SPSRetTagT, typename... SPSTagTs>
class WrapperFunction<SPSRetTagT(SPSTagTs...)> {
private:
template <typename RetT>
using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>;
public:
template <typename CallerFn, typename RetT, typename... ArgTs>
static Error call(const CallerFn &Caller, RetT &Result,
const ArgTs &...Args) {
detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result);
auto ArgBuffer =
detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>(
Args...);
if (const char *ErrMsg = ArgBuffer.getOutOfBandError())
return make_error<StringError>(ErrMsg, inconvertibleErrorCode());
WrapperFunctionResult ResultBuffer =
Caller(ArgBuffer.data(), ArgBuffer.size());
if (auto ErrMsg = ResultBuffer.getOutOfBandError())
return make_error<StringError>(ErrMsg, inconvertibleErrorCode());
return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
Result, ResultBuffer.data(), ResultBuffer.size());
}
template <typename AsyncCallerFn, typename SendDeserializedResultFn,
typename... ArgTs>
static void callAsync(AsyncCallerFn &&Caller,
SendDeserializedResultFn &&SendDeserializedResult,
const ArgTs &...Args) {
using RetT = typename std::tuple_element<
1, typename detail::WrapperFunctionHandlerHelper<
std::remove_reference_t<SendDeserializedResultFn>,
ResultSerializer, SPSRetTagT>::ArgTuple>::type;
auto ArgBuffer =
detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>(
Args...);
if (auto *ErrMsg = ArgBuffer.getOutOfBandError()) {
SendDeserializedResult(
make_error<StringError>(ErrMsg, inconvertibleErrorCode()),
detail::ResultDeserializer<SPSRetTagT, RetT>::makeValue());
return;
}
auto SendSerializedResult = [SDR = std::move(SendDeserializedResult)](
WrapperFunctionResult R) mutable {
RetT RetVal = detail::ResultDeserializer<SPSRetTagT, RetT>::makeValue();
detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(RetVal);
if (auto *ErrMsg = R.getOutOfBandError()) {
SDR(make_error<StringError>(ErrMsg, inconvertibleErrorCode()),
std::move(RetVal));
return;
}
SPSInputBuffer IB(R.data(), R.size());
if (auto Err = detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
RetVal, R.data(), R.size()))
SDR(std::move(Err), std::move(RetVal));
SDR(Error::success(), std::move(RetVal));
};
Caller(std::move(SendSerializedResult), ArgBuffer.data(), ArgBuffer.size());
}
template <typename HandlerT>
static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize,
HandlerT &&Handler) {
using WFHH =
detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>,
ResultSerializer, SPSTagTs...>;
return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize);
}
template <typename HandlerT, typename SendResultT>
static void handleAsync(const char *ArgData, size_t ArgSize,
HandlerT &&Handler, SendResultT &&SendResult) {
using WFAHH = detail::WrapperFunctionAsyncHandlerHelper<
std::remove_reference_t<HandlerT>, ResultSerializer, SPSTagTs...>;
WFAHH::applyAsync(std::forward<HandlerT>(Handler),
std::forward<SendResultT>(SendResult), ArgData, ArgSize);
}
private:
template <typename T> static const T &makeSerializable(const T &Value) {
return Value;
}
static detail::SPSSerializableError makeSerializable(Error Err) {
return detail::toSPSSerializable(std::move(Err));
}
template <typename T>
static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) {
return detail::toSPSSerializable(std::move(E));
}
};
template <typename... SPSTagTs>
class WrapperFunction<void(SPSTagTs...)>
: private WrapperFunction<SPSEmpty(SPSTagTs...)> {
public:
template <typename CallerFn, typename... ArgTs>
static Error call(const CallerFn &Caller, const ArgTs &...Args) {
SPSEmpty BE;
return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(Caller, BE, Args...);
}
template <typename AsyncCallerFn, typename SendDeserializedResultFn,
typename... ArgTs>
static void callAsync(AsyncCallerFn &&Caller,
SendDeserializedResultFn &&SendDeserializedResult,
const ArgTs &...Args) {
WrapperFunction<SPSEmpty(SPSTagTs...)>::callAsync(
std::forward<AsyncCallerFn>(Caller),
[SDR = std::move(SendDeserializedResult)](Error SerializeErr,
SPSEmpty E) mutable {
SDR(std::move(SerializeErr));
},
Args...);
}
using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle;
using WrapperFunction<SPSEmpty(SPSTagTs...)>::handleAsync;
};
template <typename RetT, typename ClassT, typename... ArgTs>
class MethodWrapperHandler {
public:
using MethodT = RetT (ClassT::*)(ArgTs...);
MethodWrapperHandler(MethodT M) : M(M) {}
RetT operator()(ExecutorAddr ObjAddr, ArgTs &...Args) {
return (ObjAddr.toPtr<ClassT*>()->*M)(std::forward<ArgTs>(Args)...);
}
private:
MethodT M;
};
template <typename RetT, typename ClassT, typename... ArgTs>
MethodWrapperHandler<RetT, ClassT, ArgTs...>
makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) {
return MethodWrapperHandler<RetT, ClassT, ArgTs...>(Method);
}
class WrapperFunctionCall {
public:
using ArgDataBufferType = SmallVector<char, 24>;
template <typename SPSSerializer, typename... ArgTs>
static Expected<WrapperFunctionCall> Create(ExecutorAddr FnAddr,
const ArgTs &...Args) {
ArgDataBufferType ArgData;
ArgData.resize(SPSSerializer::size(Args...));
SPSOutputBuffer OB(&ArgData[0], ArgData.size());
if (SPSSerializer::serialize(OB, Args...))
return WrapperFunctionCall(FnAddr, std::move(ArgData));
return make_error<StringError>("Cannot serialize arguments for "
"AllocActionCall",
inconvertibleErrorCode());
}
WrapperFunctionCall() = default;
WrapperFunctionCall(ExecutorAddr FnAddr, ArgDataBufferType ArgData)
: FnAddr(FnAddr), ArgData(std::move(ArgData)) {}
const ExecutorAddr &getCallee() const { return FnAddr; }
const ArgDataBufferType &getArgData() const { return ArgData; }
explicit operator bool() const { return !!FnAddr; }
shared::WrapperFunctionResult run() const {
using FnTy =
shared::CWrapperFunctionResult(const char *ArgData, size_t ArgSize);
return shared::WrapperFunctionResult(
FnAddr.toPtr<FnTy *>()(ArgData.data(), ArgData.size()));
}
template <typename SPSRetT, typename RetT>
std::enable_if_t<!std::is_same<SPSRetT, void>::value, Error>
runWithSPSRet(RetT &RetVal) const {
auto WFR = run();
if (const char *ErrMsg = WFR.getOutOfBandError())
return make_error<StringError>(ErrMsg, inconvertibleErrorCode());
shared::SPSInputBuffer IB(WFR.data(), WFR.size());
if (!shared::SPSSerializationTraits<SPSRetT, RetT>::deserialize(IB, RetVal))
return make_error<StringError>("Could not deserialize result from "
"serialized wrapper function call",
inconvertibleErrorCode());
return Error::success();
}
template <typename SPSRetT>
std::enable_if_t<std::is_same<SPSRetT, void>::value, Error>
runWithSPSRet() const {
shared::SPSEmpty E;
return runWithSPSRet<shared::SPSEmpty>(E);
}
Error runWithSPSRetErrorMerged() const {
detail::SPSSerializableError RetErr;
if (auto Err = runWithSPSRet<SPSError>(RetErr))
return Err;
return detail::fromSPSSerializable(std::move(RetErr));
}
private:
orc::ExecutorAddr FnAddr;
ArgDataBufferType ArgData;
};
using SPSWrapperFunctionCall = SPSTuple<SPSExecutorAddr, SPSSequence<char>>;
template <>
class SPSSerializationTraits<SPSWrapperFunctionCall, WrapperFunctionCall> {
public:
static size_t size(const WrapperFunctionCall &WFC) {
return SPSWrapperFunctionCall::AsArgList::size(WFC.getCallee(),
WFC.getArgData());
}
static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) {
return SPSWrapperFunctionCall::AsArgList::serialize(OB, WFC.getCallee(),
WFC.getArgData());
}
static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) {
ExecutorAddr FnAddr;
WrapperFunctionCall::ArgDataBufferType ArgData;
if (!SPSWrapperFunctionCall::AsArgList::deserialize(IB, FnAddr, ArgData))
return false;
WFC = WrapperFunctionCall(FnAddr, std::move(ArgData));
return true;
}
};
} } }
#endif