#include "ARM.h"
#include "ARMBaseInstrInfo.h"
#include "ARMBasicBlockInfo.h"
#include "ARMSubtarget.h"
#include "MVETailPredUtils.h"
#include "llvm/CodeGen/LivePhysRegs.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineLoopInfo.h"
using namespace llvm;
#define DEBUG_TYPE "arm-block-placement"
#define DEBUG_PREFIX "ARM Block Placement: "
namespace llvm {
class ARMBlockPlacement : public MachineFunctionPass {
private:
  const ARMBaseInstrInfo *TII;
  std::unique_ptr<ARMBasicBlockUtils> BBUtils = nullptr;
  MachineLoopInfo *MLI = nullptr;
    SmallVector<MachineInstr *> RevertedWhileLoops;
public:
  static char ID;
  ARMBlockPlacement() : MachineFunctionPass(ID) {}
  bool runOnMachineFunction(MachineFunction &MF) override;
  void moveBasicBlock(MachineBasicBlock *BB, MachineBasicBlock *Before);
  bool blockIsBefore(MachineBasicBlock *BB, MachineBasicBlock *Other);
  bool fixBackwardsWLS(MachineLoop *ML);
  bool processPostOrderLoops(MachineLoop *ML);
  bool revertWhileToDoLoop(MachineInstr *WLS);
  void getAnalysisUsage(AnalysisUsage &AU) const override {
    AU.addRequired<MachineLoopInfo>();
    MachineFunctionPass::getAnalysisUsage(AU);
  }
};
} 
FunctionPass *llvm::createARMBlockPlacementPass() {
  return new ARMBlockPlacement();
}
char ARMBlockPlacement::ID = 0;
INITIALIZE_PASS(ARMBlockPlacement, DEBUG_TYPE, "ARM block placement", false,
                false)
