Compiler projects using llvm
#include <cassert>
#include <utility>
#include <functional>
#include <iterator>
#include <ranges>
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Casting.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/Value.h"
#include "llvm/IR/User.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Dominators.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/AliasSetTracker.h"

#include "llvm/Analysis/UnitLoopInfo.h"

#define DEBUG_TYPE "unit-loop"
STATISTIC(NumLoopsDetected, "Number of loops detected");

using namespace llvm;
using namespace ece479k;

// Need to do this conversion because predecessors doesn't satisfy forward-iterator
// so we can use it in a filter view
static auto preds(BasicBlock const* block)
{
    SmallVector<BasicBlock const*> r;

    auto const p = predecessors(block);
    r.append(p.begin(), p.end());

    return r;
}

#pragma region "UnitLoopAnalysis"
/// Main function for running the Loop Identification analysis. This function
/// returns information about the loops in the function via the UnitLoopInfo
/// object
UnitLoopAnalysis::Result UnitLoopAnalysis::run(Function &F, FunctionAnalysisManager &FAM)
{
    auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
    auto &AA = FAM.getResult<AAManager>(F);

    Result Loops;
    // Need to have this ugliness to support both LLVM 15 and LLVM 16-19
    #if LLVM_VERSION_MAJOR > 15
    UnitLoopInfo::alias_results alias = BatchAAResults(AA);
    #else
    UnitLoopInfo::alias_results alias = AA;
    #endif

    auto const is_reachable = [&DT](BasicBlock const& bb) {
        return DT.isReachableFromEntry(&bb);
    };

    // Find all back-edges in F
    for (BasicBlock& BB: F | std::views::filter(is_reachable)) {

        // ReSharper disable once CppDeclarationHidesLocal
        auto const is_reachable = [&DT](BasicBlock const* bb) {
            return DT.isReachableFromEntry(bb);
        };

        auto const is_backedge = [&DT, &BB](BasicBlock const* bb) {
            return DT.dominates(&BB, bb);
        };

        // We need to filter out unreachble blocks twice because it's possible
        // for an unreachble block to be the predecessor of a reachable block.
        // See bad_loop1.ll for canonical case of this.

        for (BasicBlock const* P: preds(&BB)
                    | std::views::filter(is_reachable)
                    | std::views::filter(is_backedge))
        {
            LLVM_DEBUG(dbgs()
                << "UnitLoopAnalysis: Found loop in " << F.getName() << " from `"
                << BB.getName() << "` to `" << P->getName() << "`");
            Loops.emplace_back(&BB, const_cast<BasicBlock *>(P), alias, DT);
            ++NumLoopsDetected;
        }
    }
    return Loops;
}

AnalysisKey UnitLoopAnalysis::Key;
#pragma endregion

#pragma region "UnitLoopInfo"
static BasicBlock* get_preheader(BasicBlock* header, BasicBlock* end)
{
    BasicBlock* preheader;
    // LoopSimplify guarantees 1 preheader and 1 backedge
    // hence the preds of hedaer are [preheader, end]
    assert(header->hasNPredecessors(2) && "Loop header preds were not [preheader, backedge]");

    for (auto* pred : predecessors(header))
        if (pred != end)
            preheader = pred;

    assert(preheader && "Preheader not found");
    assert(preheader->isLegalToHoistInto() && "Preheader is not hoistable");
    assert(preheader->getSingleSuccessor() == header && "Preheader is invalid");

    return preheader;
}

