Compiler projects using llvm
#include <cassert>
#include <utility>
#include <functional>
#include <iterator>
#include <vector>
#include <ranges>
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Casting.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/IR/Value.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/IR/Dominators.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/AliasSetTracker.h"
#include "llvm/Analysis/ValueTracking.h"

#include "llvm/Transforms/UnitLICM.h"
#include "llvm/Analysis/UnitLoopInfo.h"

#define DEBUG_TYPE "unit-licm"
STATISTIC(NumInstsHoisted, "Number of instructions hoisted");
STATISTIC(NumStoresHoisted, "Number of stores hoisted");
STATISTIC(NumLoadsHoisted, "Number of loads hoisted");
STATISTIC(NumComputationalHoisted, "Number of computational instructions (excluding casts and GEPs) hoisted");
STATISTIC(NumVisited, "Number of instructions considered for hoisting");

using namespace llvm;
using namespace ece479k;

/// Since all symbols are exported from the SO, we try to mark all symbols local
/// to this translation unit to improve diagnostics and improve performance.
/// Likewise, we try to keep functions pure.
///
/// To reduce confusion when describing LICM, we classify two types of "invariants":
///     - Invariant: a Value that is already invariant to the loop - it is outside the loop or not an Instruction
///     - Hoistable: a Value that can be hoisted outside the loop to become an invariant
///
/// (Un)Hoistable detection is implemented as DFS on every use-def chain in the loop,
/// stopping if we end up at an invariant or condition that prevents hoisting,
/// propagating the result back up the chain to every Value visited.
///
/// This is implemented as a series of tail-recursive calls to check each
/// instruction for any disqualifying cases or an invariant and then recursively
/// on each of its operands (the part that makes it a DFS).
///
/// Optimization misses are documented in debug output. For meaningful names,
/// make sure to run the `instnamer` pass before running `unit-licm`.
/// Example output:
///     UnitLICM: Can't hoist instruction: Depends on unsupported phi `p.0`:   %idxprom = sext i32 %p.0 to i64
///     UnitLICM: Can't hoist instruction: Depends on not-hoistable sext `idxprom`:   %arrayidx18 = getelementptr inbounds [8 x [3 x double]], ptr %position, i64 0, i64 %idxprom
///     UnitLICM: Can't hoist instruction: Depends on not-hoistable getelementptr `arrayidx18`:   %arraydecay19 = getelementptr inbounds [3 x double], ptr %arrayidx18, i64 0, i64 0

static bool is_supported(Instruction const&) __attribute((pure));
static bool is_computational(Instruction const&) __attribute((pure));

namespace {
    struct HoistArgs {
        UnitLoopInfo const& loop;
        DenseMap<Instruction const*, bool> cache;
        DominatorTree const& DT;
    };
}

static bool can_hoist(Instruction const&, HoistArgs &) __attribute((pure));
static void hoist(Instruction&, HoistArgs &);

// ReSharper disable once CppMemberFunctionMayBeStatic
PreservedAnalyses UnitLICM::run(Function &F, FunctionAnalysisManager &FAM) // NOLINT(*-convert-member-functions-to-static)
{
    for (auto const& loop: FAM.getResult<UnitLoopAnalysis>(F)) {

        auto hoist_args = HoistArgs{
            loop,
            DenseMap<Instruction const*, bool>(32),
            FAM.getResult<DominatorTreeAnalysis>(F)
        };
        hoist_args.cache.clear();

        auto const is_hoistable = [&loop, &hoist_args](Instruction const& inst) {
            return is_supported(inst) && can_hoist(inst, hoist_args);
        };

        // Hoist everything we can
        for (auto i: loop.instructions(is_hoistable))
        {
            hoist(*i, hoist_args);

            if (isa<LoadInst>(i))
                ++NumLoadsHoisted;
            else if (isa<StoreInst>(i))
                ++NumStoresHoisted;
            else if (is_computational(*i))
                ++NumComputationalHoisted;

            ++NumInstsHoisted;
        }
    }

    // Invalidate analyses (we fucked the code up)
    return PreservedAnalyses::none();
}

