#include "AArch64InstrInfo.h"
#include "AArch64Subtarget.h"
#include "MCTargetDesc/AArch64InstPrinter.h"
#include "Utils/AArch64BaseInfo.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
#include "llvm/CodeGen/MachineOperand.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
#include "llvm/IR/DebugLoc.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/Pass.h"
#include "llvm/Support/raw_ostream.h"
#include <sstream>
using namespace llvm;
#define AARCH64_LOWER_HOMOGENEOUS_PROLOG_EPILOG_NAME                           \
  "AArch64 homogeneous prolog/epilog lowering pass"
cl::opt<int> FrameHelperSizeThreshold(
    "frame-helper-size-threshold", cl::init(2), cl::Hidden,
    cl::desc("The minimum number of instructions that are outlined in a frame "
             "helper (default = 2)"));
namespace {
class AArch64LowerHomogeneousPE {
public:
  const AArch64InstrInfo *TII;
  AArch64LowerHomogeneousPE(Module *M, MachineModuleInfo *MMI)
      : M(M), MMI(MMI) {}
  bool run();
  bool runOnMachineFunction(MachineFunction &Fn);
private:
  Module *M;
  MachineModuleInfo *MMI;
  bool runOnMBB(MachineBasicBlock &MBB);
  bool runOnMI(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
               MachineBasicBlock::iterator &NextMBBI);
        bool lowerProlog(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
                   MachineBasicBlock::iterator &NextMBBI);
        bool lowerEpilog(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
                   MachineBasicBlock::iterator &NextMBBI);
};
class AArch64LowerHomogeneousPrologEpilog : public ModulePass {
public:
  static char ID;
  AArch64LowerHomogeneousPrologEpilog() : ModulePass(ID) {
    initializeAArch64LowerHomogeneousPrologEpilogPass(
        *PassRegistry::getPassRegistry());
  }
  void getAnalysisUsage(AnalysisUsage &AU) const override {
    AU.addRequired<MachineModuleInfoWrapperPass>();
    AU.addPreserved<MachineModuleInfoWrapperPass>();
    AU.setPreservesAll();
    ModulePass::getAnalysisUsage(AU);
  }
  bool runOnModule(Module &M) override;
  StringRef getPassName() const override {
    return AARCH64_LOWER_HOMOGENEOUS_PROLOG_EPILOG_NAME;
  }
};
} 
char AArch64LowerHomogeneousPrologEpilog::ID = 0;
INITIALIZE_PASS(AArch64LowerHomogeneousPrologEpilog,
                "aarch64-lower-homogeneous-prolog-epilog",
                AARCH64_LOWER_HOMOGENEOUS_PROLOG_EPILOG_NAME, false, false)
