#include "llvm/Transforms/IPO/SampleContextTracker.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/ProfileData/SampleProf.h"
#include <map>
#include <queue>
#include <vector>
using namespace llvm;
using namespace sampleprof;
#define DEBUG_TYPE "sample-context-tracker"
namespace llvm {
ContextTrieNode *ContextTrieNode::getChildContext(const LineLocation &CallSite,
StringRef CalleeName) {
if (CalleeName.empty())
return getHottestChildContext(CallSite);
uint64_t Hash = FunctionSamples::getCallSiteHash(CalleeName, CallSite);
auto It = AllChildContext.find(Hash);
if (It != AllChildContext.end())
return &It->second;
return nullptr;
}
ContextTrieNode *
ContextTrieNode::getHottestChildContext(const LineLocation &CallSite) {
ContextTrieNode *ChildNodeRet = nullptr;
uint64_t MaxCalleeSamples = 0;
for (auto &It : AllChildContext) {
ContextTrieNode &ChildNode = It.second;
if (ChildNode.CallSiteLoc != CallSite)
continue;
FunctionSamples *Samples = ChildNode.getFunctionSamples();
if (!Samples)
continue;
if (Samples->getTotalSamples() > MaxCalleeSamples) {
ChildNodeRet = &ChildNode;
MaxCalleeSamples = Samples->getTotalSamples();
}
}
return ChildNodeRet;
}
ContextTrieNode &
SampleContextTracker::moveContextSamples(ContextTrieNode &ToNodeParent,
const LineLocation &CallSite,
ContextTrieNode &&NodeToMove) {
uint64_t Hash =
FunctionSamples::getCallSiteHash(NodeToMove.getFuncName(), CallSite);
std::map<uint64_t, ContextTrieNode> &AllChildContext =
ToNodeParent.getAllChildContext();
assert(!AllChildContext.count(Hash) && "Node to remove must exist");
AllChildContext[Hash] = NodeToMove;
ContextTrieNode &NewNode = AllChildContext[Hash];
NewNode.setCallSiteLoc(CallSite);
std::queue<ContextTrieNode *> NodeToUpdate;
NewNode.setParentContext(&ToNodeParent);
NodeToUpdate.push(&NewNode);
while (!NodeToUpdate.empty()) {
ContextTrieNode *Node = NodeToUpdate.front();
NodeToUpdate.pop();
FunctionSamples *FSamples = Node->getFunctionSamples();
if (FSamples) {
setContextNode(FSamples, Node);
FSamples->getContext().setState(SyntheticContext);
}
for (auto &It : Node->getAllChildContext()) {
ContextTrieNode *ChildNode = &It.second;
ChildNode->setParentContext(Node);
NodeToUpdate.push(ChildNode);
}
}
return NewNode;
}
void ContextTrieNode::removeChildContext(const LineLocation &CallSite,
StringRef CalleeName) {
uint64_t Hash = FunctionSamples::getCallSiteHash(CalleeName, CallSite);
AllChildContext.erase(Hash);
}
std::map<uint64_t, ContextTrieNode> &ContextTrieNode::getAllChildContext() {
return AllChildContext;
}
StringRef ContextTrieNode::getFuncName() const { return FuncName; }
FunctionSamples *ContextTrieNode::getFunctionSamples() const {
return FuncSamples;
}
void ContextTrieNode::setFunctionSamples(FunctionSamples *FSamples) {
FuncSamples = FSamples;
}
Optional<uint32_t> ContextTrieNode::getFunctionSize() const { return FuncSize; }
void ContextTrieNode::addFunctionSize(uint32_t FSize) {
if (!FuncSize)
FuncSize = 0;
FuncSize = FuncSize.value() + FSize;
}
LineLocation ContextTrieNode::getCallSiteLoc() const { return CallSiteLoc; }
ContextTrieNode *ContextTrieNode::getParentContext() const {
return ParentContext;
}
void ContextTrieNode::setParentContext(ContextTrieNode *Parent) {
ParentContext = Parent;
}
void ContextTrieNode::setCallSiteLoc(const LineLocation &Loc) {
CallSiteLoc = Loc;
}
void ContextTrieNode::dumpNode() {
dbgs() << "Node: " << FuncName << "\n"
<< " Callsite: " << CallSiteLoc << "\n"
<< " Size: " << FuncSize << "\n"
<< " Children:\n";
for (auto &It : AllChildContext) {
dbgs() << " Node: " << It.second.getFuncName() << "\n";
}
}
void ContextTrieNode::dumpTree() {
dbgs() << "Context Profile Tree:\n";
std::queue<ContextTrieNode *> NodeQueue;
NodeQueue.push(this);
while (!NodeQueue.empty()) {
ContextTrieNode *Node = NodeQueue.front();
NodeQueue.pop();
Node->dumpNode();
for (auto &It : Node->getAllChildContext()) {
ContextTrieNode *ChildNode = &It.second;
NodeQueue.push(ChildNode);
}
}
}
ContextTrieNode *ContextTrieNode::getOrCreateChildContext(
const LineLocation &CallSite, StringRef CalleeName, bool AllowCreate) {
uint64_t Hash = FunctionSamples::getCallSiteHash(CalleeName, CallSite);
auto It = AllChildContext.find(Hash);
if (It != AllChildContext.end()) {
assert(It->second.getFuncName() == CalleeName &&
"Hash collision for child context node");
return &It->second;
}
if (!AllowCreate)
return nullptr;
AllChildContext[Hash] = ContextTrieNode(this, CalleeName, nullptr, CallSite);
return &AllChildContext[Hash];
}
SampleContextTracker::SampleContextTracker(
SampleProfileMap &Profiles,
const DenseMap<uint64_t, StringRef> *GUIDToFuncNameMap)
: GUIDToFuncNameMap(GUIDToFuncNameMap) {
for (auto &FuncSample : Profiles) {
FunctionSamples *FSamples = &FuncSample.second;
SampleContext Context = FuncSample.first;
LLVM_DEBUG(dbgs() << "Tracking Context for function: " << Context.toString()
<< "\n");
ContextTrieNode *NewNode = getOrCreateContextPath(Context, true);
assert(!NewNode->getFunctionSamples() &&
"New node can't have sample profile");
NewNode->setFunctionSamples(FSamples);
}
populateFuncToCtxtMap();
}
void SampleContextTracker::populateFuncToCtxtMap() {
for (auto *Node : *this) {
FunctionSamples *FSamples = Node->getFunctionSamples();
if (FSamples) {
FSamples->getContext().setState(RawContext);
setContextNode(FSamples, Node);
FuncToCtxtProfiles[Node->getFuncName()].push_back(FSamples);
}
}
}
FunctionSamples *
SampleContextTracker::getCalleeContextSamplesFor(const CallBase &Inst,
StringRef CalleeName) {
LLVM_DEBUG(dbgs() << "Getting callee context for instr: " << Inst << "\n");
DILocation *DIL = Inst.getDebugLoc();
if (!DIL)
return nullptr;
CalleeName = FunctionSamples::getCanonicalFnName(CalleeName);
std::string FGUID;
CalleeName = getRepInFormat(CalleeName, FunctionSamples::UseMD5, FGUID);
ContextTrieNode *CalleeContext = getCalleeContextFor(DIL, CalleeName);
if (CalleeContext) {
FunctionSamples *FSamples = CalleeContext->getFunctionSamples();
LLVM_DEBUG(if (FSamples) {
dbgs() << " Callee context found: " << getContextString(CalleeContext)
<< "\n";
});
return FSamples;
}
return nullptr;
}
std::vector<const FunctionSamples *>
SampleContextTracker::getIndirectCalleeContextSamplesFor(
const DILocation *DIL) {
std::vector<const FunctionSamples *> R;
if (!DIL)
return R;
ContextTrieNode *CallerNode = getContextFor(DIL);
LineLocation CallSite = FunctionSamples::getCallSiteIdentifier(DIL);
for (auto &It : CallerNode->getAllChildContext()) {
ContextTrieNode &ChildNode = It.second;
if (ChildNode.getCallSiteLoc() != CallSite)
continue;
if (FunctionSamples *CalleeSamples = ChildNode.getFunctionSamples())
R.push_back(CalleeSamples);
}
return R;
}
FunctionSamples *
SampleContextTracker::getContextSamplesFor(const DILocation *DIL) {
assert(DIL && "Expect non-null location");
ContextTrieNode *ContextNode = getContextFor(DIL);
if (!ContextNode)
return nullptr;
FunctionSamples *Samples = ContextNode->getFunctionSamples();
if (Samples && ContextNode->getParentContext() != &RootContext)
Samples->getContext().setState(InlinedContext);
return Samples;
}
FunctionSamples *
SampleContextTracker::getContextSamplesFor(const SampleContext &Context) {
ContextTrieNode *Node = getContextFor(Context);
if (!Node)
return nullptr;
return Node->getFunctionSamples();
}
SampleContextTracker::ContextSamplesTy &
SampleContextTracker::getAllContextSamplesFor(const Function &Func) {
StringRef CanonName = FunctionSamples::getCanonicalFnName(Func);
return FuncToCtxtProfiles[CanonName];
}
SampleContextTracker::ContextSamplesTy &
SampleContextTracker::getAllContextSamplesFor(StringRef Name) {
return FuncToCtxtProfiles[Name];
}
FunctionSamples *SampleContextTracker::getBaseSamplesFor(const Function &Func,
bool MergeContext) {
StringRef CanonName = FunctionSamples::getCanonicalFnName(Func);
return getBaseSamplesFor(CanonName, MergeContext);
}
FunctionSamples *SampleContextTracker::getBaseSamplesFor(StringRef Name,
bool MergeContext) {
LLVM_DEBUG(dbgs() << "Getting base profile for function: " << Name << "\n");
std::string FGUID;
Name = getRepInFormat(Name, FunctionSamples::UseMD5, FGUID);
ContextTrieNode *Node = getTopLevelContextNode(Name);
if (MergeContext) {
LLVM_DEBUG(dbgs() << " Merging context profile into base profile: " << Name
<< "\n");
for (auto *CSamples : FuncToCtxtProfiles[Name]) {
SampleContext &Context = CSamples->getContext();
if (Context.hasState(InlinedContext) || Context.hasState(MergedContext))
continue;
ContextTrieNode *FromNode = getContextNodeForProfile(CSamples);
if (FromNode == Node)
continue;
ContextTrieNode &ToNode = promoteMergeContextSamplesTree(*FromNode);
assert((!Node || Node == &ToNode) && "Expect only one base profile");
Node = &ToNode;
}
}
if (!Node)
return nullptr;
return Node->getFunctionSamples();
}
void SampleContextTracker::markContextSamplesInlined(
const FunctionSamples *InlinedSamples) {
assert(InlinedSamples && "Expect non-null inlined samples");
LLVM_DEBUG(dbgs() << "Marking context profile as inlined: "
<< getContextString(*InlinedSamples) << "\n");
InlinedSamples->getContext().setState(InlinedContext);
}
ContextTrieNode &SampleContextTracker::getRootContext() { return RootContext; }
void SampleContextTracker::promoteMergeContextSamplesTree(
const Instruction &Inst, StringRef CalleeName) {
LLVM_DEBUG(dbgs() << "Promoting and merging context tree for instr: \n"
<< Inst << "\n");
DILocation *DIL = Inst.getDebugLoc();
ContextTrieNode *CallerNode = getContextFor(DIL);
if (!CallerNode)
return;
LineLocation CallSite = FunctionSamples::getCallSiteIdentifier(DIL);
if (CalleeName.empty()) {
for (auto &It : CallerNode->getAllChildContext()) {
ContextTrieNode *NodeToPromo = &It.second;
if (CallSite != NodeToPromo->getCallSiteLoc())
continue;
FunctionSamples *FromSamples = NodeToPromo->getFunctionSamples();
if (FromSamples && FromSamples->getContext().hasState(InlinedContext))
continue;
promoteMergeContextSamplesTree(*NodeToPromo);
}
return;
}
ContextTrieNode *NodeToPromo =
CallerNode->getChildContext(CallSite, CalleeName);
if (!NodeToPromo)
return;
promoteMergeContextSamplesTree(*NodeToPromo);
}
ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree(
ContextTrieNode &NodeToPromo) {
FunctionSamples *FromSamples = NodeToPromo.getFunctionSamples();
assert(FromSamples && "Shouldn't promote a context without profile");
(void)FromSamples;
LLVM_DEBUG(dbgs() << " Found context tree root to promote: "
<< getContextString(&NodeToPromo) << "\n");
assert(!FromSamples->getContext().hasState(InlinedContext) &&
"Shouldn't promote inlined context profile");
return promoteMergeContextSamplesTree(NodeToPromo, RootContext);
}
#ifndef NDEBUG
std::string
SampleContextTracker::getContextString(const FunctionSamples &FSamples) const {
return getContextString(getContextNodeForProfile(&FSamples));
}
std::string
SampleContextTracker::getContextString(ContextTrieNode *Node) const {
SampleContextFrameVector Res;
if (Node == &RootContext)
return std::string();
Res.emplace_back(Node->getFuncName(), LineLocation(0, 0));
ContextTrieNode *PreNode = Node;
Node = Node->getParentContext();
while (Node && Node != &RootContext) {
Res.emplace_back(Node->getFuncName(), PreNode->getCallSiteLoc());
PreNode = Node;
Node = Node->getParentContext();
}
std::reverse(Res.begin(), Res.end());
return SampleContext::getContextString(Res);
}
#endif
void SampleContextTracker::dump() { RootContext.dumpTree(); }
StringRef SampleContextTracker::getFuncNameFor(ContextTrieNode *Node) const {
if (!FunctionSamples::UseMD5)
return Node->getFuncName();
assert(GUIDToFuncNameMap && "GUIDToFuncNameMap needs to be populated first");
return GUIDToFuncNameMap->lookup(std::stoull(Node->getFuncName().data()));
}
ContextTrieNode *
SampleContextTracker::getContextFor(const SampleContext &Context) {
return getOrCreateContextPath(Context, false);
}
ContextTrieNode *
SampleContextTracker::getCalleeContextFor(const DILocation *DIL,
StringRef CalleeName) {
assert(DIL && "Expect non-null location");
ContextTrieNode *CallContext = getContextFor(DIL);
if (!CallContext)
return nullptr;
return CallContext->getChildContext(
FunctionSamples::getCallSiteIdentifier(DIL), CalleeName);
}
ContextTrieNode *SampleContextTracker::getContextFor(const DILocation *DIL) {
assert(DIL && "Expect non-null location");
SmallVector<std::pair<LineLocation, StringRef>, 10> S;
const DILocation *PrevDIL = DIL;
for (DIL = DIL->getInlinedAt(); DIL; DIL = DIL->getInlinedAt()) {
StringRef Name = PrevDIL->getScope()->getSubprogram()->getLinkageName();
if (Name.empty())
Name = PrevDIL->getScope()->getSubprogram()->getName();
S.push_back(
std::make_pair(FunctionSamples::getCallSiteIdentifier(DIL), Name));
PrevDIL = DIL;
}
StringRef RootName = PrevDIL->getScope()->getSubprogram()->getLinkageName();
if (RootName.empty())
RootName = PrevDIL->getScope()->getSubprogram()->getName();
S.push_back(std::make_pair(LineLocation(0, 0), RootName));
std::list<std::string> MD5Names;
if (FunctionSamples::UseMD5) {
for (auto &Location : S) {
MD5Names.emplace_back();
getRepInFormat(Location.second, FunctionSamples::UseMD5, MD5Names.back());
Location.second = MD5Names.back();
}
}
ContextTrieNode *ContextNode = &RootContext;
int I = S.size();
while (--I >= 0 && ContextNode) {
LineLocation &CallSite = S[I].first;
StringRef CalleeName = S[I].second;
ContextNode = ContextNode->getChildContext(CallSite, CalleeName);
}
if (I < 0)
return ContextNode;
return nullptr;
}
ContextTrieNode *
SampleContextTracker::getOrCreateContextPath(const SampleContext &Context,
bool AllowCreate) {
ContextTrieNode *ContextNode = &RootContext;
LineLocation CallSiteLoc(0, 0);
for (auto &Callsite : Context.getContextFrames()) {
if (AllowCreate) {
ContextNode =
ContextNode->getOrCreateChildContext(CallSiteLoc, Callsite.FuncName);
} else {
ContextNode =
ContextNode->getChildContext(CallSiteLoc, Callsite.FuncName);
}
CallSiteLoc = Callsite.Location;
}
assert((!AllowCreate || ContextNode) &&
"Node must exist if creation is allowed");
return ContextNode;
}
ContextTrieNode *SampleContextTracker::getTopLevelContextNode(StringRef FName) {
assert(!FName.empty() && "Top level node query must provide valid name");
return RootContext.getChildContext(LineLocation(0, 0), FName);
}
ContextTrieNode &SampleContextTracker::addTopLevelContextNode(StringRef FName) {
assert(!getTopLevelContextNode(FName) && "Node to add must not exist");
return *RootContext.getOrCreateChildContext(LineLocation(0, 0), FName);
}
void SampleContextTracker::mergeContextNode(ContextTrieNode &FromNode,
ContextTrieNode &ToNode) {
FunctionSamples *FromSamples = FromNode.getFunctionSamples();
FunctionSamples *ToSamples = ToNode.getFunctionSamples();
if (FromSamples && ToSamples) {
ToSamples->merge(*FromSamples);
ToSamples->getContext().setState(SyntheticContext);
FromSamples->getContext().setState(MergedContext);
if (FromSamples->getContext().hasAttribute(ContextShouldBeInlined))
ToSamples->getContext().setAttribute(ContextShouldBeInlined);
} else if (FromSamples) {
ToNode.setFunctionSamples(FromSamples);
setContextNode(FromSamples, &ToNode);
FromSamples->getContext().setState(SyntheticContext);
}
}
ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree(
ContextTrieNode &FromNode, ContextTrieNode &ToNodeParent) {
LineLocation NewCallSiteLoc = LineLocation(0, 0);
LineLocation OldCallSiteLoc = FromNode.getCallSiteLoc();
ContextTrieNode &FromNodeParent = *FromNode.getParentContext();
ContextTrieNode *ToNode = nullptr;
bool MoveToRoot = (&ToNodeParent == &RootContext);
if (!MoveToRoot) {
NewCallSiteLoc = OldCallSiteLoc;
}
ToNode = ToNodeParent.getChildContext(NewCallSiteLoc, FromNode.getFuncName());
if (!ToNode) {
ToNode =
&moveContextSamples(ToNodeParent, NewCallSiteLoc, std::move(FromNode));
LLVM_DEBUG({
dbgs() << " Context promoted and merged to: " << getContextString(ToNode)
<< "\n";
});
} else {
mergeContextNode(FromNode, *ToNode);
LLVM_DEBUG({
if (ToNode->getFunctionSamples())
dbgs() << " Context promoted and merged to: "
<< getContextString(ToNode) << "\n";
});
for (auto &It : FromNode.getAllChildContext()) {
ContextTrieNode &FromChildNode = It.second;
promoteMergeContextSamplesTree(FromChildNode, *ToNode);
}
FromNode.getAllChildContext().clear();
}
if (MoveToRoot)
FromNodeParent.removeChildContext(OldCallSiteLoc, ToNode->getFuncName());
return *ToNode;
}
void SampleContextTracker::createContextLessProfileMap(
SampleProfileMap &ContextLessProfiles) {
for (auto *Node : *this) {
FunctionSamples *FProfile = Node->getFunctionSamples();
if (FProfile)
ContextLessProfiles[Node->getFuncName()].merge(*FProfile);
}
}
}