#ifndef LLVM_TRANSFORMS_IPO_SAMPLECONTEXTTRACKER_H
#define LLVM_TRANSFORMS_IPO_SAMPLECONTEXTTRACKER_H
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ProfileData/SampleProf.h"
#include <map>
#include <queue>
#include <vector>
namespace llvm {
class CallBase;
class DILocation;
class Function;
class Instruction;
class ContextTrieNode {
public:
ContextTrieNode(ContextTrieNode *Parent = nullptr,
StringRef FName = StringRef(),
FunctionSamples *FSamples = nullptr,
LineLocation CallLoc = {0, 0})
: ParentContext(Parent), FuncName(FName), FuncSamples(FSamples),
CallSiteLoc(CallLoc){};
ContextTrieNode *getChildContext(const LineLocation &CallSite,
StringRef ChildName);
ContextTrieNode *getHottestChildContext(const LineLocation &CallSite);
ContextTrieNode *getOrCreateChildContext(const LineLocation &CallSite,
StringRef ChildName,
bool AllowCreate = true);
void removeChildContext(const LineLocation &CallSite, StringRef ChildName);
std::map<uint64_t, ContextTrieNode> &getAllChildContext();
StringRef getFuncName() const;
FunctionSamples *getFunctionSamples() const;
void setFunctionSamples(FunctionSamples *FSamples);
Optional<uint32_t> getFunctionSize() const;
void addFunctionSize(uint32_t FSize);
LineLocation getCallSiteLoc() const;
ContextTrieNode *getParentContext() const;
void setParentContext(ContextTrieNode *Parent);
void setCallSiteLoc(const LineLocation &Loc);
void dumpNode();
void dumpTree();
private:
std::map<uint64_t, ContextTrieNode> AllChildContext;
ContextTrieNode *ParentContext;
StringRef FuncName;
FunctionSamples *FuncSamples;
Optional<uint32_t> FuncSize;
LineLocation CallSiteLoc;
};
class SampleContextTracker {
public:
using ContextSamplesTy = std::vector<FunctionSamples *>;
SampleContextTracker() = default;
SampleContextTracker(SampleProfileMap &Profiles,
const DenseMap<uint64_t, StringRef> *GUIDToFuncNameMap);
void populateFuncToCtxtMap();
FunctionSamples *getCalleeContextSamplesFor(const CallBase &Inst,
StringRef CalleeName);
std::vector<const FunctionSamples *>
getIndirectCalleeContextSamplesFor(const DILocation *DIL);
FunctionSamples *getContextSamplesFor(const DILocation *DIL);
FunctionSamples *getContextSamplesFor(const SampleContext &Context);
ContextSamplesTy &getAllContextSamplesFor(const Function &Func);
ContextSamplesTy &getAllContextSamplesFor(StringRef Name);
ContextTrieNode *getOrCreateContextPath(const SampleContext &Context,
bool AllowCreate);
FunctionSamples *getBaseSamplesFor(const Function &Func,
bool MergeContext = true);
FunctionSamples *getBaseSamplesFor(StringRef Name, bool MergeContext = true);
ContextTrieNode *getContextFor(const SampleContext &Context);
StringRef getFuncNameFor(ContextTrieNode *Node) const;
void markContextSamplesInlined(const FunctionSamples *InlinedSamples);
ContextTrieNode &getRootContext();
void promoteMergeContextSamplesTree(const Instruction &Inst,
StringRef CalleeName);
void createContextLessProfileMap(SampleProfileMap &ContextLessProfiles);
ContextTrieNode *
getContextNodeForProfile(const FunctionSamples *FSamples) const {
auto I = ProfileToNodeMap.find(FSamples);
if (I == ProfileToNodeMap.end())
return nullptr;
return I->second;
}
StringMap<ContextSamplesTy> &getFuncToCtxtProfiles() {
return FuncToCtxtProfiles;
}
class Iterator : public std::iterator<std::forward_iterator_tag,
const ContextTrieNode *> {
std::queue<ContextTrieNode *> NodeQueue;
public:
explicit Iterator() = default;
explicit Iterator(ContextTrieNode *Node) { NodeQueue.push(Node); }
Iterator &operator++() {
assert(!NodeQueue.empty() && "Iterator already at the end");
ContextTrieNode *Node = NodeQueue.front();
NodeQueue.pop();
for (auto &It : Node->getAllChildContext())
NodeQueue.push(&It.second);
return *this;
}
Iterator operator++(int) {
assert(!NodeQueue.empty() && "Iterator already at the end");
Iterator Ret = *this;
++(*this);
return Ret;
}
bool operator==(const Iterator &Other) const {
if (NodeQueue.empty() && Other.NodeQueue.empty())
return true;
if (NodeQueue.empty() || Other.NodeQueue.empty())
return false;
return NodeQueue.front() == Other.NodeQueue.front();
}
bool operator!=(const Iterator &Other) const { return !(*this == Other); }
ContextTrieNode *operator*() const {
assert(!NodeQueue.empty() && "Invalid access to end iterator");
return NodeQueue.front();
}
};
Iterator begin() { return Iterator(&RootContext); }
Iterator end() { return Iterator(); }
#ifndef NDEBUG
std::string getContextString(const FunctionSamples &FSamples) const;
std::string getContextString(ContextTrieNode *Node) const;
#endif
void dump();
private:
ContextTrieNode *getContextFor(const DILocation *DIL);
ContextTrieNode *getCalleeContextFor(const DILocation *DIL,
StringRef CalleeName);
ContextTrieNode *getTopLevelContextNode(StringRef FName);
ContextTrieNode &addTopLevelContextNode(StringRef FName);
ContextTrieNode &promoteMergeContextSamplesTree(ContextTrieNode &NodeToPromo);
void mergeContextNode(ContextTrieNode &FromNode, ContextTrieNode &ToNode);
ContextTrieNode &
promoteMergeContextSamplesTree(ContextTrieNode &FromNode,
ContextTrieNode &ToNodeParent);
ContextTrieNode &moveContextSamples(ContextTrieNode &ToNodeParent,
const LineLocation &CallSite,
ContextTrieNode &&NodeToMove);
void setContextNode(const FunctionSamples *FSample, ContextTrieNode *Node) {
ProfileToNodeMap[FSample] = Node;
}
StringMap<ContextSamplesTy> FuncToCtxtProfiles;
std::unordered_map<const FunctionSamples *, ContextTrieNode *>
ProfileToNodeMap;
const DenseMap<uint64_t, StringRef> *GUIDToFuncNameMap;
ContextTrieNode RootContext;
};
} #endif