#include <cassert>
#include <utility>
#include <functional>
#include <iterator>
#include <vector>
#include <ranges>
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Casting.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/IR/Value.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/IR/Dominators.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/AliasSetTracker.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Transforms/UnitLICM.h"
#include "llvm/Analysis/UnitLoopInfo.h"
#define DEBUG_TYPE "unit-licm"
STATISTIC(NumInstsHoisted, "Number of instructions hoisted");
STATISTIC(NumStoresHoisted, "Number of stores hoisted");
STATISTIC(NumLoadsHoisted, "Number of loads hoisted");
STATISTIC(NumComputationalHoisted, "Number of computational instructions (excluding casts and GEPs) hoisted");
STATISTIC(NumVisited, "Number of instructions considered for hoisting");
using namespace llvm;
using namespace ece479k;
/// Since all symbols are exported from the SO, we try to mark all symbols local
/// to this translation unit to improve diagnostics and improve performance.
/// Likewise, we try to keep functions pure.
///
/// To reduce confusion when describing LICM, we classify two types of "invariants":
/// - Invariant: a Value that is already invariant to the loop - it is outside the loop or not an Instruction
/// - Hoistable: a Value that can be hoisted outside the loop to become an invariant
///
/// (Un)Hoistable detection is implemented as DFS on every use-def chain in the loop,
/// stopping if we end up at an invariant or condition that prevents hoisting,
/// propagating the result back up the chain to every Value visited.
///
/// This is implemented as a series of tail-recursive calls to check each
/// instruction for any disqualifying cases or an invariant and then recursively
/// on each of its operands (the part that makes it a DFS).
///
/// Optimization misses are documented in debug output. For meaningful names,
/// make sure to run the `instnamer` pass before running `unit-licm`.
/// Example output:
/// UnitLICM: Can't hoist instruction: Depends on unsupported phi `p.0`: %idxprom = sext i32 %p.0 to i64
/// UnitLICM: Can't hoist instruction: Depends on not-hoistable sext `idxprom`: %arrayidx18 = getelementptr inbounds [8 x [3 x double]], ptr %position, i64 0, i64 %idxprom
/// UnitLICM: Can't hoist instruction: Depends on not-hoistable getelementptr `arrayidx18`: %arraydecay19 = getelementptr inbounds [3 x double], ptr %arrayidx18, i64 0, i64 0
static bool is_supported(Instruction const&) __attribute((pure));
static bool is_computational(Instruction const&) __attribute((pure));
namespace {
struct HoistArgs {
UnitLoopInfo const& loop;
DenseMap<Instruction const*, bool> cache;
DominatorTree const& DT;
};
}
static bool can_hoist(Instruction const&, HoistArgs &) __attribute((pure));
static void hoist(Instruction&, HoistArgs &);
// ReSharper disable once CppMemberFunctionMayBeStatic
PreservedAnalyses UnitLICM::run(Function &F, FunctionAnalysisManager &FAM) // NOLINT(*-convert-member-functions-to-static)
{
for (auto const& loop: FAM.getResult<UnitLoopAnalysis>(F)) {
auto hoist_args = HoistArgs{
loop,
DenseMap<Instruction const*, bool>(32),
FAM.getResult<DominatorTreeAnalysis>(F)
};
hoist_args.cache.clear();
auto const is_hoistable = [&loop, &hoist_args](Instruction const& inst) {
return is_supported(inst) && can_hoist(inst, hoist_args);
};
// Hoist everything we can
for (auto i: loop.instructions(is_hoistable))
{
hoist(*i, hoist_args);
if (isa<LoadInst>(i))
++NumLoadsHoisted;
else if (isa<StoreInst>(i))
++NumStoresHoisted;
else if (is_computational(*i))
++NumComputationalHoisted;
++NumInstsHoisted;
}
}
// Invalidate analyses (we fucked the code up)
return PreservedAnalyses::none();
}
// This switch is inlined instead of chaining isXXX calls because it's easier to see what's going on
// and we get slightly better codegen (practically it doesn't matter)
static bool is_supported(Instruction const& inst)
{
switch (inst.getOpcode()) {
default:
return false;
// unary operator
case Instruction::FNeg:
// binary operator
case Instruction::Add:
case Instruction::FAdd:
case Instruction::Sub:
case Instruction::FSub:
case Instruction::Mul:
case Instruction::FMul:
case Instruction::UDiv:
case Instruction::SDiv:
case Instruction::FDiv:
case Instruction::URem:
case Instruction::SRem:
case Instruction::FRem:
// logical operator
case Instruction::Shl:
case Instruction::LShr:
case Instruction::AShr:
case Instruction::And:
case Instruction::Or:
case Instruction::Xor:
// memory operator
case Instruction::Load:
case Instruction::Store:
case Instruction::GetElementPtr:
// cast operator
case Instruction::Trunc:
case Instruction::ZExt:
case Instruction::SExt:
case Instruction::FPToUI:
case Instruction::FPToSI:
case Instruction::UIToFP:
case Instruction::SIToFP:
case Instruction::FPTrunc:
case Instruction::FPExt:
case Instruction::PtrToInt:
case Instruction::IntToPtr:
case Instruction::BitCast:
case Instruction::AddrSpaceCast:
// misc operator
case Instruction::Select:
case Instruction::ICmp:
case Instruction::FCmp:
return true;
}
}
static bool is_computational(Instruction const& inst)
{
switch (inst.getOpcode()) {
default:
return false;
// unary operator
case Instruction::FNeg:
// binary operator
case Instruction::Add:
case Instruction::FAdd:
case Instruction::Sub:
case Instruction::FSub:
case Instruction::Mul:
case Instruction::FMul:
case Instruction::UDiv:
case Instruction::SDiv:
case Instruction::FDiv:
case Instruction::URem:
case Instruction::SRem:
case Instruction::FRem:
// logical operator
case Instruction::Shl:
case Instruction::LShr:
case Instruction::AShr:
case Instruction::And:
case Instruction::Or:
case Instruction::Xor:
// misc operator
case Instruction::Select:
case Instruction::ICmp:
case Instruction::FCmp:
return true;
}
}
static void hoist(Instruction& inst, HoistArgs &args)
{
assert(!args.loop.is_invariant(inst) && "Cannot hoist variant");
auto* insertion_point = const_cast<Instruction *>(args.loop.preheader->getTerminator());
inst.moveBefore(insertion_point);
assert(args.loop.is_invariant(inst) && "Hoist did not make invariant");
}
// These are called in a tail-recursive (ish) manner
static bool can_hoist_impl(Instruction const&, HoistArgs &) __attribute((pure));
static bool can_hoist(Value const&, HoistArgs &) __attribute((pure));
static bool can_hoist_udiv(Instruction const&, HoistArgs &) __attribute((pure));
static bool can_hoist_sdiv(Instruction const&, HoistArgs &) __attribute((pure));
static bool can_hoist_store(Instruction const&, HoistArgs &) __attribute((pure));
static bool can_hoist_load(Instruction const&, HoistArgs &) __attribute((pure));
static bool can_hoist_load_impl(Instruction const&, HoistArgs &) __attribute((pure));
static bool can_hoist_sidefx(Instruction const&, HoistArgs&) __attribute((pure));
static bool can_hoist_operands(Instruction const&, HoistArgs &) __attribute((pure));
static bool can_hoist(Instruction const& inst, HoistArgs &args)
{
if (args.loop.is_invariant(inst))
return true;
if (!is_supported(inst)) {
LLVM_DEBUG(dbgs() << "UnitLICM: Can't hoist instruction: Depends on unsupported " << inst.getOpcodeName() << " `" << inst.getName() << "`: ");
return false;
}
auto const it = args.cache.find(&inst);
if (it != args.cache.end()) {
LLVM_DEBUG(if (!it->second)
dbgs() << "UnitLICM: Can't hoist instruction: Depends on not-hoistable " << inst.getOpcodeName() << " `" << inst.getName() << "`: ");
return it->second;
}
bool const dfs = can_hoist_impl(inst, args);
args.cache.insert({&inst, dfs});
return dfs;
}
static bool can_hoist(Value const& val, HoistArgs &args)
{
if (auto const* inst = dyn_cast<Instruction>(&val))
return can_hoist(*inst, args);
return true;
}
/// This is effectively a specialization of isSafeToSpeculativelyExecute for our case:
/// - We support stores
/// - We don't bail directly if an instruction has sidefx - we check if it dominates all exits to try and hoist it anyway
/// - We don't check for instructions we don't support
///
/// This will dispatch to special cases for instructions that can have side effects
/// or cause the use-def chain to be marked as not-hoistable. Otherwise, this
/// will directly dispatch to the next step of the DFS
static bool can_hoist_impl(Instruction const& inst, HoistArgs &args)
{
assert(is_supported(inst) && "Instruction is not supported");
++NumVisited;
switch (inst.getOpcode()) {
default: // Instruction doesn't have edge cases
return can_hoist_operands(inst, args);
case Instruction::Load: // Need to do alias analysis
return can_hoist_load(inst, args);
case Instruction::Store: // This one always has sidefx, and we need to do alias analysis
return can_hoist_store(inst, args);
case Instruction::UDiv:
case Instruction::URem: // Division by zero can throw
return can_hoist_udiv(inst, args);
case Instruction::SDiv:
case Instruction::SRem: // Division by zero or of INT_MIN can throw
return can_hoist_sdiv(inst, args);
}
}
static bool can_hoist_sdiv(Instruction const& inst, HoistArgs &args)
{
// code stolen from isSafeToSpeculativelyExecute
using namespace llvm::PatternMatch;
// x / y is undefined if y == 0 or x == INT_MIN and y == -1
const APInt *Numerator, *Denominator;
if (!match(inst.getOperand(1), m_APInt(Denominator))) {
// LLVM_DEBUG(dbgs() << "UnitLICM: Can't hoist " << inst.getOpcodeName() << ": Undefined denominator " << inst.getOperand(1) << ":" << inst << "\n");
return can_hoist_sidefx(inst, args);
}
// We cannot hoist this division if the denominator is 0.
if (*Denominator == 0) {
// LLVM_DEBUG(dbgs() << "UnitLICM: Can't hoist " << inst.getOpcodeName() << ": Division by zero:" << inst << "\n");
return can_hoist_sidefx(inst, args);
}
// It's safe to hoist if the denominator is not 0 or -1.
if (!Denominator->isAllOnes()) {
return can_hoist_operands(inst, args);
}
// At this point we know that the denominator is -1. It is safe to hoist as
// long we know that the numerator is not INT_MIN.
if (match(inst.getOperand(0), m_APInt(Numerator))) {
if (Numerator->isMinSignedValue()) {
// LLVM_DEBUG(dbgs() << "UnitLICM: Can't hoist " << inst.getOpcodeName() << ": Division of INT_MIN by -1:" << inst << "\n");
return can_hoist_sidefx(inst, args);
}
return can_hoist_operands(inst, args);
}
// The numerator *might* be MinSignedValue.
// LLVM_DEBUG(dbgs() << "UnitLICM: Can't hoist " << inst.getOpcodeName() << ": Division of INT_MIN( " << Numerator << " ) by -1:" << inst << "\n");
return can_hoist_sidefx(inst, args);
}
static bool can_hoist_udiv(Instruction const& inst, HoistArgs &args)
{
// code stolen from isSafeToSpeculativelyExecute
using namespace llvm::PatternMatch;
// x / y is undefined if y == 0.
const APInt *V;
if (match(inst.getOperand(1), m_APInt(V))) {
if (!*V) {
// LLVM_DEBUG(dbgs() << "UnitLICM: Can't hoist " << inst.getOpcodeName() << ": Division by zero:" << inst << "\n");
return can_hoist_sidefx(inst, args);
}
return can_hoist_operands(inst, args);
}
return false;
}
static bool can_hoist_load(Instruction const& inst, HoistArgs &args)
{
// code stolen from isSafeToSpeculativelyExecute
const auto *LI = dyn_cast<LoadInst>(&inst);
if (!LI) {
LLVM_DEBUG(dbgs() << "UnitLICM: Invalid instruction (expected load): " << inst << "\n");
return false;
}
if (mustSuppressSpeculation(*LI)) {
LLVM_DEBUG(dbgs() << "UnitLICM: Can't hoist load: Suppressed by sanitizer:" << inst << "\n");
return false;
}
return can_hoist_load_impl(inst, args);
}
static bool can_hoist_load_impl(Instruction const& inst, HoistArgs &args)
{
auto const *LI = dyn_cast<LoadInst>(&inst);
if (!LI) {
LLVM_DEBUG(dbgs() << "UnitLICM: Invalid instruction (expected load): " << inst << "\n");
return false;
}
auto const& aa = args.loop.alias_for(LI);
// If the load always reads the same mem loc, we can hoist it
if (aa.isMod()) {
LLVM_DEBUG(dbgs() << "UnitLICM: Can't hoist load: Load aliases modified memory:" << inst << "\n");
return false;
}
return can_hoist_operands(inst, args);
}
static bool can_hoist_store(Instruction const& inst, HoistArgs &args)
{
auto const *SI = dyn_cast<StoreInst>(&inst);
if (!SI) {
LLVM_DEBUG(dbgs() << "UnitLICM: Invalid instruction (expected store):\t" << inst << "\n");
return false;
}
// Volatile stores may not return, so we can't hoist them
if (SI->isVolatile())
return false;
auto const& aa = args.loop.alias_for(SI);
// If the store always modifies the same mem loc and it is never read, we can hoist it
if (!aa.isMustAlias()) {
LLVM_DEBUG(dbgs() << "UnitLICM: Can't hoist store: Store does not always write the same memory:" << inst << "\n");
return false;
}
if (aa.isRef()) {
LLVM_DEBUG(dbgs() << "UnitLICM: Can't hoist store: Store aliases read memory:" << inst << "\n");
return false;
}
// Writing to memory is side-effectful
return can_hoist_sidefx(inst, args);
}
/// We can only host expressions with side effects if they dominate all exits.
static bool can_hoist_sidefx(Instruction const& inst, HoistArgs& args)
{
if (!all_of(args.loop.exits, [&args, &inst](BasicBlock const* exit) {
return args.DT.dominates(&inst, exit);
})) {
LLVM_DEBUG(dbgs() << "UnitLICM: Can't hoist inst w/ sidefx: Does not dominate all exits:" << inst << "\n");
return false;
}
return can_hoist_operands(inst, args);
}
static bool can_hoist_operands(Instruction const& inst, HoistArgs &args)
{
return std::ranges::all_of(inst.operands(), [&args, &inst](Use const& V) {
bool const r = can_hoist(*V.get(), args);
LLVM_DEBUG(if (!r) dbgs() << inst << "\n");
return r;
});
}