static MachineInstr *findWLSInBlock(MachineBasicBlock *MBB) {
  for (auto &Terminator : MBB->terminators()) {
    if (isWhileLoopStart(Terminator))
      return &Terminator;
  }
  return nullptr;
}
static MachineInstr *findWLS(MachineLoop *ML) {
  MachineBasicBlock *Predecessor = ML->getLoopPredecessor();
  if (!Predecessor)
    return nullptr;
  MachineInstr *WlsInstr = findWLSInBlock(Predecessor);
  if (WlsInstr)
    return WlsInstr;
  if (Predecessor->pred_size() == 1)
    return findWLSInBlock(*Predecessor->pred_begin());
  return nullptr;
}
bool ARMBlockPlacement::revertWhileToDoLoop(MachineInstr *WLS) {
                  MachineBasicBlock *Preheader = WLS->getParent();
  assert(WLS != &Preheader->back());
  assert(WLS->getNextNode() == &Preheader->back());
  MachineInstr *Br = &Preheader->back();
  assert(Br->getOpcode() == ARM::t2B);
  assert(Br->getOperand(1).getImm() == 14);
    WLS->getOperand(1).setIsKill(false);
  if (WLS->getOpcode() == ARM::t2WhileLoopStartTP)
    WLS->getOperand(2).setIsKill(false);
    MachineBasicBlock *NewBlock = Preheader->getParent()->CreateMachineBasicBlock(
      Preheader->getBasicBlock());
  Preheader->getParent()->insert(++Preheader->getIterator(), NewBlock);
    Br->removeFromParent();
  NewBlock->insert(NewBlock->end(), Br);
    Preheader->replaceSuccessor(Br->getOperand(0).getMBB(), NewBlock);
  NewBlock->addSuccessor(Br->getOperand(0).getMBB());
    MachineInstrBuilder MIB =
      BuildMI(*NewBlock, Br, WLS->getDebugLoc(),
              TII->get(WLS->getOpcode() == ARM::t2WhileLoopStartTP
                           ? ARM::t2DoLoopStartTP
                           : ARM::t2DoLoopStart));
  MIB.add(WLS->getOperand(0));
  MIB.add(WLS->getOperand(1));
  if (WLS->getOpcode() == ARM::t2WhileLoopStartTP)
    MIB.add(WLS->getOperand(2));
  LLVM_DEBUG(dbgs() << DEBUG_PREFIX
                    << "Reverting While Loop to Do Loop: " << *WLS << "\n");
  RevertWhileLoopStartLR(WLS, TII, ARM::t2Bcc, true);
  LivePhysRegs LiveRegs;
  computeAndAddLiveIns(LiveRegs, *NewBlock);
  Preheader->getParent()->RenumberBlocks();
  BBUtils->computeAllBlockSizes();
  BBUtils->adjustBBOffsetsAfter(Preheader);
  return true;
}
bool ARMBlockPlacement::fixBackwardsWLS(MachineLoop *ML) {
  MachineInstr *WlsInstr = findWLS(ML);
  if (!WlsInstr)
    return false;
  MachineBasicBlock *Predecessor = WlsInstr->getParent();
  MachineBasicBlock *LoopExit = getWhileLoopStartTargetBB(*WlsInstr);
    if (!LoopExit->getPrevNode())
    return false;
  if (blockIsBefore(Predecessor, LoopExit))
    return false;
  LLVM_DEBUG(dbgs() << DEBUG_PREFIX << "Found a backwards WLS from "
                    << Predecessor->getFullName() << " to "
                    << LoopExit->getFullName() << "\n");
                      for (auto It = ++LoopExit->getIterator(); It != Predecessor->getIterator();
       ++It) {
    MachineBasicBlock *MBB = &*It;
    for (auto &Terminator : MBB->terminators()) {
      if (!isWhileLoopStart(Terminator))
        continue;
      MachineBasicBlock *WLSTarget = getWhileLoopStartTargetBB(Terminator);
                  if (WLSTarget == Predecessor) {
        LLVM_DEBUG(dbgs() << DEBUG_PREFIX << "Can't move Predecessor block as "
                          << "it would convert a WLS from forward to a "
                          << "backwards branching WLS\n");
        RevertedWhileLoops.push_back(WlsInstr);
        return false;
      }
    }
  }
  moveBasicBlock(Predecessor, LoopExit);
  return true;
}
bool ARMBlockPlacement::processPostOrderLoops(MachineLoop *ML) {
  bool Changed = false;
  for (auto *InnerML : *ML)
    Changed |= processPostOrderLoops(InnerML);
  return Changed | fixBackwardsWLS(ML);
}
bool ARMBlockPlacement::runOnMachineFunction(MachineFunction &MF) {
  if (skipFunction(MF.getFunction()))
    return false;
  const ARMSubtarget &ST = MF.getSubtarget<ARMSubtarget>();
  if (!ST.hasLOB())
    return false;
  LLVM_DEBUG(dbgs() << DEBUG_PREFIX << "Running on " << MF.getName() << "\n");
  MLI = &getAnalysis<MachineLoopInfo>();
  TII = static_cast<const ARMBaseInstrInfo *>(ST.getInstrInfo());
  BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(MF));
  MF.RenumberBlocks();
  BBUtils->computeAllBlockSizes();
  BBUtils->adjustBBOffsetsAfter(&MF.front());
  bool Changed = false;
  RevertedWhileLoops.clear();
    for (auto *ML : *MLI)
    Changed |= processPostOrderLoops(ML);
    for (auto *WlsInstr : RevertedWhileLoops)
    Changed |= revertWhileToDoLoop(WlsInstr);
  return Changed;
}
bool ARMBlockPlacement::blockIsBefore(MachineBasicBlock *BB,
                                      MachineBasicBlock *Other) {
  return BBUtils->getOffsetOf(Other) > BBUtils->getOffsetOf(BB);
}
void ARMBlockPlacement::moveBasicBlock(MachineBasicBlock *BB,
                                       MachineBasicBlock *Before) {
  LLVM_DEBUG(dbgs() << DEBUG_PREFIX << "Moving " << BB->getName() << " before "
                    << Before->getName() << "\n");
  MachineBasicBlock *BBPrevious = BB->getPrevNode();
  assert(BBPrevious && "Cannot move the function entry basic block");
  MachineBasicBlock *BBNext = BB->getNextNode();
  MachineBasicBlock *BeforePrev = Before->getPrevNode();
  assert(BeforePrev &&
         "Cannot move the given block to before the function entry block");
  MachineFunction *F = BB->getParent();
  BB->moveBefore(Before);
        auto FixFallthrough = [&](MachineBasicBlock *From, MachineBasicBlock *To) {
    LLVM_DEBUG(dbgs() << DEBUG_PREFIX << "Checking for fallthrough from "
                      << From->getName() << " to " << To->getName() << "\n");
    assert(From->isSuccessor(To) &&
           "'To' is expected to be a successor of 'From'");
    MachineInstr &Terminator = *(--From->terminators().end());
    if (!TII->isPredicated(Terminator) &&
        (isUncondBranchOpcode(Terminator.getOpcode()) ||
         isIndirectBranchOpcode(Terminator.getOpcode()) ||
         isJumpTableBranchOpcode(Terminator.getOpcode()) ||
         Terminator.isReturn()))
      return;
            MachineInstrBuilder MIB =
        BuildMI(From, Terminator.getDebugLoc(), TII->get(ARM::t2B));
    MIB.addMBB(To);
    MIB.addImm(ARMCC::CondCodes::AL);
    MIB.addReg(ARM::NoRegister);
    LLVM_DEBUG(dbgs() << DEBUG_PREFIX << "Adding unconditional branch from "
                      << From->getName() << " to " << To->getName() << ": "
                      << *MIB.getInstr());
  };
    if (BBPrevious->isSuccessor(BB))
    FixFallthrough(BBPrevious, BB);
    if (BeforePrev->isSuccessor(Before))
    FixFallthrough(BeforePrev, Before);
    if (BBNext && BB->isSuccessor(BBNext))
    FixFallthrough(BB, BBNext);
  F->RenumberBlocks();
  BBUtils->computeAllBlockSizes();
  BBUtils->adjustBBOffsetsAfter(BB);
}