#ifndef LLVM_TOOLS_LLVM_PROGEN_PROFILEGENERATOR_H
#define LLVM_TOOLS_LLVM_PROGEN_PROFILEGENERATOR_H
#include "CSPreInliner.h"
#include "ErrorHandling.h"
#include "PerfReader.h"
#include "ProfiledBinary.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/ProfileData/SampleProfWriter.h"
#include <memory>
#include <unordered_set>
using namespace llvm;
using namespace sampleprof;
namespace llvm {
namespace sampleprof {
using ProbeCounterMap =
    std::unordered_map<const MCDecodedPseudoProbe *, uint64_t>;
class ProfileGeneratorBase {
public:
  ProfileGeneratorBase(ProfiledBinary *Binary) : Binary(Binary){};
  ProfileGeneratorBase(ProfiledBinary *Binary,
                       const ContextSampleCounterMap *Counters)
      : Binary(Binary), SampleCounters(Counters){};
  ProfileGeneratorBase(ProfiledBinary *Binary,
                       const SampleProfileMap &&Profiles)
      : Binary(Binary), ProfileMap(std::move(Profiles)){};
  virtual ~ProfileGeneratorBase() = default;
  static std::unique_ptr<ProfileGeneratorBase>
  create(ProfiledBinary *Binary, const ContextSampleCounterMap *Counters,
         bool profileIsCS);
  static std::unique_ptr<ProfileGeneratorBase>
  create(ProfiledBinary *Binary, SampleProfileMap &ProfileMap,
         bool profileIsCS);
  virtual void generateProfile() = 0;
  void write();
  static uint32_t
  getDuplicationFactor(unsigned Discriminator,
                       bool UseFSD = ProfileGeneratorBase::UseFSDiscriminator) {
    return UseFSD ? 1
                  : llvm::DILocation::getDuplicationFactorFromDiscriminator(
                        Discriminator);
  }
  static uint32_t
  getBaseDiscriminator(unsigned Discriminator,
                       bool UseFSD = ProfileGeneratorBase::UseFSDiscriminator) {
    return UseFSD ? Discriminator
                  : DILocation::getBaseDiscriminatorFromDiscriminator(
                        Discriminator,  false);
  }
  static bool UseFSDiscriminator;
protected:
    void write(std::unique_ptr<SampleProfileWriter> Writer,
             SampleProfileMap &ProfileMap);
  