// This switch is inlined instead of chaining isXXX calls because it's easier to see what's going on
// and we get slightly better codegen (practically it doesn't matter)
static bool is_supported(Instruction const& inst)
{
    switch (inst.getOpcode()) {
    default:
            return false;
    // unary operator
    case Instruction::FNeg:
    // binary operator
    case Instruction::Add:
    case Instruction::FAdd:
    case Instruction::Sub:
    case Instruction::FSub:
    case Instruction::Mul:
    case Instruction::FMul:
    case Instruction::UDiv:
    case Instruction::SDiv:
    case Instruction::FDiv:
    case Instruction::URem:
    case Instruction::SRem:
    case Instruction::FRem:
    // logical operator
    case Instruction::Shl:
    case Instruction::LShr:
    case Instruction::AShr:
    case Instruction::And:
    case Instruction::Or:
    case Instruction::Xor:
    // memory operator
    case Instruction::Load:
    case Instruction::Store:
    case Instruction::GetElementPtr:
    // cast operator
    case Instruction::Trunc:
    case Instruction::ZExt:
    case Instruction::SExt:
    case Instruction::FPToUI:
    case Instruction::FPToSI:
    case Instruction::UIToFP:
    case Instruction::SIToFP:
    case Instruction::FPTrunc:
    case Instruction::FPExt:
    case Instruction::PtrToInt:
    case Instruction::IntToPtr:
    case Instruction::BitCast:
    case Instruction::AddrSpaceCast:
    // misc operator
    case Instruction::Select:
    case Instruction::ICmp:
    case Instruction::FCmp:
        return true;
    }
}


static bool is_computational(Instruction const& inst)
{
    switch (inst.getOpcode()) {
    default:
        return false;
    // unary operator
    case Instruction::FNeg:
    // binary operator
    case Instruction::Add:
    case Instruction::FAdd:
    case Instruction::Sub:
    case Instruction::FSub:
    case Instruction::Mul:
    case Instruction::FMul:
    case Instruction::UDiv:
    case Instruction::SDiv:
    case Instruction::FDiv:
    case Instruction::URem:
    case Instruction::SRem:
    case Instruction::FRem:
    // logical operator
    case Instruction::Shl:
    case Instruction::LShr:
    case Instruction::AShr:
    case Instruction::And:
    case Instruction::Or:
    case Instruction::Xor:
    // misc operator
    case Instruction::Select:
    case Instruction::ICmp:
    case Instruction::FCmp:
        return true;
    }
}

static void hoist(Instruction& inst, HoistArgs &args)
{
    assert(!args.loop.is_invariant(inst) && "Cannot hoist variant");

    auto* insertion_point = const_cast<Instruction *>(args.loop.preheader->getTerminator());
    inst.moveBefore(insertion_point);

    assert(args.loop.is_invariant(inst) && "Hoist did not make invariant");
}

// These are called in a tail-recursive (ish) manner
static bool can_hoist_impl(Instruction const&, HoistArgs &) __attribute((pure));
static bool can_hoist(Value const&, HoistArgs &) __attribute((pure));
static bool can_hoist_udiv(Instruction const&, HoistArgs &) __attribute((pure));
static bool can_hoist_sdiv(Instruction const&, HoistArgs &) __attribute((pure));
static bool can_hoist_store(Instruction const&, HoistArgs &) __attribute((pure));
static bool can_hoist_load(Instruction const&, HoistArgs &) __attribute((pure));
static bool can_hoist_load_impl(Instruction const&, HoistArgs &) __attribute((pure));
static bool can_hoist_sidefx(Instruction const&, HoistArgs&) __attribute((pure));
static bool can_hoist_operands(Instruction const&, HoistArgs &) __attribute((pure));


