#include "llvm/Transforms/Utils/IntegerDivision.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
using namespace llvm;
#define DEBUG_TYPE "integer-division"
static Value *generateSignedRemainderCode(Value *Dividend, Value *Divisor,
IRBuilder<> &Builder) {
unsigned BitWidth = Dividend->getType()->getIntegerBitWidth();
ConstantInt *Shift;
if (BitWidth == 64) {
Shift = Builder.getInt64(63);
} else {
assert(BitWidth == 32 && "Unexpected bit width");
Shift = Builder.getInt32(31);
}
Value *DividendSign = Builder.CreateAShr(Dividend, Shift);
Value *DivisorSign = Builder.CreateAShr(Divisor, Shift);
Value *DvdXor = Builder.CreateXor(Dividend, DividendSign);
Value *DvsXor = Builder.CreateXor(Divisor, DivisorSign);
Value *UDividend = Builder.CreateSub(DvdXor, DividendSign);
Value *UDivisor = Builder.CreateSub(DvsXor, DivisorSign);
Value *URem = Builder.CreateURem(UDividend, UDivisor);
Value *Xored = Builder.CreateXor(URem, DividendSign);
Value *SRem = Builder.CreateSub(Xored, DividendSign);
if (Instruction *URemInst = dyn_cast<Instruction>(URem))
Builder.SetInsertPoint(URemInst);
return SRem;
}
static Value *generatedUnsignedRemainderCode(Value *Dividend, Value *Divisor,
IRBuilder<> &Builder) {
Value *Quotient = Builder.CreateUDiv(Dividend, Divisor);
Value *Product = Builder.CreateMul(Divisor, Quotient);
Value *Remainder = Builder.CreateSub(Dividend, Product);
if (Instruction *UDiv = dyn_cast<Instruction>(Quotient))
Builder.SetInsertPoint(UDiv);
return Remainder;
}
static Value *generateSignedDivisionCode(Value *Dividend, Value *Divisor,
IRBuilder<> &Builder) {
unsigned BitWidth = Dividend->getType()->getIntegerBitWidth();
ConstantInt *Shift;
if (BitWidth == 64) {
Shift = Builder.getInt64(63);
} else {
assert(BitWidth == 32 && "Unexpected bit width");
Shift = Builder.getInt32(31);
}
Value *Tmp = Builder.CreateAShr(Dividend, Shift);
Value *Tmp1 = Builder.CreateAShr(Divisor, Shift);
Value *Tmp2 = Builder.CreateXor(Tmp, Dividend);
Value *U_Dvnd = Builder.CreateSub(Tmp2, Tmp);
Value *Tmp3 = Builder.CreateXor(Tmp1, Divisor);
Value *U_Dvsr = Builder.CreateSub(Tmp3, Tmp1);
Value *Q_Sgn = Builder.CreateXor(Tmp1, Tmp);
Value *Q_Mag = Builder.CreateUDiv(U_Dvnd, U_Dvsr);
Value *Tmp4 = Builder.CreateXor(Q_Mag, Q_Sgn);
Value *Q = Builder.CreateSub(Tmp4, Q_Sgn);
if (Instruction *UDiv = dyn_cast<Instruction>(Q_Mag))
Builder.SetInsertPoint(UDiv);
return Q;
}
static Value *generateUnsignedDivisionCode(Value *Dividend, Value *Divisor,
IRBuilder<> &Builder) {
IntegerType *DivTy = cast<IntegerType>(Dividend->getType());
unsigned BitWidth = DivTy->getBitWidth();
ConstantInt *Zero;
ConstantInt *One;
ConstantInt *NegOne;
ConstantInt *MSB;
if (BitWidth == 64) {
Zero = Builder.getInt64(0);
One = Builder.getInt64(1);
NegOne = ConstantInt::getSigned(DivTy, -1);
MSB = Builder.getInt64(63);
} else {
assert(BitWidth == 32 && "Unexpected bit width");
Zero = Builder.getInt32(0);
One = Builder.getInt32(1);
NegOne = ConstantInt::getSigned(DivTy, -1);
MSB = Builder.getInt32(31);
}
ConstantInt *True = Builder.getTrue();
BasicBlock *IBB = Builder.GetInsertBlock();
Function *F = IBB->getParent();
Function *CTLZ = Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz,
DivTy);
BasicBlock *SpecialCases = Builder.GetInsertBlock();
SpecialCases->setName(Twine(SpecialCases->getName(), "_udiv-special-cases"));
BasicBlock *End = SpecialCases->splitBasicBlock(Builder.GetInsertPoint(),
"udiv-end");
BasicBlock *LoopExit = BasicBlock::Create(Builder.getContext(),
"udiv-loop-exit", F, End);
BasicBlock *DoWhile = BasicBlock::Create(Builder.getContext(),
"udiv-do-while", F, End);
BasicBlock *Preheader = BasicBlock::Create(Builder.getContext(),
"udiv-preheader", F, End);
BasicBlock *BB1 = BasicBlock::Create(Builder.getContext(),
"udiv-bb1", F, End);
SpecialCases->getTerminator()->eraseFromParent();
Builder.SetInsertPoint(SpecialCases);
Value *Ret0_1 = Builder.CreateICmpEQ(Divisor, Zero);
Value *Ret0_2 = Builder.CreateICmpEQ(Dividend, Zero);
Value *Ret0_3 = Builder.CreateOr(Ret0_1, Ret0_2);
Value *Tmp0 = Builder.CreateCall(CTLZ, {Divisor, True});
Value *Tmp1 = Builder.CreateCall(CTLZ, {Dividend, True});
Value *SR = Builder.CreateSub(Tmp0, Tmp1);
Value *Ret0_4 = Builder.CreateICmpUGT(SR, MSB);
Value *Ret0 = Builder.CreateOr(Ret0_3, Ret0_4);
Value *RetDividend = Builder.CreateICmpEQ(SR, MSB);
Value *RetVal = Builder.CreateSelect(Ret0, Zero, Dividend);
Value *EarlyRet = Builder.CreateOr(Ret0, RetDividend);
Builder.CreateCondBr(EarlyRet, End, BB1);
Builder.SetInsertPoint(BB1);
Value *SR_1 = Builder.CreateAdd(SR, One);
Value *Tmp2 = Builder.CreateSub(MSB, SR);
Value *Q = Builder.CreateShl(Dividend, Tmp2);
Value *SkipLoop = Builder.CreateICmpEQ(SR_1, Zero);
Builder.CreateCondBr(SkipLoop, LoopExit, Preheader);
Builder.SetInsertPoint(Preheader);
Value *Tmp3 = Builder.CreateLShr(Dividend, SR_1);
Value *Tmp4 = Builder.CreateAdd(Divisor, NegOne);
Builder.CreateBr(DoWhile);
Builder.SetInsertPoint(DoWhile);
PHINode *Carry_1 = Builder.CreatePHI(DivTy, 2);
PHINode *SR_3 = Builder.CreatePHI(DivTy, 2);
PHINode *R_1 = Builder.CreatePHI(DivTy, 2);
PHINode *Q_2 = Builder.CreatePHI(DivTy, 2);
Value *Tmp5 = Builder.CreateShl(R_1, One);
Value *Tmp6 = Builder.CreateLShr(Q_2, MSB);
Value *Tmp7 = Builder.CreateOr(Tmp5, Tmp6);
Value *Tmp8 = Builder.CreateShl(Q_2, One);
Value *Q_1 = Builder.CreateOr(Carry_1, Tmp8);
Value *Tmp9 = Builder.CreateSub(Tmp4, Tmp7);
Value *Tmp10 = Builder.CreateAShr(Tmp9, MSB);
Value *Carry = Builder.CreateAnd(Tmp10, One);
Value *Tmp11 = Builder.CreateAnd(Tmp10, Divisor);
Value *R = Builder.CreateSub(Tmp7, Tmp11);
Value *SR_2 = Builder.CreateAdd(SR_3, NegOne);
Value *Tmp12 = Builder.CreateICmpEQ(SR_2, Zero);
Builder.CreateCondBr(Tmp12, LoopExit, DoWhile);
Builder.SetInsertPoint(LoopExit);
PHINode *Carry_2 = Builder.CreatePHI(DivTy, 2);
PHINode *Q_3 = Builder.CreatePHI(DivTy, 2);
Value *Tmp13 = Builder.CreateShl(Q_3, One);
Value *Q_4 = Builder.CreateOr(Carry_2, Tmp13);
Builder.CreateBr(End);
Builder.SetInsertPoint(End, End->begin());
PHINode *Q_5 = Builder.CreatePHI(DivTy, 2);
Carry_1->addIncoming(Zero, Preheader);
Carry_1->addIncoming(Carry, DoWhile);
SR_3->addIncoming(SR_1, Preheader);
SR_3->addIncoming(SR_2, DoWhile);
R_1->addIncoming(Tmp3, Preheader);
R_1->addIncoming(R, DoWhile);
Q_2->addIncoming(Q, Preheader);
Q_2->addIncoming(Q_1, DoWhile);
Carry_2->addIncoming(Zero, BB1);
Carry_2->addIncoming(Carry, DoWhile);
Q_3->addIncoming(Q, BB1);
Q_3->addIncoming(Q_1, DoWhile);
Q_5->addIncoming(Q_4, LoopExit);
Q_5->addIncoming(RetVal, SpecialCases);
return Q_5;
}
bool llvm::expandRemainder(BinaryOperator *Rem) {
assert((Rem->getOpcode() == Instruction::SRem ||
Rem->getOpcode() == Instruction::URem) &&
"Trying to expand remainder from a non-remainder function");
IRBuilder<> Builder(Rem);
assert(!Rem->getType()->isVectorTy() && "Div over vectors not supported");
assert((Rem->getType()->getIntegerBitWidth() == 32 ||
Rem->getType()->getIntegerBitWidth() == 64) &&
"Div of bitwidth other than 32 or 64 not supported");
if (Rem->getOpcode() == Instruction::SRem) {
Value *Remainder = generateSignedRemainderCode(Rem->getOperand(0),
Rem->getOperand(1), Builder);
bool IsInsertPoint = Rem->getIterator() == Builder.GetInsertPoint();
Rem->replaceAllUsesWith(Remainder);
Rem->dropAllReferences();
Rem->eraseFromParent();
if (IsInsertPoint)
return true;
BinaryOperator *BO = dyn_cast<BinaryOperator>(Builder.GetInsertPoint());
Rem = BO;
}
Value *Remainder = generatedUnsignedRemainderCode(Rem->getOperand(0),
Rem->getOperand(1),
Builder);
Rem->replaceAllUsesWith(Remainder);
Rem->dropAllReferences();
Rem->eraseFromParent();
if (BinaryOperator *UDiv = dyn_cast<BinaryOperator>(Builder.GetInsertPoint())) {
assert(UDiv->getOpcode() == Instruction::UDiv && "Non-udiv in expansion?");
expandDivision(UDiv);
}
return true;
}
bool llvm::expandDivision(BinaryOperator *Div) {
assert((Div->getOpcode() == Instruction::SDiv ||
Div->getOpcode() == Instruction::UDiv) &&
"Trying to expand division from a non-division function");
IRBuilder<> Builder(Div);
assert(!Div->getType()->isVectorTy() && "Div over vectors not supported");
assert((Div->getType()->getIntegerBitWidth() == 32 ||
Div->getType()->getIntegerBitWidth() == 64) &&
"Div of bitwidth other than 32 or 64 not supported");
if (Div->getOpcode() == Instruction::SDiv) {
Value *Quotient = generateSignedDivisionCode(Div->getOperand(0),
Div->getOperand(1), Builder);
bool IsInsertPoint = Div->getIterator() == Builder.GetInsertPoint();
Div->replaceAllUsesWith(Quotient);
Div->dropAllReferences();
Div->eraseFromParent();
if (IsInsertPoint)
return true;
BinaryOperator *BO = dyn_cast<BinaryOperator>(Builder.GetInsertPoint());
Div = BO;
}
Value *Quotient = generateUnsignedDivisionCode(Div->getOperand(0),
Div->getOperand(1),
Builder);
Div->replaceAllUsesWith(Quotient);
Div->dropAllReferences();
Div->eraseFromParent();
return true;
}
bool llvm::expandRemainderUpTo32Bits(BinaryOperator *Rem) {
assert((Rem->getOpcode() == Instruction::SRem ||
Rem->getOpcode() == Instruction::URem) &&
"Trying to expand remainder from a non-remainder function");
Type *RemTy = Rem->getType();
assert(!RemTy->isVectorTy() && "Div over vectors not supported");
unsigned RemTyBitWidth = RemTy->getIntegerBitWidth();
assert(RemTyBitWidth <= 32 &&
"Div of bitwidth greater than 32 not supported");
if (RemTyBitWidth == 32)
return expandRemainder(Rem);
IRBuilder<> Builder(Rem);
Value *ExtDividend;
Value *ExtDivisor;
Value *ExtRem;
Value *Trunc;
Type *Int32Ty = Builder.getInt32Ty();
if (Rem->getOpcode() == Instruction::SRem) {
ExtDividend = Builder.CreateSExt(Rem->getOperand(0), Int32Ty);
ExtDivisor = Builder.CreateSExt(Rem->getOperand(1), Int32Ty);
ExtRem = Builder.CreateSRem(ExtDividend, ExtDivisor);
} else {
ExtDividend = Builder.CreateZExt(Rem->getOperand(0), Int32Ty);
ExtDivisor = Builder.CreateZExt(Rem->getOperand(1), Int32Ty);
ExtRem = Builder.CreateURem(ExtDividend, ExtDivisor);
}
Trunc = Builder.CreateTrunc(ExtRem, RemTy);
Rem->replaceAllUsesWith(Trunc);
Rem->dropAllReferences();
Rem->eraseFromParent();
return expandRemainder(cast<BinaryOperator>(ExtRem));
}
bool llvm::expandRemainderUpTo64Bits(BinaryOperator *Rem) {
assert((Rem->getOpcode() == Instruction::SRem ||
Rem->getOpcode() == Instruction::URem) &&
"Trying to expand remainder from a non-remainder function");
Type *RemTy = Rem->getType();
assert(!RemTy->isVectorTy() && "Div over vectors not supported");
unsigned RemTyBitWidth = RemTy->getIntegerBitWidth();
assert(RemTyBitWidth <= 64 && "Div of bitwidth greater than 64 not supported");
if (RemTyBitWidth == 64)
return expandRemainder(Rem);
IRBuilder<> Builder(Rem);
Value *ExtDividend;
Value *ExtDivisor;
Value *ExtRem;
Value *Trunc;
Type *Int64Ty = Builder.getInt64Ty();
if (Rem->getOpcode() == Instruction::SRem) {
ExtDividend = Builder.CreateSExt(Rem->getOperand(0), Int64Ty);
ExtDivisor = Builder.CreateSExt(Rem->getOperand(1), Int64Ty);
ExtRem = Builder.CreateSRem(ExtDividend, ExtDivisor);
} else {
ExtDividend = Builder.CreateZExt(Rem->getOperand(0), Int64Ty);
ExtDivisor = Builder.CreateZExt(Rem->getOperand(1), Int64Ty);
ExtRem = Builder.CreateURem(ExtDividend, ExtDivisor);
}
Trunc = Builder.CreateTrunc(ExtRem, RemTy);
Rem->replaceAllUsesWith(Trunc);
Rem->dropAllReferences();
Rem->eraseFromParent();
return expandRemainder(cast<BinaryOperator>(ExtRem));
}
bool llvm::expandDivisionUpTo32Bits(BinaryOperator *Div) {
assert((Div->getOpcode() == Instruction::SDiv ||
Div->getOpcode() == Instruction::UDiv) &&
"Trying to expand division from a non-division function");
Type *DivTy = Div->getType();
assert(!DivTy->isVectorTy() && "Div over vectors not supported");
unsigned DivTyBitWidth = DivTy->getIntegerBitWidth();
assert(DivTyBitWidth <= 32 && "Div of bitwidth greater than 32 not supported");
if (DivTyBitWidth == 32)
return expandDivision(Div);
IRBuilder<> Builder(Div);
Value *ExtDividend;
Value *ExtDivisor;
Value *ExtDiv;
Value *Trunc;
Type *Int32Ty = Builder.getInt32Ty();
if (Div->getOpcode() == Instruction::SDiv) {
ExtDividend = Builder.CreateSExt(Div->getOperand(0), Int32Ty);
ExtDivisor = Builder.CreateSExt(Div->getOperand(1), Int32Ty);
ExtDiv = Builder.CreateSDiv(ExtDividend, ExtDivisor);
} else {
ExtDividend = Builder.CreateZExt(Div->getOperand(0), Int32Ty);
ExtDivisor = Builder.CreateZExt(Div->getOperand(1), Int32Ty);
ExtDiv = Builder.CreateUDiv(ExtDividend, ExtDivisor);
}
Trunc = Builder.CreateTrunc(ExtDiv, DivTy);
Div->replaceAllUsesWith(Trunc);
Div->dropAllReferences();
Div->eraseFromParent();
return expandDivision(cast<BinaryOperator>(ExtDiv));
}
bool llvm::expandDivisionUpTo64Bits(BinaryOperator *Div) {
assert((Div->getOpcode() == Instruction::SDiv ||
Div->getOpcode() == Instruction::UDiv) &&
"Trying to expand division from a non-division function");
Type *DivTy = Div->getType();
assert(!DivTy->isVectorTy() && "Div over vectors not supported");
unsigned DivTyBitWidth = DivTy->getIntegerBitWidth();
assert(DivTyBitWidth <= 64 &&
"Div of bitwidth greater than 64 not supported");
if (DivTyBitWidth == 64)
return expandDivision(Div);
IRBuilder<> Builder(Div);
Value *ExtDividend;
Value *ExtDivisor;
Value *ExtDiv;
Value *Trunc;
Type *Int64Ty = Builder.getInt64Ty();
if (Div->getOpcode() == Instruction::SDiv) {
ExtDividend = Builder.CreateSExt(Div->getOperand(0), Int64Ty);
ExtDivisor = Builder.CreateSExt(Div->getOperand(1), Int64Ty);
ExtDiv = Builder.CreateSDiv(ExtDividend, ExtDivisor);
} else {
ExtDividend = Builder.CreateZExt(Div->getOperand(0), Int64Ty);
ExtDivisor = Builder.CreateZExt(Div->getOperand(1), Int64Ty);
ExtDiv = Builder.CreateUDiv(ExtDividend, ExtDivisor);
}
Trunc = Builder.CreateTrunc(ExtDiv, DivTy);
Div->replaceAllUsesWith(Trunc);
Div->dropAllReferences();
Div->eraseFromParent();
return expandDivision(cast<BinaryOperator>(ExtDiv));
}