#include "llvm/Transforms/Utils/BypassSlowDivision.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/KnownBits.h"
#include <cassert>
#include <cstdint>
using namespace llvm;
#define DEBUG_TYPE "bypass-slow-division"
namespace {
struct QuotRemPair {
Value *Quotient;
Value *Remainder;
QuotRemPair(Value *InQuotient, Value *InRemainder)
: Quotient(InQuotient), Remainder(InRemainder) {}
};
struct QuotRemWithBB {
BasicBlock *BB = nullptr;
Value *Quotient = nullptr;
Value *Remainder = nullptr;
};
using DivCacheTy = DenseMap<DivRemMapKey, QuotRemPair>;
using BypassWidthsTy = DenseMap<unsigned, unsigned>;
using VisitedSetTy = SmallPtrSet<Instruction *, 4>;
enum ValueRange {
VALRNG_KNOWN_SHORT,
VALRNG_UNKNOWN,
VALRNG_LIKELY_LONG
};
class FastDivInsertionTask {
bool IsValidTask = false;
Instruction *SlowDivOrRem = nullptr;
IntegerType *BypassType = nullptr;
BasicBlock *MainBB = nullptr;
bool isHashLikeValue(Value *V, VisitedSetTy &Visited);
ValueRange getValueRange(Value *Op, VisitedSetTy &Visited);
QuotRemWithBB createSlowBB(BasicBlock *Successor);
QuotRemWithBB createFastBB(BasicBlock *Successor);
QuotRemPair createDivRemPhiNodes(QuotRemWithBB &LHS, QuotRemWithBB &RHS,
BasicBlock *PhiBB);
Value *insertOperandRuntimeCheck(Value *Op1, Value *Op2);
Optional<QuotRemPair> insertFastDivAndRem();
bool isSignedOp() {
return SlowDivOrRem->getOpcode() == Instruction::SDiv ||
SlowDivOrRem->getOpcode() == Instruction::SRem;
}
bool isDivisionOp() {
return SlowDivOrRem->getOpcode() == Instruction::SDiv ||
SlowDivOrRem->getOpcode() == Instruction::UDiv;
}
Type *getSlowType() { return SlowDivOrRem->getType(); }
public:
FastDivInsertionTask(Instruction *I, const BypassWidthsTy &BypassWidths);
Value *getReplacement(DivCacheTy &Cache);
};
}
FastDivInsertionTask::FastDivInsertionTask(Instruction *I,
const BypassWidthsTy &BypassWidths) {
switch (I->getOpcode()) {
case Instruction::UDiv:
case Instruction::SDiv:
case Instruction::URem:
case Instruction::SRem:
SlowDivOrRem = I;
break;
default:
return;
}
IntegerType *SlowType = dyn_cast<IntegerType>(SlowDivOrRem->getType());
if (!SlowType)
return;
auto BI = BypassWidths.find(SlowType->getBitWidth());
if (BI == BypassWidths.end())
return;
IntegerType *BT = IntegerType::get(I->getContext(), BI->second);
BypassType = BT;
MainBB = I->getParent();
IsValidTask = true;
}
Value *FastDivInsertionTask::getReplacement(DivCacheTy &Cache) {
if (!IsValidTask)
return nullptr;
Value *Dividend = SlowDivOrRem->getOperand(0);
Value *Divisor = SlowDivOrRem->getOperand(1);
DivRemMapKey Key(isSignedOp(), Dividend, Divisor);
auto CacheI = Cache.find(Key);
if (CacheI == Cache.end()) {
Optional<QuotRemPair> OptResult = insertFastDivAndRem();
if (!OptResult)
return nullptr;
CacheI = Cache.insert({Key, *OptResult}).first;
}
QuotRemPair &Value = CacheI->second;
return isDivisionOp() ? Value.Quotient : Value.Remainder;
}
bool FastDivInsertionTask::isHashLikeValue(Value *V, VisitedSetTy &Visited) {
Instruction *I = dyn_cast<Instruction>(V);
if (!I)
return false;
switch (I->getOpcode()) {
case Instruction::Xor:
return true;
case Instruction::Mul: {
Value *Op1 = I->getOperand(1);
ConstantInt *C = dyn_cast<ConstantInt>(Op1);
if (!C && isa<BitCastInst>(Op1))
C = dyn_cast<ConstantInt>(cast<BitCastInst>(Op1)->getOperand(0));
return C && C->getValue().getMinSignedBits() > BypassType->getBitWidth();
}
case Instruction::PHI:
if (Visited.size() >= 16)
return false;
if (!Visited.insert(I).second)
return true;
return llvm::all_of(cast<PHINode>(I)->incoming_values(), [&](Value *V) {
return getValueRange(V, Visited) == VALRNG_LIKELY_LONG ||
isa<UndefValue>(V);
});
default:
return false;
}
}
ValueRange FastDivInsertionTask::getValueRange(Value *V,
VisitedSetTy &Visited) {
unsigned ShortLen = BypassType->getBitWidth();
unsigned LongLen = V->getType()->getIntegerBitWidth();
assert(LongLen > ShortLen && "Value type must be wider than BypassType");
unsigned HiBits = LongLen - ShortLen;
const DataLayout &DL = SlowDivOrRem->getModule()->getDataLayout();
KnownBits Known(LongLen);
computeKnownBits(V, Known, DL);
if (Known.countMinLeadingZeros() >= HiBits)
return VALRNG_KNOWN_SHORT;
if (Known.countMaxLeadingZeros() < HiBits)
return VALRNG_LIKELY_LONG;
if (isHashLikeValue(V, Visited))
return VALRNG_LIKELY_LONG;
return VALRNG_UNKNOWN;
}
QuotRemWithBB FastDivInsertionTask::createSlowBB(BasicBlock *SuccessorBB) {
QuotRemWithBB DivRemPair;
DivRemPair.BB = BasicBlock::Create(MainBB->getParent()->getContext(), "",
MainBB->getParent(), SuccessorBB);
IRBuilder<> Builder(DivRemPair.BB, DivRemPair.BB->begin());
Builder.SetCurrentDebugLocation(SlowDivOrRem->getDebugLoc());
Value *Dividend = SlowDivOrRem->getOperand(0);
Value *Divisor = SlowDivOrRem->getOperand(1);
if (isSignedOp()) {
DivRemPair.Quotient = Builder.CreateSDiv(Dividend, Divisor);
DivRemPair.Remainder = Builder.CreateSRem(Dividend, Divisor);
} else {
DivRemPair.Quotient = Builder.CreateUDiv(Dividend, Divisor);
DivRemPair.Remainder = Builder.CreateURem(Dividend, Divisor);
}
Builder.CreateBr(SuccessorBB);
return DivRemPair;
}
QuotRemWithBB FastDivInsertionTask::createFastBB(BasicBlock *SuccessorBB) {
QuotRemWithBB DivRemPair;
DivRemPair.BB = BasicBlock::Create(MainBB->getParent()->getContext(), "",
MainBB->getParent(), SuccessorBB);
IRBuilder<> Builder(DivRemPair.BB, DivRemPair.BB->begin());
Builder.SetCurrentDebugLocation(SlowDivOrRem->getDebugLoc());
Value *Dividend = SlowDivOrRem->getOperand(0);
Value *Divisor = SlowDivOrRem->getOperand(1);
Value *ShortDivisorV =
Builder.CreateCast(Instruction::Trunc, Divisor, BypassType);
Value *ShortDividendV =
Builder.CreateCast(Instruction::Trunc, Dividend, BypassType);
Value *ShortQV = Builder.CreateUDiv(ShortDividendV, ShortDivisorV);
Value *ShortRV = Builder.CreateURem(ShortDividendV, ShortDivisorV);
DivRemPair.Quotient =
Builder.CreateCast(Instruction::ZExt, ShortQV, getSlowType());
DivRemPair.Remainder =
Builder.CreateCast(Instruction::ZExt, ShortRV, getSlowType());
Builder.CreateBr(SuccessorBB);
return DivRemPair;
}
QuotRemPair FastDivInsertionTask::createDivRemPhiNodes(QuotRemWithBB &LHS,
QuotRemWithBB &RHS,
BasicBlock *PhiBB) {
IRBuilder<> Builder(PhiBB, PhiBB->begin());
Builder.SetCurrentDebugLocation(SlowDivOrRem->getDebugLoc());
PHINode *QuoPhi = Builder.CreatePHI(getSlowType(), 2);
QuoPhi->addIncoming(LHS.Quotient, LHS.BB);
QuoPhi->addIncoming(RHS.Quotient, RHS.BB);
PHINode *RemPhi = Builder.CreatePHI(getSlowType(), 2);
RemPhi->addIncoming(LHS.Remainder, LHS.BB);
RemPhi->addIncoming(RHS.Remainder, RHS.BB);
return QuotRemPair(QuoPhi, RemPhi);
}
Value *FastDivInsertionTask::insertOperandRuntimeCheck(Value *Op1, Value *Op2) {
assert((Op1 || Op2) && "Nothing to check");
IRBuilder<> Builder(MainBB, MainBB->end());
Builder.SetCurrentDebugLocation(SlowDivOrRem->getDebugLoc());
Value *OrV;
if (Op1 && Op2)
OrV = Builder.CreateOr(Op1, Op2);
else
OrV = Op1 ? Op1 : Op2;
uint64_t BitMask = ~BypassType->getBitMask();
Value *AndV = Builder.CreateAnd(OrV, BitMask);
Value *ZeroV = ConstantInt::getSigned(getSlowType(), 0);
return Builder.CreateICmpEQ(AndV, ZeroV);
}
Optional<QuotRemPair> FastDivInsertionTask::insertFastDivAndRem() {
Value *Dividend = SlowDivOrRem->getOperand(0);
Value *Divisor = SlowDivOrRem->getOperand(1);
VisitedSetTy SetL;
ValueRange DividendRange = getValueRange(Dividend, SetL);
if (DividendRange == VALRNG_LIKELY_LONG)
return None;
VisitedSetTy SetR;
ValueRange DivisorRange = getValueRange(Divisor, SetR);
if (DivisorRange == VALRNG_LIKELY_LONG)
return None;
bool DividendShort = (DividendRange == VALRNG_KNOWN_SHORT);
bool DivisorShort = (DivisorRange == VALRNG_KNOWN_SHORT);
if (DividendShort && DivisorShort) {
IRBuilder<> Builder(SlowDivOrRem);
Value *TruncDividend = Builder.CreateTrunc(Dividend, BypassType);
Value *TruncDivisor = Builder.CreateTrunc(Divisor, BypassType);
Value *TruncDiv = Builder.CreateUDiv(TruncDividend, TruncDivisor);
Value *TruncRem = Builder.CreateURem(TruncDividend, TruncDivisor);
Value *ExtDiv = Builder.CreateZExt(TruncDiv, getSlowType());
Value *ExtRem = Builder.CreateZExt(TruncRem, getSlowType());
return QuotRemPair(ExtDiv, ExtRem);
}
if (isa<ConstantInt>(Divisor)) {
return None;
}
if (auto *BCI = dyn_cast<BitCastInst>(Divisor))
if (BCI->getParent() == SlowDivOrRem->getParent() &&
isa<ConstantInt>(BCI->getOperand(0)))
return None;
IRBuilder<> Builder(MainBB, MainBB->end());
Builder.SetCurrentDebugLocation(SlowDivOrRem->getDebugLoc());
if (DividendShort && !isSignedOp()) {
BasicBlock *SuccessorBB = MainBB->splitBasicBlock(SlowDivOrRem);
MainBB->getInstList().back().eraseFromParent();
QuotRemWithBB Long;
Long.BB = MainBB;
Long.Quotient = ConstantInt::get(getSlowType(), 0);
Long.Remainder = Dividend;
QuotRemWithBB Fast = createFastBB(SuccessorBB);
QuotRemPair Result = createDivRemPhiNodes(Fast, Long, SuccessorBB);
Value *CmpV = Builder.CreateICmpUGE(Dividend, Divisor);
Builder.CreateCondBr(CmpV, Fast.BB, SuccessorBB);
return Result;
} else {
BasicBlock *SuccessorBB = MainBB->splitBasicBlock(SlowDivOrRem);
MainBB->getInstList().back().eraseFromParent();
QuotRemWithBB Fast = createFastBB(SuccessorBB);
QuotRemWithBB Slow = createSlowBB(SuccessorBB);
QuotRemPair Result = createDivRemPhiNodes(Fast, Slow, SuccessorBB);
Value *CmpV = insertOperandRuntimeCheck(DividendShort ? nullptr : Dividend,
DivisorShort ? nullptr : Divisor);
Builder.CreateCondBr(CmpV, Fast.BB, Slow.BB);
return Result;
}
}
bool llvm::bypassSlowDivision(BasicBlock *BB,
const BypassWidthsTy &BypassWidths) {
DivCacheTy PerBBDivCache;
bool MadeChange = false;
Instruction *Next = &*BB->begin();
while (Next != nullptr) {
Instruction *I = Next;
Next = Next->getNextNode();
if (I->hasNUses(0))
continue;
FastDivInsertionTask Task(I, BypassWidths);
if (Value *Replacement = Task.getReplacement(PerBBDivCache)) {
I->replaceAllUsesWith(Replacement);
I->eraseFromParent();
MadeChange = true;
}
}
for (auto &KV : PerBBDivCache)
for (Value *V : {KV.second.Quotient, KV.second.Remainder})
RecursivelyDeleteTriviallyDeadInstructions(V);
return MadeChange;
}