#include "CSPreInliner.h"
#include "ProfiledBinary.h"
#include "llvm/ADT/SCCIterator.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/DebugInfo/Symbolize/SymbolizableModule.h"
#include <cstdint>
#include <queue>
#define DEBUG_TYPE "cs-preinliner"
using namespace llvm;
using namespace sampleprof;
STATISTIC(PreInlNumCSInlined,
"Number of functions inlined with context sensitive profile");
STATISTIC(PreInlNumCSNotInlined,
"Number of functions not inlined with context sensitive profile");
STATISTIC(PreInlNumCSInlinedHitMinLimit,
"Number of functions with FDO inline stopped due to min size limit");
STATISTIC(PreInlNumCSInlinedHitMaxLimit,
"Number of functions with FDO inline stopped due to max size limit");
STATISTIC(
PreInlNumCSInlinedHitGrowthLimit,
"Number of functions with FDO inline stopped due to growth size limit");
extern cl::opt<int> SampleHotCallSiteThreshold;
extern cl::opt<int> SampleColdCallSiteThreshold;
extern cl::opt<int> ProfileInlineGrowthLimit;
extern cl::opt<int> ProfileInlineLimitMin;
extern cl::opt<int> ProfileInlineLimitMax;
extern cl::opt<bool> SortProfiledSCC;
cl::opt<bool> EnableCSPreInliner(
"csspgo-preinliner", cl::Hidden, cl::init(true),
cl::desc("Run a global pre-inliner to merge context profile based on "
"estimated global top-down inline decisions"));
cl::opt<bool> UseContextCostForPreInliner(
"use-context-cost-for-preinliner", cl::Hidden, cl::init(true),
cl::desc("Use context-sensitive byte size cost for preinliner decisions"));
static cl::opt<bool> SamplePreInlineReplay(
"csspgo-replay-preinline", cl::Hidden, cl::init(false),
cl::desc(
"Replay previous inlining and adjust context profile accordingly"));
CSPreInliner::CSPreInliner(SampleContextTracker &Tracker,
ProfiledBinary &Binary, ProfileSummary *Summary)
: UseContextCost(UseContextCostForPreInliner),
ContextTracker(Tracker), Binary(Binary), Summary(Summary) {
if (!SampleHotCallSiteThreshold.getNumOccurrences())
SampleHotCallSiteThreshold = 1500;
if (!SampleColdCallSiteThreshold.getNumOccurrences())
SampleColdCallSiteThreshold = 0;
if (!ProfileInlineLimitMax.getNumOccurrences())
ProfileInlineLimitMax = 3000;
}
std::vector<StringRef> CSPreInliner::buildTopDownOrder() {
std::vector<StringRef> Order;
ProfiledCallGraph ProfiledCG(ContextTracker);
scc_iterator<ProfiledCallGraph *> I = scc_begin(&ProfiledCG);
while (!I.isAtEnd()) {
auto Range = *I;
if (SortProfiledSCC) {
scc_member_iterator<ProfiledCallGraph *> SI(*I);
Range = *SI;
}
for (auto *Node : Range) {
if (Node != ProfiledCG.getEntryNode())
Order.push_back(Node->Name);
}
++I;
}
std::reverse(Order.begin(), Order.end());
return Order;
}
bool CSPreInliner::getInlineCandidates(ProfiledCandidateQueue &CQueue,
const FunctionSamples *CallerSamples) {
assert(CallerSamples && "Expect non-null caller samples");
ContextTrieNode *CallerNode =
ContextTracker.getContextNodeForProfile(CallerSamples);
bool HasNewCandidate = false;
for (auto &Child : CallerNode->getAllChildContext()) {
ContextTrieNode *CalleeNode = &Child.second;
FunctionSamples *CalleeSamples = CalleeNode->getFunctionSamples();
if (!CalleeSamples)
continue;
uint64_t CalleeEntryCount = CalleeSamples->getHeadSamplesEstimate();
uint64_t CallsiteCount = 0;
LineLocation Callsite = CalleeNode->getCallSiteLoc();
if (auto CallTargets = CallerSamples->findCallTargetMapAt(Callsite)) {
SampleRecord::CallTargetMap &TargetCounts = CallTargets.get();
auto It = TargetCounts.find(CalleeSamples->getName());
if (It != TargetCounts.end())
CallsiteCount = It->second;
}
HasNewCandidate = true;
uint32_t CalleeSize = getFuncSize(CalleeNode);
CQueue.emplace(CalleeSamples, std::max(CallsiteCount, CalleeEntryCount),
CalleeSize);
}
return HasNewCandidate;
}
uint32_t CSPreInliner::getFuncSize(const ContextTrieNode *ContextNode) {
if (UseContextCost)
return Binary.getFuncSizeForContext(ContextNode);
return ContextNode->getFunctionSamples()->getBodySamples().size();
}
bool CSPreInliner::shouldInline(ProfiledInlineCandidate &Candidate) {
if (SamplePreInlineReplay)
return Candidate.CalleeSamples->getContext().hasAttribute(
ContextWasInlined);
unsigned int SampleThreshold = SampleColdCallSiteThreshold;
uint64_t ColdCountThreshold = ProfileSummaryBuilder::getColdCountThreshold(
(Summary->getDetailedSummary()));
if (Candidate.CallsiteCount <= ColdCountThreshold)
SampleThreshold = SampleColdCallSiteThreshold;
else {
double NormalizationUpperBound =
ProfileSummaryBuilder::getEntryForPercentile(
Summary->getDetailedSummary(), 100000 )
.MinCount;
double NormalizationLowerBound = ColdCountThreshold;
double NormalizedHotness =
(Candidate.CallsiteCount - NormalizationLowerBound) /
(NormalizationUpperBound - NormalizationLowerBound);
if (NormalizedHotness > 1.0)
NormalizedHotness = 1.0;
SampleThreshold = SampleHotCallSiteThreshold * NormalizedHotness * 100 +
SampleColdCallSiteThreshold + 1;
}
return (Candidate.SizeCost < SampleThreshold);
}
void CSPreInliner::processFunction(const StringRef Name) {
FunctionSamples *FSamples = ContextTracker.getBaseSamplesFor(Name);
if (!FSamples)
return;
unsigned FuncSize =
getFuncSize(ContextTracker.getContextNodeForProfile(FSamples));
unsigned FuncFinalSize = FuncSize;
unsigned SizeLimit = FuncSize * ProfileInlineGrowthLimit;
SizeLimit = std::min(SizeLimit, (unsigned)ProfileInlineLimitMax);
SizeLimit = std::max(SizeLimit, (unsigned)ProfileInlineLimitMin);
LLVM_DEBUG(dbgs() << "Process " << Name
<< " for context-sensitive pre-inlining (pre-inline size: "
<< FuncSize << ", size limit: " << SizeLimit << ")\n");
ProfiledCandidateQueue CQueue;
getInlineCandidates(CQueue, FSamples);
while (!CQueue.empty() && FuncFinalSize < SizeLimit) {
ProfiledInlineCandidate Candidate = CQueue.top();
CQueue.pop();
bool ShouldInline = false;
if ((ShouldInline = shouldInline(Candidate))) {
++PreInlNumCSInlined;
ContextTracker.markContextSamplesInlined(Candidate.CalleeSamples);
Candidate.CalleeSamples->getContext().setAttribute(
ContextShouldBeInlined);
FuncFinalSize += Candidate.SizeCost;
getInlineCandidates(CQueue, Candidate.CalleeSamples);
} else {
++PreInlNumCSNotInlined;
}
LLVM_DEBUG(
dbgs() << (ShouldInline ? " Inlined" : " Outlined")
<< " context profile for: "
<< ContextTracker.getContextString(*Candidate.CalleeSamples)
<< " (callee size: " << Candidate.SizeCost
<< ", call count:" << Candidate.CallsiteCount << ")\n");
}
if (!CQueue.empty()) {
if (SizeLimit == (unsigned)ProfileInlineLimitMax)
++PreInlNumCSInlinedHitMaxLimit;
else if (SizeLimit == (unsigned)ProfileInlineLimitMin)
++PreInlNumCSInlinedHitMinLimit;
else
++PreInlNumCSInlinedHitGrowthLimit;
}
LLVM_DEBUG({
if (!CQueue.empty())
dbgs() << " Inline candidates ignored due to size limit (inliner "
"original size: "
<< FuncSize << ", inliner final size: " << FuncFinalSize
<< ", size limit: " << SizeLimit << ")\n";
while (!CQueue.empty()) {
ProfiledInlineCandidate Candidate = CQueue.top();
CQueue.pop();
bool WasInlined =
Candidate.CalleeSamples->getContext().hasAttribute(ContextWasInlined);
dbgs() << " "
<< ContextTracker.getContextString(*Candidate.CalleeSamples)
<< " (candidate size:" << Candidate.SizeCost
<< ", call count: " << Candidate.CallsiteCount << ", previously "
<< (WasInlined ? "inlined)\n" : "not inlined)\n");
}
});
}
void CSPreInliner::run() {
#ifndef NDEBUG
auto printProfileNames = [](SampleContextTracker &ContextTracker,
bool IsInput) {
uint32_t Size = 0;
for (auto *Node : ContextTracker) {
FunctionSamples *FSamples = Node->getFunctionSamples();
if (FSamples) {
Size++;
dbgs() << " [" << ContextTracker.getContextString(Node) << "] "
<< FSamples->getTotalSamples() << ":"
<< FSamples->getHeadSamples() << "\n";
}
}
dbgs() << (IsInput ? "Input" : "Output") << " context-sensitive profiles ("
<< Size << " total):\n";
};
#endif
LLVM_DEBUG(printProfileNames(ContextTracker, true));
for (StringRef FuncName : buildTopDownOrder()) {
processFunction(FuncName);
}
for (auto *Node : ContextTracker) {
FunctionSamples *FProfile = Node->getFunctionSamples();
if (FProfile &&
(Node->getParentContext() != &ContextTracker.getRootContext() &&
!FProfile->getContext().hasState(InlinedContext))) {
Node->setFunctionSamples(nullptr);
}
}
FunctionSamples::ProfileIsPreInlined = true;
LLVM_DEBUG(printProfileNames(ContextTracker, false));
}