bool AArch64LowerHomogeneousPrologEpilog::runOnModule(Module &M) {
  if (skipModule(M))
    return false;
  MachineModuleInfo *MMI =
      &getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
  return AArch64LowerHomogeneousPE(&M, MMI).run();
}
bool AArch64LowerHomogeneousPE::run() {
  bool Changed = false;
  for (auto &F : *M) {
    if (F.empty())
      continue;
    MachineFunction *MF = MMI->getMachineFunction(F);
    if (!MF)
      continue;
    Changed |= runOnMachineFunction(*MF);
  }
  return Changed;
}
enum FrameHelperType { Prolog, PrologFrame, Epilog, EpilogTail };
static std::string getFrameHelperName(SmallVectorImpl<unsigned> &Regs,
                                      FrameHelperType Type, unsigned FpOffset) {
  std::ostringstream RegStream;
  switch (Type) {
  case FrameHelperType::Prolog:
    RegStream << "OUTLINED_FUNCTION_PROLOG_";
    break;
  case FrameHelperType::PrologFrame:
    RegStream << "OUTLINED_FUNCTION_PROLOG_FRAME" << FpOffset << "_";
    break;
  case FrameHelperType::Epilog:
    RegStream << "OUTLINED_FUNCTION_EPILOG_";
    break;
  case FrameHelperType::EpilogTail:
    RegStream << "OUTLINED_FUNCTION_EPILOG_TAIL_";
    break;
  }
  for (auto Reg : Regs)
    RegStream << AArch64InstPrinter::getRegisterName(Reg);
  return RegStream.str();
}
static MachineFunction &createFrameHelperMachineFunction(Module *M,
                                                         MachineModuleInfo *MMI,
                                                         StringRef Name) {
  LLVMContext &C = M->getContext();
  Function *F = M->getFunction(Name);
  assert(F == nullptr && "Function has been created before");
  F = Function::Create(FunctionType::get(Type::getVoidTy(C), false),
                       Function::ExternalLinkage, Name, M);
  assert(F && "Function was null!");
    F->setLinkage(GlobalValue::LinkOnceODRLinkage);
  F->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
      F->addFnAttr(Attribute::OptimizeNone);
  F->addFnAttr(Attribute::NoInline);
  F->addFnAttr(Attribute::MinSize);
  F->addFnAttr(Attribute::Naked);
  MachineFunction &MF = MMI->getOrCreateMachineFunction(*F);
    MF.getProperties().reset(MachineFunctionProperties::Property::TracksLiveness);
  MF.getProperties().reset(MachineFunctionProperties::Property::IsSSA);
  MF.getProperties().set(MachineFunctionProperties::Property::NoVRegs);
  MF.getRegInfo().freezeReservedRegs(MF);
    BasicBlock *EntryBB = BasicBlock::Create(C, "entry", F);
  IRBuilder<> Builder(EntryBB);
  Builder.CreateRetVoid();
    MachineBasicBlock *MBB = MF.CreateMachineBasicBlock();
  MF.insert(MF.begin(), MBB);
  return MF;
}
static void emitStore(MachineFunction &MF, MachineBasicBlock &MBB,
                      MachineBasicBlock::iterator Pos,
                      const TargetInstrInfo &TII, unsigned Reg1, unsigned Reg2,
                      int Offset, bool IsPreDec) {
  bool IsFloat = AArch64::FPR64RegClass.contains(Reg1);
  assert(!(IsFloat ^ AArch64::FPR64RegClass.contains(Reg2)));
  unsigned Opc;
  if (IsPreDec)
    Opc = IsFloat ? AArch64::STPDpre : AArch64::STPXpre;
  else
    Opc = IsFloat ? AArch64::STPDi : AArch64::STPXi;
  MachineInstrBuilder MIB = BuildMI(MBB, Pos, DebugLoc(), TII.get(Opc));
  if (IsPreDec)
    MIB.addDef(AArch64::SP);
  MIB.addReg(Reg2)
      .addReg(Reg1)
      .addReg(AArch64::SP)
      .addImm(Offset)
      .setMIFlag(MachineInstr::FrameSetup);
}
static void emitLoad(MachineFunction &MF, MachineBasicBlock &MBB,
                     MachineBasicBlock::iterator Pos,
                     const TargetInstrInfo &TII, unsigned Reg1, unsigned Reg2,
                     int Offset, bool IsPostDec) {
  bool IsFloat = AArch64::FPR64RegClass.contains(Reg1);
  assert(!(IsFloat ^ AArch64::FPR64RegClass.contains(Reg2)));
  unsigned Opc;
  if (IsPostDec)
    Opc = IsFloat ? AArch64::LDPDpost : AArch64::LDPXpost;
  else
    Opc = IsFloat ? AArch64::LDPDi : AArch64::LDPXi;
  MachineInstrBuilder MIB = BuildMI(MBB, Pos, DebugLoc(), TII.get(Opc));
  if (IsPostDec)
    MIB.addDef(AArch64::SP);
  MIB.addReg(Reg2, getDefRegState(true))
      .addReg(Reg1, getDefRegState(true))
      .addReg(AArch64::SP)
      .addImm(Offset)
      .setMIFlag(MachineInstr::FrameDestroy);
}
static Function *getOrCreateFrameHelper(Module *M, MachineModuleInfo *MMI,
                                        SmallVectorImpl<unsigned> &Regs,
                                        FrameHelperType Type,
                                        unsigned FpOffset = 0) {
  assert(Regs.size() >= 2);
  auto Name = getFrameHelperName(Regs, Type, FpOffset);
  auto *F = M->getFunction(Name);
  if (F)
    return F;
  auto &MF = createFrameHelperMachineFunction(M, MMI, Name);
  MachineBasicBlock &MBB = *MF.begin();
  const TargetSubtargetInfo &STI = MF.getSubtarget();
  const TargetInstrInfo &TII = *STI.getInstrInfo();
  int Size = (int)Regs.size();
  switch (Type) {
  case FrameHelperType::Prolog:
  case FrameHelperType::PrologFrame: {
        auto LRIdx = std::distance(
        Regs.begin(), std::find(Regs.begin(), Regs.end(), AArch64::LR));
            if (LRIdx != Size - 2) {
      assert(Regs[Size - 2] != AArch64::LR);
      emitStore(MF, MBB, MBB.end(), TII, Regs[Size - 2], Regs[Size - 1],
                LRIdx - Size + 2, true);
    }
        for (int I = Size - 3; I >= 0; I -= 2) {
            if (Regs[I - 1] == AArch64::LR)
        continue;
      emitStore(MF, MBB, MBB.end(), TII, Regs[I - 1], Regs[I], Size - I - 1,
                false);
    }
    if (Type == FrameHelperType::PrologFrame)
      BuildMI(MBB, MBB.end(), DebugLoc(), TII.get(AArch64::ADDXri))
          .addDef(AArch64::FP)
          .addUse(AArch64::SP)
          .addImm(FpOffset)
          .addImm(0)
          .setMIFlag(MachineInstr::FrameSetup);
    BuildMI(MBB, MBB.end(), DebugLoc(), TII.get(AArch64::RET))
        .addReg(AArch64::LR);
    break;
  }
  case FrameHelperType::Epilog:
  case FrameHelperType::EpilogTail:
    if (Type == FrameHelperType::Epilog)
            BuildMI(MBB, MBB.end(), DebugLoc(), TII.get(AArch64::ORRXrs))
          .addDef(AArch64::X16)
          .addReg(AArch64::XZR)
          .addUse(AArch64::LR)
          .addImm(0);
    for (int I = 0; I < Size - 2; I += 2)
      emitLoad(MF, MBB, MBB.end(), TII, Regs[I], Regs[I + 1], Size - I - 2,
               false);
        emitLoad(MF, MBB, MBB.end(), TII, Regs[Size - 2], Regs[Size - 1], Size,
             true);
    BuildMI(MBB, MBB.end(), DebugLoc(), TII.get(AArch64::RET))
        .addReg(Type == FrameHelperType::Epilog ? AArch64::X16 : AArch64::LR);
    break;
  }
  return M->getFunction(Name);
}
static bool shouldUseFrameHelper(MachineBasicBlock &MBB,
                                 MachineBasicBlock::iterator &NextMBBI,
                                 SmallVectorImpl<unsigned> &Regs,
                                 FrameHelperType Type) {
  const auto *TRI = MBB.getParent()->getSubtarget().getRegisterInfo();
  auto RegCount = Regs.size();
  assert(RegCount > 0 && (RegCount % 2 == 0));
    int InstCount = RegCount / 2;
    if (!llvm::is_contained(Regs, AArch64::LR))
    return false;
  switch (Type) {
  case FrameHelperType::Prolog:
        InstCount--;
    break;
  case FrameHelperType::PrologFrame: {
        break;
  }
  case FrameHelperType::Epilog:
            for (auto NextMI = NextMBBI; NextMI != MBB.end(); NextMI++) {
      if (NextMI->readsRegister(AArch64::W16, TRI))
        return false;
    }
        for (const MachineBasicBlock *SuccMBB : MBB.successors()) {
      if (SuccMBB->isLiveIn(AArch64::W16) || SuccMBB->isLiveIn(AArch64::X16))
        return false;
    }
        break;
  case FrameHelperType::EpilogTail: {
        if (NextMBBI == MBB.end())
      return false;
    if (NextMBBI->getOpcode() != AArch64::RET_ReallyLR)
      return false;
    InstCount++;
    break;
  }
  }
  return InstCount >= FrameHelperSizeThreshold;
}
bool AArch64LowerHomogeneousPE::lowerEpilog(
    MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
    MachineBasicBlock::iterator &NextMBBI) {
  auto &MF = *MBB.getParent();
  MachineInstr &MI = *MBBI;
  DebugLoc DL = MI.getDebugLoc();
  SmallVector<unsigned, 8> Regs;
  for (auto &MO : MI.operands())
    if (MO.isReg())
      Regs.push_back(MO.getReg());
  int Size = (int)Regs.size();
  if (Size == 0)
    return false;
    assert(Size % 2 == 0);
  assert(MI.getOpcode() == AArch64::HOM_Epilog);
  auto Return = NextMBBI;
  if (shouldUseFrameHelper(MBB, NextMBBI, Regs, FrameHelperType::EpilogTail)) {
        auto *EpilogTailHelper =
        getOrCreateFrameHelper(M, MMI, Regs, FrameHelperType::EpilogTail);
    BuildMI(MBB, MBBI, DL, TII->get(AArch64::TCRETURNdi))
        .addGlobalAddress(EpilogTailHelper)
        .addImm(0)
        .setMIFlag(MachineInstr::FrameDestroy)
        .copyImplicitOps(MI)
        .copyImplicitOps(*Return);
    NextMBBI = std::next(Return);
    Return->removeFromParent();
  } else if (shouldUseFrameHelper(MBB, NextMBBI, Regs,
                                  FrameHelperType::Epilog)) {
        auto *EpilogHelper =
        getOrCreateFrameHelper(M, MMI, Regs, FrameHelperType::Epilog);
    BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL))
        .addGlobalAddress(EpilogHelper)
        .setMIFlag(MachineInstr::FrameDestroy)
        .copyImplicitOps(MI);
  } else {
        for (int I = 0; I < Size - 2; I += 2)
      emitLoad(MF, MBB, MBBI, *TII, Regs[I], Regs[I + 1], Size - I - 2, false);
        emitLoad(MF, MBB, MBBI, *TII, Regs[Size - 2], Regs[Size - 1], Size, true);
  }
  MBBI->removeFromParent();
  return true;
}
bool AArch64LowerHomogeneousPE::lowerProlog(
    MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
    MachineBasicBlock::iterator &NextMBBI) {
  auto &MF = *MBB.getParent();
  MachineInstr &MI = *MBBI;
  DebugLoc DL = MI.getDebugLoc();
  SmallVector<unsigned, 8> Regs;
  int LRIdx = 0;
  Optional<int> FpOffset;
  for (auto &MO : MI.operands()) {
    if (MO.isReg()) {
      if (MO.getReg() == AArch64::LR)
        LRIdx = Regs.size();
      Regs.push_back(MO.getReg());
    } else if (MO.isImm()) {
      FpOffset = MO.getImm();
    }
  }
  int Size = (int)Regs.size();
  if (Size == 0)
    return false;
    assert(Size % 2 == 0);
  assert(MI.getOpcode() == AArch64::HOM_Prolog);
  if (FpOffset &&
      shouldUseFrameHelper(MBB, NextMBBI, Regs, FrameHelperType::PrologFrame)) {
        emitStore(MF, MBB, MBBI, *TII, AArch64::LR, AArch64::FP, -LRIdx - 2, true);
    auto *PrologFrameHelper = getOrCreateFrameHelper(
        M, MMI, Regs, FrameHelperType::PrologFrame, *FpOffset);
    BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL))
        .addGlobalAddress(PrologFrameHelper)
        .setMIFlag(MachineInstr::FrameSetup)
        .copyImplicitOps(MI)
        .addReg(AArch64::FP, RegState::Implicit | RegState::Define)
        .addReg(AArch64::SP, RegState::Implicit);
  } else if (!FpOffset && shouldUseFrameHelper(MBB, NextMBBI, Regs,
                                               FrameHelperType::Prolog)) {
        emitStore(MF, MBB, MBBI, *TII, AArch64::LR, AArch64::FP, -LRIdx - 2, true);
    auto *PrologHelper =
        getOrCreateFrameHelper(M, MMI, Regs, FrameHelperType::Prolog);
    BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL))
        .addGlobalAddress(PrologHelper)
        .setMIFlag(MachineInstr::FrameSetup)
        .copyImplicitOps(MI);
  } else {
        emitStore(MF, MBB, MBBI, *TII, Regs[Size - 2], Regs[Size - 1], -Size, true);
    for (int I = Size - 3; I >= 0; I -= 2)
      emitStore(MF, MBB, MBBI, *TII, Regs[I - 1], Regs[I], Size - I - 1, false);
    if (FpOffset) {
      BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri))
          .addDef(AArch64::FP)
          .addUse(AArch64::SP)
          .addImm(*FpOffset)
          .addImm(0)
          .setMIFlag(MachineInstr::FrameSetup);
    }
  }
  MBBI->removeFromParent();
  return true;
}
bool AArch64LowerHomogeneousPE::runOnMI(MachineBasicBlock &MBB,
                                        MachineBasicBlock::iterator MBBI,
                                        MachineBasicBlock::iterator &NextMBBI) {
  MachineInstr &MI = *MBBI;
  unsigned Opcode = MI.getOpcode();
  switch (Opcode) {
  default:
    break;
  case AArch64::HOM_Prolog:
    return lowerProlog(MBB, MBBI, NextMBBI);
  case AArch64::HOM_Epilog:
    return lowerEpilog(MBB, MBBI, NextMBBI);
  }
  return false;
}
bool AArch64LowerHomogeneousPE::runOnMBB(MachineBasicBlock &MBB) {
  bool Modified = false;
  MachineBasicBlock::iterator MBBI = MBB.begin(), E = MBB.end();
  while (MBBI != E) {
    MachineBasicBlock::iterator NMBBI = std::next(MBBI);
    Modified |= runOnMI(MBB, MBBI, NMBBI);
    MBBI = NMBBI;
  }
  return Modified;
}
bool AArch64LowerHomogeneousPE::runOnMachineFunction(MachineFunction &MF) {
  TII = static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo());
  bool Modified = false;
  for (auto &MBB : MF)
    Modified |= runOnMBB(MBB);
  return Modified;
}
ModulePass *llvm::createAArch64LowerHomogeneousPrologEpilogPass() {
  return new AArch64LowerHomogeneousPrologEpilog();
}