static bool can_hoist(Instruction const& inst, HoistArgs &args)
{
    if (args.loop.is_invariant(inst))
        return true;

    if (!is_supported(inst)) {
        LLVM_DEBUG(dbgs() << "UnitLICM: Can't hoist instruction: Depends on unsupported " << inst.getOpcodeName() << " `" << inst.getName() << "`: ");
        return false;
    }

    auto const it = args.cache.find(&inst);
    if (it != args.cache.end()) {
        LLVM_DEBUG(if (!it->second)
            dbgs() << "UnitLICM: Can't hoist instruction: Depends on not-hoistable " << inst.getOpcodeName() << " `" << inst.getName() << "`: ");

        return it->second;
    }

    bool const dfs = can_hoist_impl(inst, args);
    args.cache.insert({&inst, dfs});
    return dfs;
}

static bool can_hoist(Value const& val, HoistArgs &args)
{
    if (auto const* inst = dyn_cast<Instruction>(&val))
        return can_hoist(*inst, args);
    return true;
}

/// This is effectively a specialization of isSafeToSpeculativelyExecute for our case:
///     - We support stores
///     - We don't bail directly if an instruction has sidefx - we check if it dominates all exits to try and hoist it anyway
///     - We don't check for instructions we don't support
///
/// This will dispatch to special cases for instructions that can have side effects
/// or cause the use-def chain to be marked as not-hoistable. Otherwise, this
/// will directly dispatch to the next step of the DFS
static bool can_hoist_impl(Instruction const& inst, HoistArgs &args)
{
    assert(is_supported(inst) && "Instruction is not supported");

    ++NumVisited;

    switch (inst.getOpcode()) {
    default: // Instruction doesn't have edge cases
        return can_hoist_operands(inst, args);

    case Instruction::Load: // Need to do alias analysis
        return can_hoist_load(inst, args);

    case Instruction::Store: // This one always has sidefx, and we need to do alias analysis
        return can_hoist_store(inst, args);

    case Instruction::UDiv:
    case Instruction::URem: // Division by zero can throw
        return can_hoist_udiv(inst, args);

    case Instruction::SDiv:
    case Instruction::SRem: // Division by zero or of INT_MIN can throw
        return can_hoist_sdiv(inst, args);

    }
}

static bool can_hoist_sdiv(Instruction const& inst, HoistArgs &args)
{
    // code stolen from isSafeToSpeculativelyExecute
    using namespace llvm::PatternMatch;
    // x / y is undefined if y == 0 or x == INT_MIN and y == -1
    const APInt *Numerator, *Denominator;
    if (!match(inst.getOperand(1), m_APInt(Denominator))) {
        // LLVM_DEBUG(dbgs() << "UnitLICM: Can't hoist  " << inst.getOpcodeName() <<  ": Undefined denominator " << inst.getOperand(1) << ":" << inst << "\n");
        return can_hoist_sidefx(inst, args);
    }
    // We cannot hoist this division if the denominator is 0.
    if (*Denominator == 0) {
        // LLVM_DEBUG(dbgs() << "UnitLICM: Can't hoist  " << inst.getOpcodeName() <<  ": Division by zero:" << inst << "\n");
        return can_hoist_sidefx(inst, args);
    }
    // It's safe to hoist if the denominator is not 0 or -1.
    if (!Denominator->isAllOnes()) {
        return can_hoist_operands(inst, args);
    }
    // At this point we know that the denominator is -1.  It is safe to hoist as
    // long we know that the numerator is not INT_MIN.
    if (match(inst.getOperand(0), m_APInt(Numerator))) {
        if (Numerator->isMinSignedValue()) {
            // LLVM_DEBUG(dbgs() << "UnitLICM: Can't hoist  " << inst.getOpcodeName() <<  ": Division of INT_MIN by -1:" << inst << "\n");
            return can_hoist_sidefx(inst, args);
        }

        return can_hoist_operands(inst, args);
    }
    // The numerator *might* be MinSignedValue.
    // LLVM_DEBUG(dbgs() << "UnitLICM: Can't hoist  " << inst.getOpcodeName() <<  ": Division of INT_MIN( " << Numerator << " ) by -1:" << inst << "\n");
    return can_hoist_sidefx(inst, args);
}

