#include "MCTargetDesc/NVPTXBaseInfo.h"
#include "NVPTX.h"
#include "NVPTXTargetMachine.h"
#include "NVPTXUtilities.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/Pass.h"
#include <queue>
#define DEBUG_TYPE "nvptx-lower-args"
using namespace llvm;
namespace llvm {
void initializeNVPTXLowerArgsPass(PassRegistry &);
}
namespace {
class NVPTXLowerArgs : public FunctionPass {
  bool runOnFunction(Function &F) override;
  bool runOnKernelFunction(Function &F);
  bool runOnDeviceFunction(Function &F);
    void handleByValParam(Argument *Arg);
          void markPointerAsGlobal(Value *Ptr);
public:
  static char ID;   NVPTXLowerArgs(const NVPTXTargetMachine *TM = nullptr)
      : FunctionPass(ID), TM(TM) {}
  StringRef getPassName() const override {
    return "Lower pointer arguments of CUDA kernels";
  }
private:
  const NVPTXTargetMachine *TM;
};
} 
char NVPTXLowerArgs::ID = 1;
INITIALIZE_PASS(NVPTXLowerArgs, "nvptx-lower-args",
                "Lower arguments (NVPTX)", false, false)
static void convertToParamAS(Value *OldUser, Value *Param) {
  Instruction *I = dyn_cast<Instruction>(OldUser);
  assert(I && "OldUser must be an instruction");
  struct IP {
    Instruction *OldInstruction;
    Value *NewParam;
  };
  SmallVector<IP> ItemsToConvert = {{I, Param}};
  SmallVector<Instruction *> InstructionsToDelete;
  auto CloneInstInParamAS = [](const IP &I) -> Value * {
    if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction)) {
      LI->setOperand(0, I.NewParam);
      return LI;
    }
    if (auto *GEP = dyn_cast<GetElementPtrInst>(I.OldInstruction)) {
      SmallVector<Value *, 4> Indices(GEP->indices());
      auto *NewGEP = GetElementPtrInst::Create(GEP->getSourceElementType(),
                                               I.NewParam, Indices,
                                               GEP->getName(), GEP);
      NewGEP->setIsInBounds(GEP->isInBounds());
      return NewGEP;
    }
    if (auto *BC = dyn_cast<BitCastInst>(I.OldInstruction)) {
      auto *NewBCType = PointerType::getWithSamePointeeType(
          cast<PointerType>(BC->getType()), ADDRESS_SPACE_PARAM);
      return BitCastInst::Create(BC->getOpcode(), I.NewParam, NewBCType,
                                 BC->getName(), BC);
    }
    if (auto *ASC = dyn_cast<AddrSpaceCastInst>(I.OldInstruction)) {
      assert(ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM);
      (void)ASC;
            return I.NewParam;
    }
    llvm_unreachable("Unsupported instruction");
  };
  while (!ItemsToConvert.empty()) {
    IP I = ItemsToConvert.pop_back_val();
    Value *NewInst = CloneInstInParamAS(I);
    if (NewInst && NewInst != I.OldInstruction) {
                        for (Value *V : I.OldInstruction->users())
        ItemsToConvert.push_back({cast<Instruction>(V), NewInst});
      InstructionsToDelete.push_back(I.OldInstruction);
    }
  }
                for (Instruction *I : llvm::reverse(InstructionsToDelete))
    I->eraseFromParent();
}
static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
                                    const NVPTXTargetLowering *TLI) {
  Function *Func = Arg->getParent();
  Type *StructType = Arg->getParamByValType();
  const DataLayout DL(Func->getParent());
  uint64_t NewArgAlign =
      TLI->getFunctionParamOptimizedAlign(Func, StructType, DL).value();
  uint64_t CurArgAlign =
      Arg->getAttribute(Attribute::Alignment).getValueAsInt();
  if (CurArgAlign >= NewArgAlign)
    return;
  LLVM_DEBUG(dbgs() << "Try to use alignment " << NewArgAlign << " instead of "
                    << CurArgAlign << " for " << *Arg << '\n');
  auto NewAlignAttr =
      Attribute::get(Func->getContext(), Attribute::Alignment, NewArgAlign);
  Arg->removeAttr(Attribute::Alignment);
  Arg->addAttr(NewAlignAttr);
  struct Load {
    LoadInst *Inst;
    uint64_t Offset;
  };
  struct LoadContext {
    Value *InitialVal;
    uint64_t Offset;
  };
  SmallVector<Load> Loads;
  std::queue<LoadContext> Worklist;
  Worklist.push({ArgInParamAS, 0});
  while (!Worklist.empty()) {
    LoadContext Ctx = Worklist.front();
    Worklist.pop();
    for (User *CurUser : Ctx.InitialVal->users()) {
      if (auto *I = dyn_cast<LoadInst>(CurUser)) {
        Loads.push_back({I, Ctx.Offset});
        continue;
      }
      if (auto *I = dyn_cast<BitCastInst>(CurUser)) {
        Worklist.push({I, Ctx.Offset});
        continue;
      }
      if (auto *I = dyn_cast<GetElementPtrInst>(CurUser)) {
        APInt OffsetAccumulated =
            APInt::getZero(DL.getIndexSizeInBits(ADDRESS_SPACE_PARAM));
        if (!I->accumulateConstantOffset(DL, OffsetAccumulated))
          continue;
        uint64_t OffsetLimit = -1;
        uint64_t Offset = OffsetAccumulated.getLimitedValue(OffsetLimit);
        assert(Offset != OffsetLimit && "Expect Offset less than UINT64_MAX");
        Worklist.push({I, Ctx.Offset + Offset});
        continue;
      }
      llvm_unreachable("All users must be one of: load, "
                       "bitcast, getelementptr.");
    }
  }
  for (Load &CurLoad : Loads) {
    Align NewLoadAlign(greatestCommonDivisor(NewArgAlign, CurLoad.Offset));
    Align CurLoadAlign(CurLoad.Inst->getAlign());
    CurLoad.Inst->setAlignment(std::max(NewLoadAlign, CurLoadAlign));
  }
}
void NVPTXLowerArgs::handleByValParam(Argument *Arg) {
  Function *Func = Arg->getParent();
  Instruction *FirstInst = &(Func->getEntryBlock().front());
  Type *StructType = Arg->getParamByValType();
  assert(StructType && "Missing byval type");
  auto IsALoadChain = [&](Value *Start) {
    SmallVector<Value *, 16> ValuesToCheck = {Start};
    auto IsALoadChainInstr = [](Value *V) -> bool {
      if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V))
        return true;
            if (auto *ASC = dyn_cast<AddrSpaceCastInst>(V)) {
        if (ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM)
          return true;
      }
      return false;
    };
    while (!ValuesToCheck.empty()) {
      Value *V = ValuesToCheck.pop_back_val();
      if (!IsALoadChainInstr(V)) {
        LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << *V
                          << "\n");
        (void)Arg;
        return false;
      }
      if (!isa<LoadInst>(V))
        llvm::append_range(ValuesToCheck, V->users());
    }
    return true;
  };
  if (llvm::all_of(Arg->users(), IsALoadChain)) {
            SmallVector<User *, 16> UsersToUpdate(Arg->users());
    Value *ArgInParamAS = new AddrSpaceCastInst(
        Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
        FirstInst);
    for (Value *V : UsersToUpdate)
      convertToParamAS(V, ArgInParamAS);
    LLVM_DEBUG(dbgs() << "No need to copy " << *Arg << "\n");
        if (!TM)
      return;
    const auto *TLI =
        cast<NVPTXTargetLowering>(TM->getSubtargetImpl()->getTargetLowering());
    adjustByValArgAlignment(Arg, ArgInParamAS, TLI);
    return;
  }
    const DataLayout &DL = Func->getParent()->getDataLayout();
  unsigned AS = DL.getAllocaAddrSpace();
  AllocaInst *AllocA = new AllocaInst(StructType, AS, Arg->getName(), FirstInst);
        AllocA->setAlignment(Func->getParamAlign(Arg->getArgNo())
                           .value_or(DL.getPrefTypeAlign(StructType)));
  Arg->replaceAllUsesWith(AllocA);
  Value *ArgInParam = new AddrSpaceCastInst(
      Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
      FirstInst);
        LoadInst *LI =
      new LoadInst(StructType, ArgInParam, Arg->getName(),
                   false, AllocA->getAlign(), FirstInst);
  new StoreInst(LI, AllocA, FirstInst);
}
void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
  if (Ptr->getType()->getPointerAddressSpace() == ADDRESS_SPACE_GLOBAL)
    return;
    BasicBlock::iterator InsertPt;
  if (Argument *Arg = dyn_cast<Argument>(Ptr)) {
        InsertPt = Arg->getParent()->getEntryBlock().begin();
  } else {
        InsertPt = ++cast<Instruction>(Ptr)->getIterator();
    assert(InsertPt != InsertPt->getParent()->end() &&
           "We don't call this function with Ptr being a terminator.");
  }
  Instruction *PtrInGlobal = new AddrSpaceCastInst(
      Ptr,
      PointerType::getWithSamePointeeType(cast<PointerType>(Ptr->getType()),
                                          ADDRESS_SPACE_GLOBAL),
      Ptr->getName(), &*InsertPt);
  Value *PtrInGeneric = new AddrSpaceCastInst(PtrInGlobal, Ptr->getType(),
                                              Ptr->getName(), &*InsertPt);
    Ptr->replaceAllUsesWith(PtrInGeneric);
  PtrInGlobal->setOperand(0, Ptr);
}
bool NVPTXLowerArgs::runOnKernelFunction(Function &F) {
  if (TM && TM->getDrvInterface() == NVPTX::CUDA) {
        for (auto &B : F) {
      for (auto &I : B) {
        if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {
          if (LI->getType()->isPointerTy()) {
            Value *UO = getUnderlyingObject(LI->getPointerOperand());
            if (Argument *Arg = dyn_cast<Argument>(UO)) {
              if (Arg->hasByValAttr()) {
                                markPointerAsGlobal(LI);
              }
            }
          }
        }
      }
    }
  }
  LLVM_DEBUG(dbgs() << "Lowering kernel args of " << F.getName() << "\n");
  for (Argument &Arg : F.args()) {
    if (Arg.getType()->isPointerTy()) {
      if (Arg.hasByValAttr())
        handleByValParam(&Arg);
      else if (TM && TM->getDrvInterface() == NVPTX::CUDA)
        markPointerAsGlobal(&Arg);
    }
  }
  return true;
}
bool NVPTXLowerArgs::runOnDeviceFunction(Function &F) {
  LLVM_DEBUG(dbgs() << "Lowering function args of " << F.getName() << "\n");
  for (Argument &Arg : F.args())
    if (Arg.getType()->isPointerTy() && Arg.hasByValAttr())
      handleByValParam(&Arg);
  return true;
}
bool NVPTXLowerArgs::runOnFunction(Function &F) {
  return isKernelFunction(F) ? runOnKernelFunction(F) : runOnDeviceFunction(F);
}
FunctionPass *
llvm::createNVPTXLowerArgsPass(const NVPTXTargetMachine *TM) {
  return new NVPTXLowerArgs(TM);
}