#ifndef LLVM_ADT_DIRECTEDGRAPH_H
#define LLVM_ADT_DIRECTEDGRAPH_H
#include "llvm/ADT/GraphTraits.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
namespace llvm {
template <class NodeType, class EdgeType> class DGEdge {
public:
DGEdge() = delete;
explicit DGEdge(NodeType &N) : TargetNode(N) {}
explicit DGEdge(const DGEdge<NodeType, EdgeType> &E)
: TargetNode(E.TargetNode) {}
DGEdge<NodeType, EdgeType> &operator=(const DGEdge<NodeType, EdgeType> &E) {
TargetNode = E.TargetNode;
return *this;
}
bool operator==(const DGEdge &E) const {
return getDerived().isEqualTo(E.getDerived());
}
bool operator!=(const DGEdge &E) const { return !operator==(E); }
const NodeType &getTargetNode() const { return TargetNode; }
NodeType &getTargetNode() {
return const_cast<NodeType &>(
static_cast<const DGEdge<NodeType, EdgeType> &>(*this).getTargetNode());
}
void setTargetNode(const NodeType &N) { TargetNode = N; }
protected:
bool isEqualTo(const EdgeType &E) const { return this == &E; }
EdgeType &getDerived() { return *static_cast<EdgeType *>(this); }
const EdgeType &getDerived() const {
return *static_cast<const EdgeType *>(this);
}
NodeType &TargetNode;
};
template <class NodeType, class EdgeType> class DGNode {
public:
using EdgeListTy = SetVector<EdgeType *>;
using iterator = typename EdgeListTy::iterator;
using const_iterator = typename EdgeListTy::const_iterator;
explicit DGNode(EdgeType &E) : Edges() { Edges.insert(&E); }
DGNode() = default;
explicit DGNode(const DGNode<NodeType, EdgeType> &N) : Edges(N.Edges) {}
DGNode(DGNode<NodeType, EdgeType> &&N) : Edges(std::move(N.Edges)) {}
DGNode<NodeType, EdgeType> &operator=(const DGNode<NodeType, EdgeType> &N) {
Edges = N.Edges;
return *this;
}
DGNode<NodeType, EdgeType> &operator=(const DGNode<NodeType, EdgeType> &&N) {
Edges = std::move(N.Edges);
return *this;
}
friend bool operator==(const NodeType &M, const NodeType &N) {
return M.isEqualTo(N);
}
friend bool operator!=(const NodeType &M, const NodeType &N) {
return !(M == N);
}
const_iterator begin() const { return Edges.begin(); }
const_iterator end() const { return Edges.end(); }
iterator begin() { return Edges.begin(); }
iterator end() { return Edges.end(); }
const EdgeType &front() const { return *Edges.front(); }
EdgeType &front() { return *Edges.front(); }
const EdgeType &back() const { return *Edges.back(); }
EdgeType &back() { return *Edges.back(); }
bool findEdgesTo(const NodeType &N, SmallVectorImpl<EdgeType *> &EL) const {
assert(EL.empty() && "Expected the list of edges to be empty.");
for (auto *E : Edges)
if (E->getTargetNode() == N)
EL.push_back(E);
return !EL.empty();
}
bool addEdge(EdgeType &E) { return Edges.insert(&E); }
void removeEdge(EdgeType &E) { Edges.remove(&E); }
bool hasEdgeTo(const NodeType &N) const {
return (findEdgeTo(N) != Edges.end());
}
const EdgeListTy &getEdges() const { return Edges; }
EdgeListTy &getEdges() {
return const_cast<EdgeListTy &>(
static_cast<const DGNode<NodeType, EdgeType> &>(*this).Edges);
}
void clear() { Edges.clear(); }
protected:
bool isEqualTo(const NodeType &N) const { return this == &N; }
NodeType &getDerived() { return *static_cast<NodeType *>(this); }
const NodeType &getDerived() const {
return *static_cast<const NodeType *>(this);
}
const_iterator findEdgeTo(const NodeType &N) const {
return llvm::find_if(
Edges, [&N](const EdgeType *E) { return E->getTargetNode() == N; });
}
EdgeListTy Edges;
};
template <class NodeType, class EdgeType> class DirectedGraph {
protected:
using NodeListTy = SmallVector<NodeType *, 10>;
using EdgeListTy = SmallVector<EdgeType *, 10>;
public:
using iterator = typename NodeListTy::iterator;
using const_iterator = typename NodeListTy::const_iterator;
using DGraphType = DirectedGraph<NodeType, EdgeType>;
DirectedGraph() = default;
explicit DirectedGraph(NodeType &N) : Nodes() { addNode(N); }
DirectedGraph(const DGraphType &G) : Nodes(G.Nodes) {}
DirectedGraph(DGraphType &&RHS) : Nodes(std::move(RHS.Nodes)) {}
DGraphType &operator=(const DGraphType &G) {
Nodes = G.Nodes;
return *this;
}
DGraphType &operator=(const DGraphType &&G) {
Nodes = std::move(G.Nodes);
return *this;
}
const_iterator begin() const { return Nodes.begin(); }
const_iterator end() const { return Nodes.end(); }
iterator begin() { return Nodes.begin(); }
iterator end() { return Nodes.end(); }
const NodeType &front() const { return *Nodes.front(); }
NodeType &front() { return *Nodes.front(); }
const NodeType &back() const { return *Nodes.back(); }
NodeType &back() { return *Nodes.back(); }
size_t size() const { return Nodes.size(); }
const_iterator findNode(const NodeType &N) const {
return llvm::find_if(Nodes,
[&N](const NodeType *Node) { return *Node == N; });
}
iterator findNode(const NodeType &N) {
return const_cast<iterator>(
static_cast<const DGraphType &>(*this).findNode(N));
}
bool addNode(NodeType &N) {
if (findNode(N) != Nodes.end())
return false;
Nodes.push_back(&N);
return true;
}
bool findIncomingEdgesToNode(const NodeType &N, SmallVectorImpl<EdgeType*> &EL) const {
assert(EL.empty() && "Expected the list of edges to be empty.");
EdgeListTy TempList;
for (auto *Node : Nodes) {
if (*Node == N)
continue;
Node->findEdgesTo(N, TempList);
llvm::append_range(EL, TempList);
TempList.clear();
}
return !EL.empty();
}
bool removeNode(NodeType &N) {
iterator IT = findNode(N);
if (IT == Nodes.end())
return false;
EdgeListTy EL;
for (auto *Node : Nodes) {
if (*Node == N)
continue;
Node->findEdgesTo(N, EL);
for (auto *E : EL)
Node->removeEdge(*E);
EL.clear();
}
N.clear();
Nodes.erase(IT);
return true;
}
bool connect(NodeType &Src, NodeType &Dst, EdgeType &E) {
assert(findNode(Src) != Nodes.end() && "Src node should be present.");
assert(findNode(Dst) != Nodes.end() && "Dst node should be present.");
assert((E.getTargetNode() == Dst) &&
"Target of the given edge does not match Dst.");
return Src.addEdge(E);
}
protected:
NodeListTy Nodes;
};
}
#endif