#include "llvm/ADT/DirectedGraph.h"
#include "llvm/ADT/GraphTraits.h"
#include "llvm/ADT/SCCIterator.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "gtest/gtest.h"
namespace llvm {
class DGTestNode;
class DGTestEdge;
using DGTestNodeBase = DGNode<DGTestNode, DGTestEdge>;
using DGTestEdgeBase = DGEdge<DGTestNode, DGTestEdge>;
using DGTestBase = DirectedGraph<DGTestNode, DGTestEdge>;
class DGTestNode : public DGTestNodeBase {
public:
DGTestNode() = default;
};
class DGTestEdge : public DGTestEdgeBase {
public:
DGTestEdge() = delete;
DGTestEdge(DGTestNode &N) : DGTestEdgeBase(N) {}
};
class DGTestGraph : public DGTestBase {
public:
DGTestGraph() = default;
~DGTestGraph(){};
};
using EdgeListTy = SmallVector<DGTestEdge *, 2>;
template <> struct GraphTraits<DGTestNode *> {
using NodeRef = DGTestNode *;
static DGTestNode *DGTestGetTargetNode(DGEdge<DGTestNode, DGTestEdge> *P) {
return &P->getTargetNode();
}
using ChildIteratorType =
mapped_iterator<DGTestNode::iterator, decltype(&DGTestGetTargetNode)>;
using ChildEdgeIteratorType = DGTestNode::iterator;
static NodeRef getEntryNode(NodeRef N) { return N; }
static ChildIteratorType child_begin(NodeRef N) {
return ChildIteratorType(N->begin(), &DGTestGetTargetNode);
}
static ChildIteratorType child_end(NodeRef N) {
return ChildIteratorType(N->end(), &DGTestGetTargetNode);
}
static ChildEdgeIteratorType child_edge_begin(NodeRef N) {
return N->begin();
}
static ChildEdgeIteratorType child_edge_end(NodeRef N) { return N->end(); }
};
template <>
struct GraphTraits<DGTestGraph *> : public GraphTraits<DGTestNode *> {
using nodes_iterator = DGTestGraph::iterator;
static NodeRef getEntryNode(DGTestGraph *DG) { return *DG->begin(); }
static nodes_iterator nodes_begin(DGTestGraph *DG) { return DG->begin(); }
static nodes_iterator nodes_end(DGTestGraph *DG) { return DG->end(); }
};
TEST(DirectedGraphTest, AddAndConnectNodes) {
DGTestGraph DG;
DGTestNode N1, N2, N3;
DGTestEdge E1(N1), E2(N2), E3(N3);
EXPECT_TRUE(DG.addNode(N1));
EXPECT_TRUE(DG.addNode(N2));
EXPECT_TRUE(DG.addNode(N3));
EXPECT_FALSE(DG.addNode(N1));
EXPECT_TRUE(DG.connect(N1, N2, E2));
EXPECT_TRUE(DG.connect(N2, N3, E3));
EXPECT_TRUE(DG.connect(N3, N1, E1));
EXPECT_FALSE(DG.connect(N3, N1, E1));
EXPECT_TRUE(DG.size() == 3);
EXPECT_NE(DG.findNode(N3), DG.end());
DGTestNode N4;
EXPECT_EQ(DG.findNode(N4), DG.end());
EdgeListTy EL;
EXPECT_TRUE(DG.findIncomingEdgesToNode(N1, EL));
EXPECT_TRUE(EL.size() == 1);
EXPECT_EQ(*EL[0], E1);
}
TEST(DirectedGraphTest, AddRemoveEdge) {
DGTestGraph DG;
DGTestNode N1, N2, N3;
DGTestEdge E1(N1), E2(N2), E3(N3);
DG.addNode(N1);
DG.addNode(N2);
DG.addNode(N3);
DG.connect(N1, N2, E2);
DG.connect(N2, N3, E3);
DG.connect(N3, N1, E1);
EXPECT_TRUE(DG.size() == 3);
EXPECT_EQ(E1.getTargetNode(), N1);
EXPECT_EQ(E2.getTargetNode(), N2);
EXPECT_EQ(E3.getTargetNode(), N3);
N1.removeEdge(E2);
EdgeListTy EL;
EXPECT_FALSE(DG.findIncomingEdgesToNode(N2, EL));
EXPECT_TRUE(EL.empty());
N1.addEdge(E2);
EL.clear();
EXPECT_TRUE(DG.findIncomingEdgesToNode(N2, EL));
EXPECT_EQ(*EL[0], E2);
}
TEST(DirectedGraphTest, hasEdgeTo) {
DGTestGraph DG;
DGTestNode N1, N2, N3;
DGTestEdge E1(N1), E2(N2), E3(N3), E4(N1);
DG.addNode(N1);
DG.addNode(N2);
DG.addNode(N3);
DG.connect(N1, N2, E2);
DG.connect(N2, N3, E3);
DG.connect(N3, N1, E1);
DG.connect(N2, N1, E4);
EXPECT_TRUE(N2.hasEdgeTo(N1));
EXPECT_TRUE(N3.hasEdgeTo(N1));
}
TEST(DirectedGraphTest, AddRemoveNode) {
DGTestGraph DG;
DGTestNode N1, N2, N3;
DGTestEdge E1(N1), E2(N2), E3(N3);
DG.addNode(N1);
DG.addNode(N2);
DG.addNode(N3);
DG.connect(N1, N2, E2);
DG.connect(N2, N3, E3);
DG.connect(N3, N1, E1);
EXPECT_TRUE(DG.size() == 3);
EXPECT_TRUE(DG.removeNode(N1));
EXPECT_EQ(DG.findNode(N1), DG.end());
EXPECT_FALSE(DG.removeNode(N1));
EXPECT_TRUE(DG.size() == 2);
EXPECT_TRUE(N3.getEdges().empty());
EdgeListTy EL;
EXPECT_FALSE(DG.findIncomingEdgesToNode(N2, EL));
EXPECT_TRUE(EL.empty());
}
TEST(DirectedGraphTest, SCC) {
DGTestGraph DG;
DGTestNode N1, N2, N3, N4;
DGTestEdge E1(N1), E2(N2), E3(N3), E4(N4);
DG.addNode(N1);
DG.addNode(N2);
DG.addNode(N3);
DG.addNode(N4);
DG.connect(N1, N2, E2);
DG.connect(N2, N3, E3);
DG.connect(N3, N1, E1);
DG.connect(N3, N4, E4);
using NodeListTy = SmallPtrSet<DGTestNode *, 3>;
SmallVector<NodeListTy, 4> ListOfSCCs;
for (auto &SCC : make_range(scc_begin(&DG), scc_end(&DG)))
ListOfSCCs.push_back(NodeListTy(SCC.begin(), SCC.end()));
EXPECT_TRUE(ListOfSCCs.size() == 2);
for (auto &SCC : ListOfSCCs) {
if (SCC.size() > 1)
continue;
EXPECT_TRUE(SCC.size() == 1);
EXPECT_TRUE(SCC.count(&N4) == 1);
}
for (auto &SCC : ListOfSCCs) {
if (SCC.size() <= 1)
continue;
EXPECT_TRUE(SCC.size() == 3);
EXPECT_TRUE(SCC.count(&N1) == 1);
EXPECT_TRUE(SCC.count(&N2) == 1);
EXPECT_TRUE(SCC.count(&N3) == 1);
EXPECT_TRUE(SCC.count(&N4) == 0);
}
}
}