#include "llvm/CodeGen/ReplaceWithVeclib.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/DemandedBits.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/CodeGen/Passes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
using namespace llvm;
#define DEBUG_TYPE "replace-with-veclib"
STATISTIC(NumCallsReplaced,
"Number of calls to intrinsics that have been replaced.");
STATISTIC(NumTLIFuncDeclAdded,
"Number of vector library function declarations added.");
STATISTIC(NumFuncUsedAdded,
"Number of functions added to `llvm.compiler.used`");
static bool replaceWithTLIFunction(CallInst &CI, const StringRef TLIName) {
Module *M = CI.getModule();
Function *OldFunc = CI.getCalledFunction();
Function *TLIFunc = M->getFunction(TLIName);
if (!TLIFunc) {
TLIFunc = Function::Create(OldFunc->getFunctionType(),
Function::ExternalLinkage, TLIName, *M);
TLIFunc->copyAttributesFrom(OldFunc);
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `"
<< TLIName << "` of type `" << *(TLIFunc->getType())
<< "` to module.\n");
++NumTLIFuncDeclAdded;
appendToCompilerUsed(*M, {TLIFunc});
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName
<< "` to `@llvm.compiler.used`.\n");
++NumFuncUsedAdded;
}
IRBuilder<> IRBuilder(&CI);
SmallVector<Value *> Args(CI.args());
SmallVector<OperandBundleDef, 1> OpBundles;
CI.getOperandBundlesAsDefs(OpBundles);
CallInst *Replacement = IRBuilder.CreateCall(TLIFunc, Args, OpBundles);
assert(OldFunc->getFunctionType() == TLIFunc->getFunctionType() &&
"Expecting function types to be identical");
CI.replaceAllUsesWith(Replacement);
if (isa<FPMathOperator>(Replacement)) {
Replacement->copyFastMathFlags(&CI);
}
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `"
<< OldFunc->getName() << "` with call to `" << TLIName
<< "`.\n");
++NumCallsReplaced;
return true;
}
static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
CallInst &CI) {
if (!CI.getCalledFunction()) {
return false;
}
auto IntrinsicID = CI.getCalledFunction()->getIntrinsicID();
if (IntrinsicID == Intrinsic::not_intrinsic) {
return false;
}
ElementCount VF = ElementCount::getFixed(0);
SmallVector<Type *> ScalarTypes;
for (auto Arg : enumerate(CI.args())) {
auto *ArgType = Arg.value()->getType();
if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, Arg.index())) {
ScalarTypes.push_back(ArgType);
} else {
auto *VectorArgTy = dyn_cast<VectorType>(ArgType);
if (!VectorArgTy) {
return false;
}
ElementCount NumElements = VectorArgTy->getElementCount();
if (NumElements.isScalable()) {
return false;
}
if (VF.isNonZero() && VF != NumElements) {
return false;
} else {
VF = NumElements;
}
ScalarTypes.push_back(VectorArgTy->getElementType());
}
}
std::string ScalarName;
if (Intrinsic::isOverloaded(IntrinsicID)) {
ScalarName = Intrinsic::getName(IntrinsicID, ScalarTypes, CI.getModule());
} else {
ScalarName = Intrinsic::getName(IntrinsicID).str();
}
if (!TLI.isFunctionVectorizable(ScalarName)) {
return false;
}
const std::string TLIName =
std::string(TLI.getVectorizedFunction(ScalarName, VF));
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Looking up TLI mapping for `"
<< ScalarName << "` and vector width " << VF << ".\n");
if (!TLIName.empty()) {
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI function `" << TLIName
<< "`.\n");
return replaceWithTLIFunction(CI, TLIName);
}
return false;
}
static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
bool Changed = false;
SmallVector<CallInst *> ReplacedCalls;
for (auto &I : instructions(F)) {
if (auto *CI = dyn_cast<CallInst>(&I)) {
if (replaceWithCallToVeclib(TLI, *CI)) {
ReplacedCalls.push_back(CI);
Changed = true;
}
}
}
for (auto *CI : ReplacedCalls) {
CI->eraseFromParent();
}
return Changed;
}
PreservedAnalyses ReplaceWithVeclib::run(Function &F,
FunctionAnalysisManager &AM) {
const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
auto Changed = runImpl(TLI, F);
if (Changed) {
PreservedAnalyses PA;
PA.preserveSet<CFGAnalyses>();
PA.preserve<TargetLibraryAnalysis>();
PA.preserve<ScalarEvolutionAnalysis>();
PA.preserve<LoopAccessAnalysis>();
PA.preserve<DemandedBitsAnalysis>();
PA.preserve<OptimizationRemarkEmitterAnalysis>();
return PA;
} else {
return PreservedAnalyses::all();
}
}
bool ReplaceWithVeclibLegacy::runOnFunction(Function &F) {
const TargetLibraryInfo &TLI =
getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
return runImpl(TLI, F);
}
void ReplaceWithVeclibLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesCFG();
AU.addRequired<TargetLibraryInfoWrapperPass>();
AU.addPreserved<TargetLibraryInfoWrapperPass>();
AU.addPreserved<ScalarEvolutionWrapperPass>();
AU.addPreserved<AAResultsWrapperPass>();
AU.addPreserved<LoopAccessLegacyAnalysis>();
AU.addPreserved<DemandedBitsWrapperPass>();
AU.addPreserved<OptimizationRemarkEmitterWrapperPass>();
AU.addPreserved<GlobalsAAWrapperPass>();
}
char ReplaceWithVeclibLegacy::ID = 0;
INITIALIZE_PASS_BEGIN(ReplaceWithVeclibLegacy, DEBUG_TYPE,
"Replace intrinsics with calls to vector library", false,
false)
INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
INITIALIZE_PASS_END(ReplaceWithVeclibLegacy, DEBUG_TYPE,
"Replace intrinsics with calls to vector library", false,
false)
FunctionPass *llvm::createReplaceWithVeclibLegacyPass() {
return new ReplaceWithVeclibLegacy();
}