static bool can_hoist_udiv(Instruction const& inst, HoistArgs &args)
{
    // code stolen from isSafeToSpeculativelyExecute
    using namespace llvm::PatternMatch;
    // x / y is undefined if y == 0.
    const APInt *V;
    if (match(inst.getOperand(1), m_APInt(V))) {
        if (!*V) {
            // LLVM_DEBUG(dbgs() << "UnitLICM: Can't hoist " << inst.getOpcodeName() <<  ": Division by zero:" << inst << "\n");
            return can_hoist_sidefx(inst, args);
        }

        return can_hoist_operands(inst, args);
    }
    return false;
}

static bool can_hoist_load(Instruction const& inst, HoistArgs &args)
{
    // code stolen from isSafeToSpeculativelyExecute
    const auto *LI = dyn_cast<LoadInst>(&inst);
    if (!LI) {
        LLVM_DEBUG(dbgs() << "UnitLICM: Invalid instruction (expected load): " << inst << "\n");
        return false;
    }
    if (mustSuppressSpeculation(*LI)) {
        LLVM_DEBUG(dbgs() << "UnitLICM: Can't hoist load: Suppressed by sanitizer:" << inst << "\n");
        return false;
    }

    return can_hoist_load_impl(inst, args);
}

static bool can_hoist_load_impl(Instruction const& inst, HoistArgs &args)
{
    auto const *LI = dyn_cast<LoadInst>(&inst);
    if (!LI) {
        LLVM_DEBUG(dbgs() << "UnitLICM: Invalid instruction (expected load): " << inst << "\n");
        return false;
    }

    auto const& aa = args.loop.alias_for(LI);

    // If the load always reads the same mem loc, we can hoist it

    if (aa.isMod()) {
        LLVM_DEBUG(dbgs() << "UnitLICM: Can't hoist load: Load aliases modified memory:" << inst << "\n");
        return false;
    }

    return can_hoist_operands(inst, args);
}

static bool can_hoist_store(Instruction const& inst, HoistArgs &args)
{
    auto const *SI = dyn_cast<StoreInst>(&inst);
    if (!SI) {
        LLVM_DEBUG(dbgs() << "UnitLICM: Invalid instruction (expected store):\t" << inst << "\n");
        return false;
    }

    // Volatile stores may not return, so we can't hoist them
    if (SI->isVolatile())
        return false;

    auto const& aa = args.loop.alias_for(SI);

    // If the store always modifies the same mem loc and it is never read, we can hoist it

    if (!aa.isMustAlias()) {
        LLVM_DEBUG(dbgs() << "UnitLICM: Can't hoist store: Store does not always write the same memory:" << inst << "\n");
        return false;
    }

    if (aa.isRef()) {
        LLVM_DEBUG(dbgs() << "UnitLICM: Can't hoist store: Store aliases read memory:" << inst << "\n");
        return false;
    }

    // Writing to memory is side-effectful
    return can_hoist_sidefx(inst, args);
}

/// We can only host expressions with side effects if they dominate all exits.
static bool can_hoist_sidefx(Instruction const& inst, HoistArgs& args)
{
    if (!all_of(args.loop.exits, [&args, &inst](BasicBlock const* exit) {
        return args.DT.dominates(&inst, exit);
    })) {
        LLVM_DEBUG(dbgs() << "UnitLICM: Can't hoist inst w/ sidefx: Does not dominate all exits:" << inst << "\n");
        return false;
    }

    return can_hoist_operands(inst, args);
}

static bool can_hoist_operands(Instruction const& inst, HoistArgs &args)
{
    return std::ranges::all_of(inst.operands(), [&args, &inst](Use const& V) {
        bool const r = can_hoist(*V.get(), args);
        LLVM_DEBUG(if (!r) dbgs() << inst << "\n");
        return r;
    });
}