  void findDisjointRanges(RangeSample &DisjointRanges,
                          const RangeSample &Ranges);
      void extractProbesFromRange(const RangeSample &RangeCounter,
                              ProbeCounterMap &ProbeCounter,
                              bool FindDisjointRanges = true);
      void updateBodySamplesforFunctionProfile(FunctionSamples &FunctionProfile,
                                           const SampleContextFrame &LeafLoc,
                                           uint64_t Count);
  void updateFunctionSamples();
  void updateTotalSamples();
  void updateCallsiteSamples();
  StringRef getCalleeNameForOffset(uint64_t TargetOffset);
  void computeSummaryAndThreshold(SampleProfileMap &ProfileMap);
  void calculateAndShowDensity(const SampleProfileMap &Profiles);
  double calculateDensity(const SampleProfileMap &Profiles,
                          uint64_t HotCntThreshold);
  void showDensitySuggestion(double Density);
  void collectProfiledFunctions();
  bool collectFunctionsFromRawProfile(
      std::unordered_set<const BinaryFunction *> &ProfiledFunctions);
    virtual bool collectFunctionsFromLLVMProfile(
      std::unordered_set<const BinaryFunction *> &ProfiledFunctions) = 0;
    uint64_t HotCountThreshold;
  uint64_t ColdCountThreshold;
  ProfiledBinary *Binary = nullptr;
  std::unique_ptr<ProfileSummary> Summary;
    SampleProfileMap ProfileMap;
  const ContextSampleCounterMap *SampleCounters = nullptr;
};
class ProfileGenerator : public ProfileGeneratorBase {
public:
  ProfileGenerator(ProfiledBinary *Binary,
                   const ContextSampleCounterMap *Counters)
      : ProfileGeneratorBase(Binary, Counters){};
  ProfileGenerator(ProfiledBinary *Binary, const SampleProfileMap &&Profiles)
      : ProfileGeneratorBase(Binary, std::move(Profiles)){};
  void generateProfile() override;
private:
  void generateLineNumBasedProfile();
  void generateProbeBasedProfile();
  RangeSample preprocessRangeCounter(const RangeSample &RangeCounter);
  FunctionSamples &getTopLevelFunctionProfile(StringRef FuncName);
        FunctionSamples &
  getLeafProfileAndAddTotalSamples(const SampleContextFrameVector &FrameVec,
                                   uint64_t Count);
  void populateBodySamplesForAllFunctions(const RangeSample &RangeCounter);
  void
  populateBoundarySamplesForAllFunctions(const BranchSample &BranchCounters);
  void
  populateBodySamplesWithProbesForAllFunctions(const RangeSample &RangeCounter);
  void populateBoundarySamplesWithProbesForAllFunctions(
      const BranchSample &BranchCounters);
  void postProcessProfiles();
  void trimColdProfiles(const SampleProfileMap &Profiles,
                        uint64_t ColdCntThreshold);
  bool collectFunctionsFromLLVMProfile(
      std::unordered_set<const BinaryFunction *> &ProfiledFunctions) override;
};
class CSProfileGenerator : public ProfileGeneratorBase {
public:
  CSProfileGenerator(ProfiledBinary *Binary,
                     const ContextSampleCounterMap *Counters)
      : ProfileGeneratorBase(Binary, Counters){};
  CSProfileGenerator(ProfiledBinary *Binary, SampleProfileMap &Profiles)
      : ProfileGeneratorBase(Binary), ContextTracker(Profiles, nullptr){};
  void generateProfile() override;
    template <typename T>
  static void trimContext(SmallVectorImpl<T> &S, int Depth = MaxContextDepth) {
    if (Depth < 0 || static_cast<size_t>(Depth) >= S.size())
      return;
    std::copy(S.begin() + S.size() - static_cast<size_t>(Depth), S.end(),
              S.begin());
    S.resize(Depth);
  }
        template <typename T>
  static void compressRecursionContext(SmallVectorImpl<T> &Context,
                                       int32_t CSize = MaxCompressionSize) {
    uint32_t I = 1;
    uint32_t HS = static_cast<uint32_t>(Context.size() / 2);
    uint32_t MaxDedupSize =
        CSize == -1 ? HS : std::min(static_cast<uint32_t>(CSize), HS);
    auto BeginIter = Context.begin();
            uint32_t End = 0;
            while (I <= MaxDedupSize) {
                                                                                                                                                                              
                              
      int32_t Right = I - 1;
      End = I;
      int32_t LeftBoundary = 0;
      while (Right + I < Context.size()) {
                                        
                                                                                
        int32_t Left = Right;
        while (Left >= LeftBoundary && Context[Left] == Context[Left + I]) {
                              Left--;
        }
        bool DuplicationFound = (Left < LeftBoundary);
                LeftBoundary = Right + 1;
        if (DuplicationFound) {
                    Right += I;
        } else {
                    std::copy(BeginIter + Right + 1, BeginIter + Left + I + 1,
                    BeginIter + End);
          End += Left + I - Right;
                    Right = Left + I;
        }
      }
            std::copy(BeginIter + Right + 1, Context.end(), BeginIter + End);
      End += Context.size() - Right - 1;
      I++;
      Context.resize(End);
      MaxDedupSize = std::min(static_cast<uint32_t>(End / 2), MaxDedupSize);
    }
  }
private:
  void generateLineNumBasedProfile();
  FunctionSamples *getOrCreateFunctionSamples(ContextTrieNode *ContextNode,
                                              bool WasLeafInlined = false);
      ContextTrieNode *getOrCreateContextNode(const SampleContextFrames Context,
                                          bool WasLeafInlined = false);
      void computeSizeForProfiledFunctions();
      void postProcessProfiles();
  void populateBodySamplesForFunction(FunctionSamples &FunctionProfile,
                                      const RangeSample &RangeCounters);
  void populateBoundarySamplesForFunction(ContextTrieNode *CallerNode,
                                          const BranchSample &BranchCounters);
  void populateInferredFunctionSamples(ContextTrieNode &Node);
  void updateFunctionSamples();
  void generateProbeBasedProfile();
    void populateBodySamplesWithProbes(const RangeSample &RangeCounter,
                                     SampleContextFrames ContextStack);
    void populateBoundarySamplesWithProbes(const BranchSample &BranchCounter,
                                         SampleContextFrames ContextStack);
  ContextTrieNode *
  getContextNodeForLeafProbe(SampleContextFrames ContextStack,
                             const MCDecodedPseudoProbe *LeafProbe);
    FunctionSamples &
  getFunctionProfileForLeafProbe(SampleContextFrames ContextStack,
                                 const MCDecodedPseudoProbe *LeafProbe);
  void convertToProfileMap(ContextTrieNode &Node,
                           SampleContextFrameVector &Context);
  void convertToProfileMap();
  void computeSummaryAndThreshold();
  bool collectFunctionsFromLLVMProfile(
      std::unordered_set<const BinaryFunction *> &ProfiledFunctions) override;
  ContextTrieNode &getRootContext() { return ContextTracker.getRootContext(); };
    std::list<FunctionSamples> FSamplesList;
    std::unordered_set<SampleContextFrameVector, SampleContextFrameHash> Contexts;
  SampleContextTracker ContextTracker;
  bool IsProfileValidOnTrie = true;
public:
      static int32_t MaxCompressionSize;
  static int MaxContextDepth;
};
} } 
#endif