#include <cassert>
#include <utility>
#include <functional>
#include <iterator>
#include <ranges>
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Casting.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/Value.h"
#include "llvm/IR/User.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Dominators.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/AliasSetTracker.h"
#include "llvm/Analysis/UnitLoopInfo.h"
#define DEBUG_TYPE "unit-loop"
STATISTIC(NumLoopsDetected, "Number of loops detected");
using namespace llvm;
using namespace ece479k;
static auto preds(BasicBlock const* block)
{
SmallVector<BasicBlock const*> r;
auto const p = predecessors(block);
r.append(p.begin(), p.end());
return r;
}
#pragma region "UnitLoopAnalysis"
UnitLoopAnalysis::Result UnitLoopAnalysis::run(Function &F, FunctionAnalysisManager &FAM)
{
auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
auto &AA = FAM.getResult<AAManager>(F);
Result Loops;
#if LLVM_VERSION_MAJOR > 15
UnitLoopInfo::alias_results alias = BatchAAResults(AA);
#else
UnitLoopInfo::alias_results alias = AA;
#endif
auto const is_reachable = [&DT](BasicBlock const& bb) {
return DT.isReachableFromEntry(&bb);
};
for (BasicBlock& BB: F | std::views::filter(is_reachable)) {
auto const is_reachable = [&DT](BasicBlock const* bb) {
return DT.isReachableFromEntry(bb);
};
auto const is_backedge = [&DT, &BB](BasicBlock const* bb) {
return DT.dominates(&BB, bb);
};
for (BasicBlock const* P: preds(&BB)
| std::views::filter(is_reachable)
| std::views::filter(is_backedge))
{
LLVM_DEBUG(dbgs()
<< "UnitLoopAnalysis: Found loop in " << F.getName() << " from `"
<< BB.getName() << "` to `" << P->getName() << "`");
Loops.emplace_back(&BB, const_cast<BasicBlock *>(P), alias, DT);
++NumLoopsDetected;
}
}
return Loops;
}
AnalysisKey UnitLoopAnalysis::Key;
#pragma endregion
#pragma region "UnitLoopInfo"
static BasicBlock* get_preheader(BasicBlock* header, BasicBlock* end)
{
BasicBlock* preheader;
assert(header->hasNPredecessors(2) && "Loop header preds were not [preheader, backedge]");
for (auto* pred : predecessors(header))
if (pred != end)
preheader = pred;
assert(preheader && "Preheader not found");
assert(preheader->isLegalToHoistInto() && "Preheader is not hoistable");
assert(preheader->getSingleSuccessor() == header && "Preheader is invalid");
return preheader;
}
UnitLoopInfo::UnitLoopInfo(BasicBlock* header, BasicBlock* end, alias_results AA, DominatorTree& DT)
:header(header), preheader(get_preheader(header, end)), alias_sets(std::make_unique<AliasSetTracker>(AA))
{
assert(header && "Can't record a loop that starts from a non-existent BasicBlock");
assert(end && "Can't record a loop that ends with a non-existent BasicBlock");
assert(preheader && "Loop does not have a valid preheader");
SmallVector<BasicBlock*, 16> worklist;
worklist.emplace_back(end);
while (!worklist.empty()) {
BasicBlock* cur = worklist.back(); worklist.pop_back();
if (cur == header) {
set.insert(header);
blocks.emplace_back(header);
break;
}
if (set.insert(cur).second) {
blocks.emplace_back(cur);
auto preds = predecessors(cur);
worklist.append(preds.begin(), preds.end());
}
}
assert(!set.empty() && "Loops must contain at least one BasicBlock");
assert(blocks.back() == header && "Loop must start at the header");
assert(blocks.front() == end && "Loop must end at the end");
for (BasicBlock* block: blocks) {
assert(DT.dominates(preheader, block) && "All blocks in the loop must be dominated by the preheader");
assert(DT.dominates(header, block) && "All blocks in the loop must be dominated by the header");
}
for (BasicBlock* block: blocks | std::views::reverse)
alias_sets->add(*block);
for (BasicBlock* block : blocks)
if (any_of(successors(block), [end, this](BasicBlock const* bb) {
return !this->contains(bb);
}))
exits.insert(block);
LLVM_DEBUG(
dbgs() << " with preheader `" << preheader->getName()
<< "`, header `" << header->getName()
<< "`, exits [ ";
for (BasicBlock const* exit: exits)
dbgs() << "`" << exit->getName() << "` ";
dbgs() << "]\n";
);
}
bool UnitLoopInfo::contains(BasicBlock const *BB) const
{
assert(!set.empty() && "Can't check an empty loop");
return set.contains(BB);
}
bool UnitLoopInfo::contains(Instruction const *I) const
{
assert(!set.empty() && "Can't check an empty loop");
return contains(I->getParent());
}
bool UnitLoopInfo::contains(Instruction const& I) const
{
assert(!set.empty() && "Can't check an empty loop");
return contains(I.getParent());
}
bool UnitLoopInfo::is_invariant(Value const *V) const
{
assert(!set.empty() && "Can't be invariant to an empty loop");
if (auto const* i = dyn_cast<Instruction>(V))
return !contains(i);
return true;
}
bool UnitLoopInfo::is_invariant(Instruction const *I) const
{
assert(!set.empty() && "Can't be invariant to an empty loop");
return !contains(I);
}
bool UnitLoopInfo::is_invariant(Instruction const& I) const
{
assert(!set.empty() && "Can't be invariant to an empty loop");
return !contains(I);
}
bool UnitLoopInfo::has_invariant_operands(Instruction const *I) const
{
assert(!set.empty() && "Can't be invariant to an empty loop");
return std::ranges::all_of(I->operands(), [this](Value *v){ return is_invariant(v); });
}
std::vector<Instruction*> UnitLoopInfo::instructions() const
{
std::vector<Instruction*> r;
for (BasicBlock* block : blocks | std::views::reverse)
for (Instruction& inst: *block)
r.emplace_back(&inst);
return r;
}
std::vector<Instruction*> UnitLoopInfo::instructions(std::function<bool(Instruction const&)> const& filter) const
{
std::vector<Instruction*> r;
for (BasicBlock* block : blocks | std::views::reverse)
for (Instruction& inst: *block | std::views::filter(filter))
r.emplace_back(&inst);
return r;
}
AliasSet const& UnitLoopInfo::alias_for(LoadInst const* inst) const
{
return alias_sets->getAliasSetFor(MemoryLocation::get(inst));
}
AliasSet const& UnitLoopInfo::alias_for(StoreInst const* inst) const
{
return alias_sets->getAliasSetFor(MemoryLocation::get(inst));
}
#pragma endregion