#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Bitcode/BitcodeWriterPass.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/IRPrintingPasses.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/SystemUtils.h"
#include "llvm/Support/ToolOutputFile.h"
#include "llvm/Transforms/IPO.h"
#include <memory>
#include <utility>
using namespace llvm;
cl::OptionCategory ExtractCat("llvm-extract Options");
static cl::opt<std::string> InputFilename(cl::Positional,
cl::desc("<input bitcode file>"),
cl::init("-"),
cl::value_desc("filename"));
static cl::opt<std::string> OutputFilename("o",
cl::desc("Specify output filename"),
cl::value_desc("filename"),
cl::init("-"), cl::cat(ExtractCat));
static cl::opt<bool> Force("f", cl::desc("Enable binary output on terminals"),
cl::cat(ExtractCat));
static cl::opt<bool> DeleteFn("delete",
cl::desc("Delete specified Globals from Module"),
cl::cat(ExtractCat));
static cl::opt<bool> KeepConstInit("keep-const-init",
cl::desc("Keep initializers of constants"),
cl::cat(ExtractCat));
static cl::opt<bool>
Recursive("recursive", cl::desc("Recursively extract all called functions"),
cl::cat(ExtractCat));
static cl::list<std::string>
ExtractFuncs("func", cl::desc("Specify function to extract"),
cl::value_desc("function"), cl::cat(ExtractCat));
static cl::list<std::string>
ExtractRegExpFuncs("rfunc",
cl::desc("Specify function(s) to extract using a "
"regular expression"),
cl::value_desc("rfunction"), cl::cat(ExtractCat));
static cl::list<std::string> ExtractBlocks(
"bb",
cl::desc(
"Specify <function, basic block1[;basic block2...]> pairs to extract.\n"
"Each pair will create a function.\n"
"If multiple basic blocks are specified in one pair,\n"
"the first block in the sequence should dominate the rest.\n"
"eg:\n"
" --bb=f:bb1;bb2 will extract one function with both bb1 and bb2;\n"
" --bb=f:bb1 --bb=f:bb2 will extract two functions, one with bb1, one "
"with bb2."),
cl::value_desc("function:bb1[;bb2...]"), cl::cat(ExtractCat));
static cl::list<std::string>
ExtractAliases("alias", cl::desc("Specify alias to extract"),
cl::value_desc("alias"), cl::cat(ExtractCat));
static cl::list<std::string>
ExtractRegExpAliases("ralias",
cl::desc("Specify alias(es) to extract using a "
"regular expression"),
cl::value_desc("ralias"), cl::cat(ExtractCat));
static cl::list<std::string>
ExtractGlobals("glob", cl::desc("Specify global to extract"),
cl::value_desc("global"), cl::cat(ExtractCat));
static cl::list<std::string>
ExtractRegExpGlobals("rglob",
cl::desc("Specify global(s) to extract using a "
"regular expression"),
cl::value_desc("rglobal"), cl::cat(ExtractCat));
static cl::opt<bool> OutputAssembly("S",
cl::desc("Write output as LLVM assembly"),
cl::Hidden, cl::cat(ExtractCat));
static cl::opt<bool> PreserveBitcodeUseListOrder(
"preserve-bc-uselistorder",
cl::desc("Preserve use-list order when writing LLVM bitcode."),
cl::init(true), cl::Hidden, cl::cat(ExtractCat));
static cl::opt<bool> PreserveAssemblyUseListOrder(
"preserve-ll-uselistorder",
cl::desc("Preserve use-list order when writing LLVM assembly."),
cl::init(false), cl::Hidden, cl::cat(ExtractCat));
int main(int argc, char **argv) {
InitLLVM X(argc, argv);
LLVMContext Context;
cl::HideUnrelatedOptions(ExtractCat);
cl::ParseCommandLineOptions(argc, argv, "llvm extractor\n");
SMDiagnostic Err;
std::unique_ptr<Module> M = getLazyIRFileModule(InputFilename, Err, Context);
if (!M.get()) {
Err.print(argv[0], errs());
return 1;
}
SetVector<GlobalValue *> GVs;
for (size_t i = 0, e = ExtractAliases.size(); i != e; ++i) {
GlobalAlias *GA = M->getNamedAlias(ExtractAliases[i]);
if (!GA) {
errs() << argv[0] << ": program doesn't contain alias named '"
<< ExtractAliases[i] << "'!\n";
return 1;
}
GVs.insert(GA);
}
for (size_t i = 0, e = ExtractRegExpAliases.size(); i != e; ++i) {
std::string Error;
Regex RegEx(ExtractRegExpAliases[i]);
if (!RegEx.isValid(Error)) {
errs() << argv[0] << ": '" << ExtractRegExpAliases[i] << "' "
"invalid regex: " << Error;
}
bool match = false;
for (Module::alias_iterator GA = M->alias_begin(), E = M->alias_end();
GA != E; GA++) {
if (RegEx.match(GA->getName())) {
GVs.insert(&*GA);
match = true;
}
}
if (!match) {
errs() << argv[0] << ": program doesn't contain global named '"
<< ExtractRegExpAliases[i] << "'!\n";
return 1;
}
}
for (size_t i = 0, e = ExtractGlobals.size(); i != e; ++i) {
GlobalValue *GV = M->getNamedGlobal(ExtractGlobals[i]);
if (!GV) {
errs() << argv[0] << ": program doesn't contain global named '"
<< ExtractGlobals[i] << "'!\n";
return 1;
}
GVs.insert(GV);
}
for (size_t i = 0, e = ExtractRegExpGlobals.size(); i != e; ++i) {
std::string Error;
Regex RegEx(ExtractRegExpGlobals[i]);
if (!RegEx.isValid(Error)) {
errs() << argv[0] << ": '" << ExtractRegExpGlobals[i] << "' "
"invalid regex: " << Error;
}
bool match = false;
for (auto &GV : M->globals()) {
if (RegEx.match(GV.getName())) {
GVs.insert(&GV);
match = true;
}
}
if (!match) {
errs() << argv[0] << ": program doesn't contain global named '"
<< ExtractRegExpGlobals[i] << "'!\n";
return 1;
}
}
for (size_t i = 0, e = ExtractFuncs.size(); i != e; ++i) {
GlobalValue *GV = M->getFunction(ExtractFuncs[i]);
if (!GV) {
errs() << argv[0] << ": program doesn't contain function named '"
<< ExtractFuncs[i] << "'!\n";
return 1;
}
GVs.insert(GV);
}
for (size_t i = 0, e = ExtractRegExpFuncs.size(); i != e; ++i) {
std::string Error;
StringRef RegExStr = ExtractRegExpFuncs[i];
Regex RegEx(RegExStr);
if (!RegEx.isValid(Error)) {
errs() << argv[0] << ": '" << ExtractRegExpFuncs[i] << "' "
"invalid regex: " << Error;
}
bool match = false;
for (Module::iterator F = M->begin(), E = M->end(); F != E;
F++) {
if (RegEx.match(F->getName())) {
GVs.insert(&*F);
match = true;
}
}
if (!match) {
errs() << argv[0] << ": program doesn't contain global named '"
<< ExtractRegExpFuncs[i] << "'!\n";
return 1;
}
}
SmallVector<std::pair<Function *, SmallVector<StringRef, 16>>, 2> BBMap;
for (StringRef StrPair : ExtractBlocks) {
SmallVector<StringRef, 16> BBNames;
auto BBInfo = StrPair.split(':');
Function *F = M->getFunction(BBInfo.first);
if (!F) {
errs() << argv[0] << ": program doesn't contain a function named '"
<< BBInfo.first << "'!\n";
return 1;
}
GVs.insert(F);
BBInfo.second.split(BBNames, ';', -1, false);
BBMap.push_back({F, std::move(BBNames)});
}
ExitOnError ExitOnErr(std::string(*argv) + ": error reading input: ");
if (Recursive) {
std::vector<llvm::Function *> Workqueue;
for (GlobalValue *GV : GVs) {
if (auto *F = dyn_cast<Function>(GV)) {
Workqueue.push_back(F);
}
}
while (!Workqueue.empty()) {
Function *F = &*Workqueue.back();
Workqueue.pop_back();
ExitOnErr(F->materialize());
for (auto &BB : *F) {
for (auto &I : BB) {
CallBase *CB = dyn_cast<CallBase>(&I);
if (!CB)
continue;
Function *CF = CB->getCalledFunction();
if (!CF)
continue;
if (CF->isDeclaration() || GVs.count(CF))
continue;
GVs.insert(CF);
Workqueue.push_back(CF);
}
}
}
}
auto Materialize = [&](GlobalValue &GV) { ExitOnErr(GV.materialize()); };
if (!DeleteFn) {
for (size_t i = 0, e = GVs.size(); i != e; ++i)
Materialize(*GVs[i]);
} else {
SmallPtrSet<GlobalValue *, 8> GVSet(GVs.begin(), GVs.end());
for (auto &F : *M) {
if (!GVSet.count(&F))
Materialize(F);
}
}
{
std::vector<GlobalValue *> Gvs(GVs.begin(), GVs.end());
legacy::PassManager Extract;
Extract.add(createGVExtractionPass(Gvs, DeleteFn, KeepConstInit));
Extract.run(*M);
ExitOnErr(M->materializeAll());
}
if (!ExtractBlocks.empty()) {
SmallVector<SmallVector<BasicBlock *, 16>, 4> GroupOfBBs;
for (auto &P : BBMap) {
SmallVector<BasicBlock *, 16> BBs;
for (StringRef BBName : P.second) {
auto Res = llvm::find_if(*P.first, [&](const BasicBlock &BB) {
return BB.getName().equals(BBName);
});
if (Res == P.first->end()) {
errs() << argv[0] << ": function " << P.first->getName()
<< " doesn't contain a basic block named '" << BBName
<< "'!\n";
return 1;
}
BBs.push_back(&*Res);
}
GroupOfBBs.push_back(BBs);
}
legacy::PassManager PM;
PM.add(createBlockExtractorPass(GroupOfBBs, true));
PM.run(*M);
}
legacy::PassManager Passes;
if (!DeleteFn)
Passes.add(createGlobalDCEPass()); Passes.add(createStripDeadDebugInfoPass()); Passes.add(createStripDeadPrototypesPass());
std::error_code EC;
ToolOutputFile Out(OutputFilename, EC, sys::fs::OF_None);
if (EC) {
errs() << EC.message() << '\n';
return 1;
}
if (OutputAssembly)
Passes.add(
createPrintModulePass(Out.os(), "", PreserveAssemblyUseListOrder));
else if (Force || !CheckBitcodeOutputToConsole(Out.os()))
Passes.add(createBitcodeWriterPass(Out.os(), PreserveBitcodeUseListOrder));
Passes.run(*M.get());
Out.keep();
return 0;
}