#include "clang/StaticAnalyzer/Checkers/BuiltinCheckerRegistration.h"
#include "clang/Analysis/SelectorExtras.h"
#include "clang/StaticAnalyzer/Core/Checker.h"
#include "clang/StaticAnalyzer/Core/CheckerManager.h"
#include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
#include "clang/StaticAnalyzer/Core/PathSensitive/CheckerHelpers.h"
#include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
using namespace clang;
using namespace ento;
REGISTER_MAP_WITH_PROGRAMSTATE(NonNullImplicationMap, SymbolRef, SymbolRef)
REGISTER_MAP_WITH_PROGRAMSTATE(NullImplicationMap, SymbolRef, SymbolRef)
namespace {
class TrustNonnullChecker : public Checker<check::PostCall,
check::PostObjCMessage,
check::DeadSymbols,
eval::Assume> {
static unsigned constexpr ComplexityThreshold = 10;
Selector ObjectForKeyedSubscriptSel;
Selector ObjectForKeySel;
Selector SetObjectForKeyedSubscriptSel;
Selector SetObjectForKeySel;
public:
TrustNonnullChecker(ASTContext &Ctx)
: ObjectForKeyedSubscriptSel(
getKeywordSelector(Ctx, "objectForKeyedSubscript")),
ObjectForKeySel(getKeywordSelector(Ctx, "objectForKey")),
SetObjectForKeyedSubscriptSel(
getKeywordSelector(Ctx, "setObject", "forKeyedSubscript")),
SetObjectForKeySel(getKeywordSelector(Ctx, "setObject", "forKey")) {}
ProgramStateRef evalAssume(ProgramStateRef State,
SVal Cond,
bool Assumption) const {
const SymbolRef CondS = Cond.getAsSymbol();
if (!CondS || CondS->computeComplexity() > ComplexityThreshold)
return State;
for (auto B=CondS->symbol_begin(), E=CondS->symbol_end(); B != E; ++B) {
const SymbolRef Antecedent = *B;
State = addImplication(Antecedent, State, true);
State = addImplication(Antecedent, State, false);
}
return State;
}
void checkPostCall(const CallEvent &Call, CheckerContext &C) const {
if (!Call.isInSystemHeader())
return;
ProgramStateRef State = C.getState();
if (isNonNullPtr(Call, C))
if (auto L = Call.getReturnValue().getAs<Loc>())
State = State->assume(*L, true);
C.addTransition(State);
}
void checkPostObjCMessage(const ObjCMethodCall &Msg,
CheckerContext &C) const {
const ObjCInterfaceDecl *ID = Msg.getReceiverInterface();
if (!ID)
return;
ProgramStateRef State = C.getState();
if (interfaceHasSuperclass(ID, "NSMutableDictionary") &&
(Msg.getSelector() == SetObjectForKeyedSubscriptSel ||
Msg.getSelector() == SetObjectForKeySel)) {
if (auto L = Msg.getArgSVal(1).getAs<Loc>())
State = State->assume(*L, true);
}
if (interfaceHasSuperclass(ID, "NSDictionary") &&
(Msg.getSelector() == ObjectForKeyedSubscriptSel ||
Msg.getSelector() == ObjectForKeySel)) {
SymbolRef ArgS = Msg.getArgSVal(0).getAsSymbol();
SymbolRef RetS = Msg.getReturnValue().getAsSymbol();
if (ArgS && RetS) {
State = State->set<NonNullImplicationMap>(RetS, ArgS);
State = State->set<NullImplicationMap>(ArgS, RetS);
}
}
C.addTransition(State);
}
void checkDeadSymbols(SymbolReaper &SymReaper, CheckerContext &C) const {
ProgramStateRef State = C.getState();
State = dropDeadFromGDM<NullImplicationMap>(SymReaper, State);
State = dropDeadFromGDM<NonNullImplicationMap>(SymReaper, State);
C.addTransition(State);
}
private:
template <typename MapName>
ProgramStateRef dropDeadFromGDM(SymbolReaper &SymReaper,
ProgramStateRef State) const {
for (const std::pair<SymbolRef, SymbolRef> &P : State->get<MapName>())
if (!SymReaper.isLive(P.first) || !SymReaper.isLive(P.second))
State = State->remove<MapName>(P.first);
return State;
}
bool isNonNullPtr(const CallEvent &Call, CheckerContext &C) const {
QualType ExprRetType = Call.getResultType();
if (!ExprRetType->isAnyPointerType())
return false;
if (getNullabilityAnnotation(ExprRetType) == Nullability::Nonnull)
return true;
if (!isa<ObjCMethodCall>(&Call))
return false;
const auto *MCall = cast<ObjCMethodCall>(&Call);
const ObjCMethodDecl *MD = MCall->getDecl();
if (isa<ObjCProtocolDecl>(MD->getDeclContext()))
return false;
QualType DeclRetType = MD->getReturnType();
if (getNullabilityAnnotation(DeclRetType) != Nullability::Nonnull)
return false;
if (!MCall->isInstanceMessage())
return true;
SVal Receiver = MCall->getReceiverSVal();
ConditionTruthVal TV = C.getState()->isNonNull(Receiver);
if (TV.isConstrainedTrue())
return true;
return false;
}
bool interfaceHasSuperclass(const ObjCInterfaceDecl *ID,
StringRef ClassName) const {
if (ID->getIdentifier()->getName() == ClassName)
return true;
if (const ObjCInterfaceDecl *Super = ID->getSuperClass())
return interfaceHasSuperclass(Super, ClassName);
return false;
}
ProgramStateRef addImplication(SymbolRef Antecedent,
ProgramStateRef InputState,
bool Negated) const {
if (!InputState)
return nullptr;
SValBuilder &SVB = InputState->getStateManager().getSValBuilder();
const SymbolRef *Consequent =
Negated ? InputState->get<NonNullImplicationMap>(Antecedent)
: InputState->get<NullImplicationMap>(Antecedent);
if (!Consequent)
return InputState;
SVal AntecedentV = SVB.makeSymbolVal(Antecedent);
ProgramStateRef State = InputState;
if ((Negated && InputState->isNonNull(AntecedentV).isConstrainedTrue())
|| (!Negated && InputState->isNull(AntecedentV).isConstrainedTrue())) {
SVal ConsequentS = SVB.makeSymbolVal(*Consequent);
State = InputState->assume(ConsequentS.castAs<DefinedSVal>(), Negated);
if (!State)
return nullptr;
if (Negated) {
State = State->remove<NonNullImplicationMap>(Antecedent);
State = State->remove<NullImplicationMap>(*Consequent);
} else {
State = State->remove<NullImplicationMap>(Antecedent);
State = State->remove<NonNullImplicationMap>(*Consequent);
}
}
return State;
}
};
}
void ento::registerTrustNonnullChecker(CheckerManager &Mgr) {
Mgr.registerChecker<TrustNonnullChecker>(Mgr.getASTContext());
}
bool ento::shouldRegisterTrustNonnullChecker(const CheckerManager &mgr) {
return true;
}