#include "PPC.h"
#include "PPCSubtarget.h"
#include "PPCTargetMachine.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#define DEBUG_TYPE "ppc-lower-massv-entries"
using namespace llvm;
namespace {
static StringRef MASSVFuncs[] = {
#define TLI_DEFINE_MASSV_VECFUNCS_NAMES
#include "llvm/Analysis/VecFuncs.def"
};
class PPCLowerMASSVEntries : public ModulePass {
public:
  static char ID;
  PPCLowerMASSVEntries() : ModulePass(ID) {}
  bool runOnModule(Module &M) override;
  StringRef getPassName() const override { return "PPC Lower MASS Entries"; }
  void getAnalysisUsage(AnalysisUsage &AU) const override {
    AU.addRequired<TargetTransformInfoWrapperPass>();
  }
private:
  static bool isMASSVFunc(StringRef Name);
  static StringRef getCPUSuffix(const PPCSubtarget *Subtarget);
  static std::string createMASSVFuncName(Function &Func,
                                         const PPCSubtarget *Subtarget);
  bool handlePowSpecialCases(CallInst *CI, Function &Func, Module &M);
  bool lowerMASSVCall(CallInst *CI, Function &Func, Module &M,
                      const PPCSubtarget *Subtarget);
};
} 
bool PPCLowerMASSVEntries::isMASSVFunc(StringRef Name) {
  return llvm::is_contained(MASSVFuncs, Name);
}
StringRef PPCLowerMASSVEntries::getCPUSuffix(const PPCSubtarget *Subtarget) {
    if (!Subtarget)
    return "";
    if (Subtarget->isAIXABI() && Subtarget->hasP10Vector())
    return "_P10";
  if (Subtarget->hasP9Vector())
    return "_P9";
  if (Subtarget->hasP8Vector())
    return "_P8";
  if (Subtarget->isAIXABI())
    return "_P7";
  report_fatal_error(
      "Mininum subtarget for -vector-library=MASSV option is Power8 on Linux "
      "and Power7 on AIX when vectorization is not disabled.");
}
std::string
PPCLowerMASSVEntries::createMASSVFuncName(Function &Func,
                                          const PPCSubtarget *Subtarget) {
  StringRef Suffix = getCPUSuffix(Subtarget);
  auto GenericName = Func.getName().str();
  std::string MASSVEntryName = GenericName + Suffix.str();
  return MASSVEntryName;
}
bool PPCLowerMASSVEntries::handlePowSpecialCases(CallInst *CI, Function &Func,
                                                 Module &M) {
  if (Func.getName() != "__powf4" && Func.getName() != "__powd2")
    return false;
  if (Constant *Exp = dyn_cast<Constant>(CI->getArgOperand(1)))
    if (ConstantFP *CFP = dyn_cast_or_null<ConstantFP>(Exp->getSplatValue())) {
                  if (!CI->hasNoInfs() || !CI->hasApproxFunc())
        return false;
      if (!CFP->isExactlyValue(0.75) && !CFP->isExactlyValue(0.25))
        return false;
      if (CFP->isExactlyValue(0.25) && !CI->hasNoSignedZeros())
        return false;
      CI->setCalledFunction(
          Intrinsic::getDeclaration(&M, Intrinsic::pow, CI->getType()));
      return true;
    }
  return false;
}
bool PPCLowerMASSVEntries::lowerMASSVCall(CallInst *CI, Function &Func,
                                          Module &M,
                                          const PPCSubtarget *Subtarget) {
  if (CI->use_empty())
    return false;
    if (handlePowSpecialCases(CI, Func, M))
    return true;
  std::string MASSVEntryName = createMASSVFuncName(Func, Subtarget);
  FunctionCallee FCache = M.getOrInsertFunction(
      MASSVEntryName, Func.getFunctionType(), Func.getAttributes());
  CI->setCalledFunction(FCache);  
  return true;
}
bool PPCLowerMASSVEntries::runOnModule(Module &M) {
  bool Changed = false;
  auto *TPC = getAnalysisIfAvailable<TargetPassConfig>();
  if (!TPC)
    return Changed;
  auto &TM = TPC->getTM<PPCTargetMachine>();
  const PPCSubtarget *Subtarget;
  for (Function &Func : M) {
    if (!Func.isDeclaration())
      continue;
    if (!isMASSVFunc(Func.getName()))
      continue;
                SmallVector<User *, 4> MASSVUsers(Func.users());
    
    for (auto *User : MASSVUsers) {
      auto *CI = dyn_cast<CallInst>(User);
      if (!CI)
        continue;
      Subtarget = &TM.getSubtarget<PPCSubtarget>(*CI->getParent()->getParent());
      Changed |= lowerMASSVCall(CI, Func, M, Subtarget);
    }
  }
  return Changed;
}
char PPCLowerMASSVEntries::ID = 0;
char &llvm::PPCLowerMASSVEntriesID = PPCLowerMASSVEntries::ID;
INITIALIZE_PASS(PPCLowerMASSVEntries, DEBUG_TYPE, "Lower MASSV entries", false,
                false)
ModulePass *llvm::createPPCLowerMASSVEntriesPass() {
  return new PPCLowerMASSVEntries();
}