#include "AArch64InstrInfo.h"
#include "AArch64Subtarget.h"
#include "Utils/AArch64BaseInfo.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineOperand.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/RegisterScavenging.h"
#include "llvm/IR/DebugLoc.h"
#include "llvm/Pass.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Support/Debug.h"
#include "llvm/Target/TargetMachine.h"
#include <cassert>
using namespace llvm;
#define DEBUG_TYPE "aarch64-speculation-hardening"
#define AARCH64_SPECULATION_HARDENING_NAME "AArch64 speculation hardening pass"
static cl::opt<bool> HardenLoads("aarch64-slh-loads", cl::Hidden,
cl::desc("Sanitize loads from memory."),
cl::init(true));
namespace {
class AArch64SpeculationHardening : public MachineFunctionPass {
public:
const TargetInstrInfo *TII;
const TargetRegisterInfo *TRI;
static char ID;
AArch64SpeculationHardening() : MachineFunctionPass(ID) {
initializeAArch64SpeculationHardeningPass(*PassRegistry::getPassRegistry());
}
bool runOnMachineFunction(MachineFunction &Fn) override;
StringRef getPassName() const override {
return AARCH64_SPECULATION_HARDENING_NAME;
}
private:
unsigned MisspeculatingTaintReg;
unsigned MisspeculatingTaintReg32Bit;
bool UseControlFlowSpeculationBarrier;
BitVector RegsNeedingCSDBBeforeUse;
BitVector RegsAlreadyMasked;
bool functionUsesHardeningRegister(MachineFunction &MF) const;
bool instrumentControlFlow(MachineBasicBlock &MBB,
bool &UsesFullSpeculationBarrier);
bool endsWithCondControlFlow(MachineBasicBlock &MBB, MachineBasicBlock *&TBB,
MachineBasicBlock *&FBB,
AArch64CC::CondCode &CondCode) const;
void insertTrackingCode(MachineBasicBlock &SplitEdgeBB,
AArch64CC::CondCode &CondCode, DebugLoc DL) const;
void insertSPToRegTaintPropagation(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI) const;
void insertRegToSPTaintPropagation(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
unsigned TmpReg) const;
void insertFullSpeculationBarrier(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
DebugLoc DL) const;
bool slhLoads(MachineBasicBlock &MBB);
bool makeGPRSpeculationSafe(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
MachineInstr &MI, unsigned Reg);
bool lowerSpeculationSafeValuePseudos(MachineBasicBlock &MBB,
bool UsesFullSpeculationBarrier);
bool expandSpeculationSafeValue(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
bool UsesFullSpeculationBarrier);
bool insertCSDB(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
DebugLoc DL);
};
}
char AArch64SpeculationHardening::ID = 0;
INITIALIZE_PASS(AArch64SpeculationHardening, "aarch64-speculation-hardening",
AARCH64_SPECULATION_HARDENING_NAME, false, false)
bool AArch64SpeculationHardening::endsWithCondControlFlow(
MachineBasicBlock &MBB, MachineBasicBlock *&TBB, MachineBasicBlock *&FBB,
AArch64CC::CondCode &CondCode) const {
SmallVector<MachineOperand, 1> analyzeBranchCondCode;
if (TII->analyzeBranch(MBB, TBB, FBB, analyzeBranchCondCode, false))
return false;
if (analyzeBranchCondCode.empty())
return false;
assert(TBB != nullptr);
if (FBB == nullptr)
FBB = MBB.getFallThrough();
if (TBB == FBB)
return false;
assert(MBB.succ_size() == 2);
assert(analyzeBranchCondCode.size() == 1 && "unknown Cond array format");
CondCode = AArch64CC::CondCode(analyzeBranchCondCode[0].getImm());
return true;
}
void AArch64SpeculationHardening::insertFullSpeculationBarrier(
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
DebugLoc DL) const {
BuildMI(MBB, MBBI, DL, TII->get(AArch64::DSB)).addImm(0xf);
BuildMI(MBB, MBBI, DL, TII->get(AArch64::ISB)).addImm(0xf);
}
void AArch64SpeculationHardening::insertTrackingCode(
MachineBasicBlock &SplitEdgeBB, AArch64CC::CondCode &CondCode,
DebugLoc DL) const {
if (UseControlFlowSpeculationBarrier) {
insertFullSpeculationBarrier(SplitEdgeBB, SplitEdgeBB.begin(), DL);
} else {
BuildMI(SplitEdgeBB, SplitEdgeBB.begin(), DL, TII->get(AArch64::CSELXr))
.addDef(MisspeculatingTaintReg)
.addUse(MisspeculatingTaintReg)
.addUse(AArch64::XZR)
.addImm(CondCode);
SplitEdgeBB.addLiveIn(AArch64::NZCV);
}
}
bool AArch64SpeculationHardening::instrumentControlFlow(
MachineBasicBlock &MBB, bool &UsesFullSpeculationBarrier) {
LLVM_DEBUG(dbgs() << "Instrument control flow tracking on MBB: " << MBB);
bool Modified = false;
MachineBasicBlock *TBB = nullptr;
MachineBasicBlock *FBB = nullptr;
AArch64CC::CondCode CondCode;
if (!endsWithCondControlFlow(MBB, TBB, FBB, CondCode)) {
LLVM_DEBUG(dbgs() << "... doesn't end with CondControlFlow\n");
} else {
AArch64CC::CondCode InvCondCode = AArch64CC::getInvertedCondCode(CondCode);
MachineBasicBlock *SplitEdgeTBB = MBB.SplitCriticalEdge(TBB, *this);
MachineBasicBlock *SplitEdgeFBB = MBB.SplitCriticalEdge(FBB, *this);
assert(SplitEdgeTBB != nullptr);
assert(SplitEdgeFBB != nullptr);
DebugLoc DL;
if (MBB.instr_end() != MBB.instr_begin())
DL = (--MBB.instr_end())->getDebugLoc();
insertTrackingCode(*SplitEdgeTBB, CondCode, DL);
insertTrackingCode(*SplitEdgeFBB, InvCondCode, DL);
LLVM_DEBUG(dbgs() << "SplitEdgeTBB: " << *SplitEdgeTBB << "\n");
LLVM_DEBUG(dbgs() << "SplitEdgeFBB: " << *SplitEdgeFBB << "\n");
Modified = true;
}
SmallVector<std::pair<MachineInstr *, unsigned>, 4> ReturnInstructions;
SmallVector<std::pair<MachineInstr *, unsigned>, 4> CallInstructions;
bool TmpRegisterNotAvailableEverywhere = false;
RegScavenger RS;
RS.enterBasicBlock(MBB);
for (MachineBasicBlock::iterator I = MBB.begin(); I != MBB.end(); I++) {
MachineInstr &MI = *I;
if (!MI.isReturn() && !MI.isCall())
continue;
if (I != MBB.begin())
RS.forward(std::prev(I));
Register TmpReg = RS.FindUnusedReg(&AArch64::GPR64commonRegClass);
LLVM_DEBUG(dbgs() << "RS finds "
<< ((TmpReg == 0) ? "no register " : "register ");
if (TmpReg != 0) dbgs() << printReg(TmpReg, TRI) << " ";
dbgs() << "to be available at MI " << MI);
if (TmpReg == 0)
TmpRegisterNotAvailableEverywhere = true;
if (MI.isReturn())
ReturnInstructions.push_back({&MI, TmpReg});
else if (MI.isCall())
CallInstructions.push_back({&MI, TmpReg});
}
if (TmpRegisterNotAvailableEverywhere) {
insertFullSpeculationBarrier(MBB, MBB.begin(),
(MBB.begin())->getDebugLoc());
UsesFullSpeculationBarrier = true;
Modified = true;
} else {
for (auto MI_Reg : ReturnInstructions) {
assert(MI_Reg.second != 0);
LLVM_DEBUG(
dbgs()
<< " About to insert Reg to SP taint propagation with temp register "
<< printReg(MI_Reg.second, TRI)
<< " on instruction: " << *MI_Reg.first);
insertRegToSPTaintPropagation(MBB, MI_Reg.first, MI_Reg.second);
Modified = true;
}
for (auto MI_Reg : CallInstructions) {
assert(MI_Reg.second != 0);
LLVM_DEBUG(dbgs() << " About to insert Reg to SP and back taint "
"propagation with temp register "
<< printReg(MI_Reg.second, TRI)
<< " around instruction: " << *MI_Reg.first);
insertSPToRegTaintPropagation(
MBB, std::next((MachineBasicBlock::iterator)MI_Reg.first));
insertRegToSPTaintPropagation(MBB, MI_Reg.first, MI_Reg.second);
Modified = true;
}
}
return Modified;
}
void AArch64SpeculationHardening::insertSPToRegTaintPropagation(
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) const {
if (UseControlFlowSpeculationBarrier) {
insertFullSpeculationBarrier(MBB, MBBI, DebugLoc());
return;
}
BuildMI(MBB, MBBI, DebugLoc(), TII->get(AArch64::SUBSXri))
.addDef(AArch64::XZR)
.addUse(AArch64::SP)
.addImm(0)
.addImm(0); BuildMI(MBB, MBBI, DebugLoc(), TII->get(AArch64::CSINVXr))
.addDef(MisspeculatingTaintReg)
.addUse(AArch64::XZR)
.addUse(AArch64::XZR)
.addImm(AArch64CC::EQ);
}
void AArch64SpeculationHardening::insertRegToSPTaintPropagation(
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
unsigned TmpReg) const {
if (UseControlFlowSpeculationBarrier)
return;
BuildMI(MBB, MBBI, DebugLoc(), TII->get(AArch64::ADDXri))
.addDef(TmpReg)
.addUse(AArch64::SP)
.addImm(0)
.addImm(0); BuildMI(MBB, MBBI, DebugLoc(), TII->get(AArch64::ANDXrs))
.addDef(TmpReg, RegState::Renamable)
.addUse(TmpReg, RegState::Kill | RegState::Renamable)
.addUse(MisspeculatingTaintReg, RegState::Kill)
.addImm(0);
BuildMI(MBB, MBBI, DebugLoc(), TII->get(AArch64::ADDXri))
.addDef(AArch64::SP)
.addUse(TmpReg, RegState::Kill)
.addImm(0)
.addImm(0); }
bool AArch64SpeculationHardening::functionUsesHardeningRegister(
MachineFunction &MF) const {
for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : MBB) {
if (MI.isCall())
continue;
if (MI.readsRegister(MisspeculatingTaintReg, TRI) ||
MI.modifiesRegister(MisspeculatingTaintReg, TRI))
return true;
}
}
return false;
}
bool AArch64SpeculationHardening::makeGPRSpeculationSafe(
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, MachineInstr &MI,
unsigned Reg) {
assert(AArch64::GPR32allRegClass.contains(Reg) ||
AArch64::GPR64allRegClass.contains(Reg));
if (Reg == AArch64::SP || Reg == AArch64::WSP)
return false;
if (RegsAlreadyMasked[Reg])
return false;
const bool Is64Bit = AArch64::GPR64allRegClass.contains(Reg);
LLVM_DEBUG(dbgs() << "About to harden register : " << Reg << "\n");
BuildMI(MBB, MBBI, MI.getDebugLoc(),
TII->get(Is64Bit ? AArch64::SpeculationSafeValueX
: AArch64::SpeculationSafeValueW))
.addDef(Reg)
.addUse(Reg);
RegsAlreadyMasked.set(Reg);
return true;
}
bool AArch64SpeculationHardening::slhLoads(MachineBasicBlock &MBB) {
bool Modified = false;
LLVM_DEBUG(dbgs() << "slhLoads running on MBB: " << MBB);
RegsAlreadyMasked.reset();
MachineBasicBlock::iterator MBBI = MBB.begin(), E = MBB.end();
MachineBasicBlock::iterator NextMBBI;
for (; MBBI != E; MBBI = NextMBBI) {
MachineInstr &MI = *MBBI;
NextMBBI = std::next(MBBI);
if (!MI.mayLoad())
continue;
LLVM_DEBUG(dbgs() << "About to harden: " << MI);
bool AllDefsAreGPR = llvm::all_of(MI.defs(), [&](MachineOperand &Op) {
return Op.isReg() && (AArch64::GPR32allRegClass.contains(Op.getReg()) ||
AArch64::GPR64allRegClass.contains(Op.getReg()));
});
bool HardenLoadedData = AllDefsAreGPR;
bool HardenAddressLoadedFrom = !HardenLoadedData;
for (MachineOperand Op : MI.defs())
for (MCRegAliasIterator AI(Op.getReg(), TRI, true); AI.isValid(); ++AI)
RegsAlreadyMasked.reset(*AI);
if (HardenLoadedData)
for (auto Def : MI.defs()) {
if (Def.isDead())
continue;
Modified |= makeGPRSpeculationSafe(MBB, NextMBBI, MI, Def.getReg());
}
if (HardenAddressLoadedFrom)
for (auto Use : MI.uses()) {
if (!Use.isReg())
continue;
Register Reg = Use.getReg();
if (!(AArch64::GPR32allRegClass.contains(Reg) ||
AArch64::GPR64allRegClass.contains(Reg)))
continue;
Modified |= makeGPRSpeculationSafe(MBB, MBBI, MI, Reg);
}
}
return Modified;
}
bool AArch64SpeculationHardening::expandSpeculationSafeValue(
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
bool UsesFullSpeculationBarrier) {
MachineInstr &MI = *MBBI;
unsigned Opcode = MI.getOpcode();
bool Is64Bit = true;
switch (Opcode) {
default:
break;
case AArch64::SpeculationSafeValueW:
Is64Bit = false;
LLVM_FALLTHROUGH;
case AArch64::SpeculationSafeValueX:
if (!UseControlFlowSpeculationBarrier && !UsesFullSpeculationBarrier) {
Register DstReg = MI.getOperand(0).getReg();
Register SrcReg = MI.getOperand(1).getReg();
for (MachineOperand Op : MI.defs())
for (MCRegAliasIterator AI(Op.getReg(), TRI, true); AI.isValid(); ++AI)
RegsNeedingCSDBBeforeUse.set(*AI);
BuildMI(MBB, MBBI, MI.getDebugLoc(),
Is64Bit ? TII->get(AArch64::ANDXrs) : TII->get(AArch64::ANDWrs))
.addDef(DstReg)
.addUse(SrcReg, RegState::Kill)
.addUse(Is64Bit ? MisspeculatingTaintReg
: MisspeculatingTaintReg32Bit)
.addImm(0);
}
MI.eraseFromParent();
return true;
}
return false;
}
bool AArch64SpeculationHardening::insertCSDB(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
DebugLoc DL) {
assert(!UseControlFlowSpeculationBarrier && "No need to insert CSDBs when "
"control flow miss-speculation "
"is already blocked");
BuildMI(MBB, MBBI, DL, TII->get(AArch64::HINT)).addImm(0x14);
RegsNeedingCSDBBeforeUse.reset();
return true;
}
bool AArch64SpeculationHardening::lowerSpeculationSafeValuePseudos(
MachineBasicBlock &MBB, bool UsesFullSpeculationBarrier) {
bool Modified = false;
RegsNeedingCSDBBeforeUse.reset();
MachineBasicBlock::iterator MBBI = MBB.begin(), E = MBB.end();
DebugLoc DL;
while (MBBI != E) {
MachineInstr &MI = *MBBI;
DL = MI.getDebugLoc();
MachineBasicBlock::iterator NMBBI = std::next(MBBI);
bool NeedToEmitBarrier = false;
if (RegsNeedingCSDBBeforeUse.any() && (MI.isCall() || MI.isTerminator()))
NeedToEmitBarrier = true;
if (!NeedToEmitBarrier)
for (MachineOperand Op : MI.uses())
if (Op.isReg() && RegsNeedingCSDBBeforeUse[Op.getReg()]) {
NeedToEmitBarrier = true;
break;
}
if (NeedToEmitBarrier && !UsesFullSpeculationBarrier)
Modified |= insertCSDB(MBB, MBBI, DL);
Modified |=
expandSpeculationSafeValue(MBB, MBBI, UsesFullSpeculationBarrier);
MBBI = NMBBI;
}
if (RegsNeedingCSDBBeforeUse.any() && !UsesFullSpeculationBarrier)
Modified |= insertCSDB(MBB, MBBI, DL);
return Modified;
}
bool AArch64SpeculationHardening::runOnMachineFunction(MachineFunction &MF) {
if (!MF.getFunction().hasFnAttribute(Attribute::SpeculativeLoadHardening))
return false;
MisspeculatingTaintReg = AArch64::X16;
MisspeculatingTaintReg32Bit = AArch64::W16;
TII = MF.getSubtarget().getInstrInfo();
TRI = MF.getSubtarget().getRegisterInfo();
RegsNeedingCSDBBeforeUse.resize(TRI->getNumRegs());
RegsAlreadyMasked.resize(TRI->getNumRegs());
UseControlFlowSpeculationBarrier = functionUsesHardeningRegister(MF);
bool Modified = false;
if (HardenLoads) {
LLVM_DEBUG(
dbgs() << "***** AArch64SpeculationHardening - automatic insertion of "
"SpeculationSafeValue intrinsics *****\n");
for (auto &MBB : MF)
Modified |= slhLoads(MBB);
}
LLVM_DEBUG(
dbgs()
<< "***** AArch64SpeculationHardening - track control flow *****\n");
SmallVector<MachineBasicBlock *, 2> EntryBlocks;
EntryBlocks.push_back(&MF.front());
for (const LandingPadInfo &LPI : MF.getLandingPads())
EntryBlocks.push_back(LPI.LandingPadBlock);
for (auto Entry : EntryBlocks)
insertSPToRegTaintPropagation(
*Entry, Entry->SkipPHIsLabelsAndDebug(Entry->begin()));
for (auto &MBB : MF) {
bool UsesFullSpeculationBarrier = false;
Modified |= instrumentControlFlow(MBB, UsesFullSpeculationBarrier);
Modified |=
lowerSpeculationSafeValuePseudos(MBB, UsesFullSpeculationBarrier);
}
return Modified;
}
FunctionPass *llvm::createAArch64SpeculationHardeningPass() {
return new AArch64SpeculationHardening();
}