#ifndef LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
#define LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
#include "llvm/ADT/SmallVector.h"
#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/Support/BranchProbability.h"
#include <vector>
namespace llvm {
class BlockFrequencyInfo;
class ConstantInt;
class FunctionLoweringInfo;
class MachineBasicBlock;
class ProfileSummaryInfo;
class TargetLowering;
class TargetMachine;
namespace SwitchCG {
enum CaseClusterKind {
CC_Range,
CC_JumpTable,
CC_BitTests
};
struct CaseCluster {
CaseClusterKind Kind;
const ConstantInt *Low, *High;
union {
MachineBasicBlock *MBB;
unsigned JTCasesIndex;
unsigned BTCasesIndex;
};
BranchProbability Prob;
static CaseCluster range(const ConstantInt *Low, const ConstantInt *High,
MachineBasicBlock *MBB, BranchProbability Prob) {
CaseCluster C;
C.Kind = CC_Range;
C.Low = Low;
C.High = High;
C.MBB = MBB;
C.Prob = Prob;
return C;
}
static CaseCluster jumpTable(const ConstantInt *Low, const ConstantInt *High,
unsigned JTCasesIndex, BranchProbability Prob) {
CaseCluster C;
C.Kind = CC_JumpTable;
C.Low = Low;
C.High = High;
C.JTCasesIndex = JTCasesIndex;
C.Prob = Prob;
return C;
}
static CaseCluster bitTests(const ConstantInt *Low, const ConstantInt *High,
unsigned BTCasesIndex, BranchProbability Prob) {
CaseCluster C;
C.Kind = CC_BitTests;
C.Low = Low;
C.High = High;
C.BTCasesIndex = BTCasesIndex;
C.Prob = Prob;
return C;
}
};
using CaseClusterVector = std::vector<CaseCluster>;
using CaseClusterIt = CaseClusterVector::iterator;
void sortAndRangeify(CaseClusterVector &Clusters);
struct CaseBits {
uint64_t Mask = 0;
MachineBasicBlock *BB = nullptr;
unsigned Bits = 0;
BranchProbability ExtraProb;
CaseBits() = default;
CaseBits(uint64_t mask, MachineBasicBlock *bb, unsigned bits,
BranchProbability Prob)
: Mask(mask), BB(bb), Bits(bits), ExtraProb(Prob) {}
};
using CaseBitsVector = std::vector<CaseBits>;
struct CaseBlock {
struct PredInfoPair {
CmpInst::Predicate Pred;
bool NoCmp;
};
union {
ISD::CondCode CC;
struct PredInfoPair PredInfo;
};
const Value *CmpLHS, *CmpMHS, *CmpRHS;
MachineBasicBlock *TrueBB, *FalseBB;
MachineBasicBlock *ThisBB;
SDLoc DL;
DebugLoc DbgLoc;
BranchProbability TrueProb, FalseProb;
CaseBlock(ISD::CondCode cc, const Value *cmplhs, const Value *cmprhs,
const Value *cmpmiddle, MachineBasicBlock *truebb,
MachineBasicBlock *falsebb, MachineBasicBlock *me, SDLoc dl,
BranchProbability trueprob = BranchProbability::getUnknown(),
BranchProbability falseprob = BranchProbability::getUnknown())
: CC(cc), CmpLHS(cmplhs), CmpMHS(cmpmiddle), CmpRHS(cmprhs),
TrueBB(truebb), FalseBB(falsebb), ThisBB(me), DL(dl),
TrueProb(trueprob), FalseProb(falseprob) {}
CaseBlock(CmpInst::Predicate pred, bool nocmp, const Value *cmplhs,
const Value *cmprhs, const Value *cmpmiddle,
MachineBasicBlock *truebb, MachineBasicBlock *falsebb,
MachineBasicBlock *me, DebugLoc dl,
BranchProbability trueprob = BranchProbability::getUnknown(),
BranchProbability falseprob = BranchProbability::getUnknown())
: PredInfo({pred, nocmp}), CmpLHS(cmplhs), CmpMHS(cmpmiddle),
CmpRHS(cmprhs), TrueBB(truebb), FalseBB(falsebb), ThisBB(me),
DbgLoc(dl), TrueProb(trueprob), FalseProb(falseprob) {}
};
struct JumpTable {
unsigned Reg;
unsigned JTI;
MachineBasicBlock *MBB;
MachineBasicBlock *Default;
JumpTable(unsigned R, unsigned J, MachineBasicBlock *M, MachineBasicBlock *D)
: Reg(R), JTI(J), MBB(M), Default(D) {}
};
struct JumpTableHeader {
APInt First;
APInt Last;
const Value *SValue;
MachineBasicBlock *HeaderBB;
bool Emitted;
bool FallthroughUnreachable = false;
JumpTableHeader(APInt F, APInt L, const Value *SV, MachineBasicBlock *H,
bool E = false)
: First(std::move(F)), Last(std::move(L)), SValue(SV), HeaderBB(H),
Emitted(E) {}
};
using JumpTableBlock = std::pair<JumpTableHeader, JumpTable>;
struct BitTestCase {
uint64_t Mask;
MachineBasicBlock *ThisBB;
MachineBasicBlock *TargetBB;
BranchProbability ExtraProb;
BitTestCase(uint64_t M, MachineBasicBlock *T, MachineBasicBlock *Tr,
BranchProbability Prob)
: Mask(M), ThisBB(T), TargetBB(Tr), ExtraProb(Prob) {}
};
using BitTestInfo = SmallVector<BitTestCase, 3>;
struct BitTestBlock {
APInt First;
APInt Range;
const Value *SValue;
unsigned Reg;
MVT RegVT;
bool Emitted;
bool ContiguousRange;
MachineBasicBlock *Parent;
MachineBasicBlock *Default;
BitTestInfo Cases;
BranchProbability Prob;
BranchProbability DefaultProb;
bool FallthroughUnreachable = false;
BitTestBlock(APInt F, APInt R, const Value *SV, unsigned Rg, MVT RgVT, bool E,
bool CR, MachineBasicBlock *P, MachineBasicBlock *D,
BitTestInfo C, BranchProbability Pr)
: First(std::move(F)), Range(std::move(R)), SValue(SV), Reg(Rg),
RegVT(RgVT), Emitted(E), ContiguousRange(CR), Parent(P), Default(D),
Cases(std::move(C)), Prob(Pr) {}
};
uint64_t getJumpTableRange(const CaseClusterVector &Clusters, unsigned First,
unsigned Last);
uint64_t getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
unsigned First, unsigned Last);
struct SwitchWorkListItem {
MachineBasicBlock *MBB;
CaseClusterIt FirstCluster;
CaseClusterIt LastCluster;
const ConstantInt *GE;
const ConstantInt *LT;
BranchProbability DefaultProb;
};
using SwitchWorkList = SmallVector<SwitchWorkListItem, 4>;
class SwitchLowering {
public:
SwitchLowering(FunctionLoweringInfo &funcinfo) : FuncInfo(funcinfo) {}
void init(const TargetLowering &tli, const TargetMachine &tm,
const DataLayout &dl) {
TLI = &tli;
TM = &tm;
DL = &dl;
}
std::vector<CaseBlock> SwitchCases;
std::vector<JumpTableBlock> JTCases;
std::vector<BitTestBlock> BitTestCases;
void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI,
MachineBasicBlock *DefaultMBB,
ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI);
bool buildJumpTable(const CaseClusterVector &Clusters, unsigned First,
unsigned Last, const SwitchInst *SI,
MachineBasicBlock *DefaultMBB, CaseCluster &JTCluster);
void findBitTestClusters(CaseClusterVector &Clusters, const SwitchInst *SI);
bool buildBitTests(CaseClusterVector &Clusters, unsigned First, unsigned Last,
const SwitchInst *SI, CaseCluster &BTCluster);
virtual void addSuccessorWithProb(
MachineBasicBlock *Src, MachineBasicBlock *Dst,
BranchProbability Prob = BranchProbability::getUnknown()) = 0;
virtual ~SwitchLowering() = default;
private:
const TargetLowering *TLI;
const TargetMachine *TM;
const DataLayout *DL;
FunctionLoweringInfo &FuncInfo;
};
} }
#endif