#ifndef LLVM_ADT_TYPESWITCH_H
#define LLVM_ADT_TYPESWITCH_H
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
namespace llvm {
namespace detail {
template <typename DerivedT, typename T> class TypeSwitchBase {
public:
TypeSwitchBase(const T &value) : value(value) {}
TypeSwitchBase(TypeSwitchBase &&other) : value(other.value) {}
~TypeSwitchBase() = default;
TypeSwitchBase(const TypeSwitchBase &) = delete;
void operator=(const TypeSwitchBase &) = delete;
void operator=(TypeSwitchBase &&other) = delete;
template <typename CaseT, typename CaseT2, typename... CaseTs,
typename CallableT>
LLVM_ATTRIBUTE_ALWAYS_INLINE LLVM_ATTRIBUTE_NODEBUG DerivedT &
Case(CallableT &&caseFn) {
DerivedT &derived = static_cast<DerivedT &>(*this);
return derived.template Case<CaseT>(caseFn)
.template Case<CaseT2, CaseTs...>(caseFn);
}
template <typename CallableT> DerivedT &Case(CallableT &&caseFn) {
using Traits = function_traits<std::decay_t<CallableT>>;
using CaseT = std::remove_cv_t<std::remove_pointer_t<
std::remove_reference_t<typename Traits::template arg_t<0>>>>;
DerivedT &derived = static_cast<DerivedT &>(*this);
return derived.template Case<CaseT>(std::forward<CallableT>(caseFn));
}
protected:
template <typename ValueT, typename CastT>
using has_dyn_cast_t =
decltype(std::declval<ValueT &>().template dyn_cast<CastT>());
template <typename CastT, typename ValueT>
static auto castValue(
ValueT value,
typename std::enable_if_t<
is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = nullptr) {
return value.template dyn_cast<CastT>();
}
template <typename CastT, typename ValueT>
static auto castValue(
ValueT value,
typename std::enable_if_t<
!is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = nullptr) {
return dyn_cast<CastT>(value);
}
const T value;
};
}
template <typename T, typename ResultT = void>
class TypeSwitch : public detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T> {
public:
using BaseT = detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T>;
using BaseT::BaseT;
using BaseT::Case;
TypeSwitch(TypeSwitch &&other) = default;
template <typename CaseT, typename CallableT>
TypeSwitch<T, ResultT> &Case(CallableT &&caseFn) {
if (result)
return *this;
if (auto caseValue = BaseT::template castValue<CaseT>(this->value))
result = caseFn(caseValue);
return *this;
}
template <typename CallableT>
LLVM_NODISCARD ResultT Default(CallableT &&defaultFn) {
if (result)
return std::move(*result);
return defaultFn(this->value);
}
LLVM_NODISCARD ResultT Default(ResultT defaultResult) {
if (result)
return std::move(*result);
return defaultResult;
}
LLVM_NODISCARD
operator ResultT() {
assert(result && "Fell off the end of a type-switch");
return std::move(*result);
}
private:
Optional<ResultT> result;
};
template <typename T>
class TypeSwitch<T, void>
: public detail::TypeSwitchBase<TypeSwitch<T, void>, T> {
public:
using BaseT = detail::TypeSwitchBase<TypeSwitch<T, void>, T>;
using BaseT::BaseT;
using BaseT::Case;
TypeSwitch(TypeSwitch &&other) = default;
template <typename CaseT, typename CallableT>
TypeSwitch<T, void> &Case(CallableT &&caseFn) {
if (foundMatch)
return *this;
if (auto caseValue = BaseT::template castValue<CaseT>(this->value)) {
caseFn(caseValue);
foundMatch = true;
}
return *this;
}
template <typename CallableT> void Default(CallableT &&defaultFn) {
if (!foundMatch)
defaultFn(this->value);
}
private:
bool foundMatch = false;
};
}
#endif