UnitLoopInfo::UnitLoopInfo(BasicBlock* header, BasicBlock* end, alias_results AA, DominatorTree& DT)
    :header(header), preheader(get_preheader(header, end)), alias_sets(std::make_unique<AliasSetTracker>(AA))
{
    assert(header && "Can't record a loop that starts from a non-existent BasicBlock");
    assert(end && "Can't record a loop that ends with a non-existent BasicBlock");
    assert(preheader && "Loop does not have a valid preheader");

    SmallVector<BasicBlock*, 16> worklist;
    worklist.emplace_back(end);

    // Populate loop body by walking up from the end to the header
    while (!worklist.empty()) {
        BasicBlock* cur = worklist.back(); worklist.pop_back();

        // Make sure that we don't add the preds of the header
        if (cur == header) {
            set.insert(header);
            blocks.emplace_back(header);
            break;
        }

        if (set.insert(cur).second) {
            blocks.emplace_back(cur);
            auto preds = predecessors(cur);
            worklist.append(preds.begin(), preds.end());
        }
    }

    assert(!set.empty() && "Loops must contain at least one BasicBlock");
    assert(blocks.back() == header && "Loop must start at the header");
    assert(blocks.front() == end && "Loop must end at the end");

    for (BasicBlock* block: blocks) {
        assert(DT.dominates(preheader, block) && "All blocks in the loop must be dominated by the preheader");
        assert(DT.dominates(header, block) && "All blocks in the loop must be dominated by the header");
    }

    // Populate AliasSets for this loop
    for (BasicBlock* block: blocks | std::views::reverse)
        alias_sets->add(*block);

    // Populate the loop exits
    for (BasicBlock* block : blocks)
        // if any of the successors of a block point outside the loop, it is an exit from the loop
        if (any_of(successors(block), [end, this](BasicBlock const* bb) {
            return !this->contains(bb);
        }))
            exits.insert(block);

    LLVM_DEBUG(
        dbgs()  << " with preheader `" << preheader->getName()
                << "`, header `" << header->getName()
                << "`, exits [ ";
        for (BasicBlock const* exit: exits)
            dbgs() << "`" << exit->getName() << "` ";

        dbgs() << "]\n";
    );
}


bool UnitLoopInfo::contains(BasicBlock const *BB) const
{
    assert(!set.empty() && "Can't check an empty loop");
    // assert(BB && "Can't look-up a non-existent BasicBlock");
    return set.contains(BB);
}

bool UnitLoopInfo::contains(Instruction const *I) const
{
    assert(!set.empty() && "Can't check an empty loop");
    // assert(I && "Can't look-up a non-existent Instruction");
    return contains(I->getParent());
}

bool UnitLoopInfo::contains(Instruction const& I) const
{
    assert(!set.empty() && "Can't check an empty loop");
    // assert(I && "Can't look-up a non-existent Instruction");
    return contains(I.getParent());
}

bool UnitLoopInfo::is_invariant(Value const *V) const
{
    assert(!set.empty() && "Can't be invariant to an empty loop");
    if (auto const* i = dyn_cast<Instruction>(V))
        return !contains(i);
    return true; 
}

bool UnitLoopInfo::is_invariant(Instruction const *I) const
{
    assert(!set.empty() && "Can't be invariant to an empty loop");
    return !contains(I);
}

bool UnitLoopInfo::is_invariant(Instruction const& I) const
{
    assert(!set.empty() && "Can't be invariant to an empty loop");
    return !contains(I);
}

bool UnitLoopInfo::has_invariant_operands(Instruction const *I) const
{
    assert(!set.empty() && "Can't be invariant to an empty loop");
    return std::ranges::all_of(I->operands(), [this](Value *v){ return is_invariant(v); });
}

std::vector<Instruction*> UnitLoopInfo::instructions() const
{
    std::vector<Instruction*> r;

    for (BasicBlock* block : blocks | std::views::reverse)
        for (Instruction& inst: *block)
            r.emplace_back(&inst);

    return r;
}

std::vector<Instruction*> UnitLoopInfo::instructions(std::function<bool(Instruction const&)> const& filter) const
{
    std::vector<Instruction*> r;

    for (BasicBlock* block : blocks | std::views::reverse)
        for (Instruction& inst: *block | std::views::filter(filter))
            r.emplace_back(&inst);

    return r;
}

AliasSet const& UnitLoopInfo::alias_for(LoadInst const* inst) const
{
    return alias_sets->getAliasSetFor(MemoryLocation::get(inst));
}

AliasSet const& UnitLoopInfo::alias_for(StoreInst const* inst) const
{
    return alias_sets->getAliasSetFor(MemoryLocation::get(inst));
}
#pragma endregion