#ifndef LLVM_ADT_COALESCINGBITVECTOR_H
#define LLVM_ADT_COALESCINGBITVECTOR_H
#include "llvm/ADT/IntervalMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <initializer_list>
namespace llvm {
template <typename IndexT> class CoalescingBitVector {
static_assert(std::is_unsigned<IndexT>::value,
"Index must be an unsigned integer.");
using ThisT = CoalescingBitVector<IndexT>;
using MapT = IntervalMap<IndexT, char>;
using UnderlyingIterator = typename MapT::const_iterator;
using IntervalT = std::pair<IndexT, IndexT>;
public:
using Allocator = typename MapT::Allocator;
CoalescingBitVector(Allocator &Alloc)
: Alloc(&Alloc), Intervals(Alloc) {}
CoalescingBitVector(const ThisT &Other)
: Alloc(Other.Alloc), Intervals(*Other.Alloc) {
set(Other);
}
ThisT &operator=(const ThisT &Other) {
clear();
set(Other);
return *this;
}
CoalescingBitVector(ThisT &&Other) = delete;
ThisT &operator=(ThisT &&Other) = delete;
void clear() { Intervals.clear(); }
bool empty() const { return Intervals.empty(); }
unsigned count() const {
unsigned Bits = 0;
for (auto It = Intervals.begin(), End = Intervals.end(); It != End; ++It)
Bits += 1 + It.stop() - It.start();
return Bits;
}
void set(IndexT Index) {
assert(!test(Index) && "Setting already-set bits not supported/efficient, "
"IntervalMap will assert");
insert(Index, Index);
}
void set(const ThisT &Other) {
for (auto It = Other.Intervals.begin(), End = Other.Intervals.end();
It != End; ++It)
insert(It.start(), It.stop());
}
void set(std::initializer_list<IndexT> Indices) {
for (IndexT Index : Indices)
set(Index);
}
bool test(IndexT Index) const {
const auto It = Intervals.find(Index);
if (It == Intervals.end())
return false;
assert(It.stop() >= Index && "Interval must end after Index");
return It.start() <= Index;
}
void test_and_set(IndexT Index) {
if (!test(Index))
set(Index);
}
void reset(IndexT Index) {
auto It = Intervals.find(Index);
if (It == Intervals.end())
return;
IndexT Start = It.start();
if (Index < Start)
return;
IndexT Stop = It.stop();
assert(Index <= Stop && "Wrong interval for index");
It.erase();
if (Start < Index)
insert(Start, Index - 1);
if (Index < Stop)
insert(Index + 1, Stop);
}
void operator|=(const ThisT &RHS) {
SmallVector<IntervalT, 8> Overlaps;
getOverlaps(RHS, Overlaps);
for (auto It = RHS.Intervals.begin(), End = RHS.Intervals.end();
It != End; ++It) {
IndexT Start = It.start();
IndexT Stop = It.stop();
SmallVector<IntervalT, 8> NonOverlappingParts;
getNonOverlappingParts(Start, Stop, Overlaps, NonOverlappingParts);
for (IntervalT AdditivePortion : NonOverlappingParts)
insert(AdditivePortion.first, AdditivePortion.second);
}
}
void operator&=(const ThisT &RHS) {
SmallVector<IntervalT, 8> Overlaps;
getOverlaps(RHS, Overlaps);
clear();
for (IntervalT Overlap : Overlaps)
insert(Overlap.first, Overlap.second);
}
void intersectWithComplement(const ThisT &Other) {
SmallVector<IntervalT, 8> Overlaps;
if (!getOverlaps(Other, Overlaps)) {
return;
}
for (IntervalT Overlap : Overlaps) {
IndexT OlapStart, OlapStop;
std::tie(OlapStart, OlapStop) = Overlap;
auto It = Intervals.find(OlapStart);
IndexT CurrStart = It.start();
IndexT CurrStop = It.stop();
assert(CurrStart <= OlapStart && OlapStop <= CurrStop &&
"Expected some intersection!");
It.erase();
if (CurrStart < OlapStart)
insert(CurrStart, OlapStart - 1);
if (OlapStop < CurrStop)
insert(OlapStop + 1, CurrStop);
}
}
bool operator==(const ThisT &RHS) const {
auto ItL = Intervals.begin();
auto ItR = RHS.Intervals.begin();
while (ItL != Intervals.end() && ItR != RHS.Intervals.end() &&
ItL.start() == ItR.start() && ItL.stop() == ItR.stop()) {
++ItL;
++ItR;
}
return ItL == Intervals.end() && ItR == RHS.Intervals.end();
}
bool operator!=(const ThisT &RHS) const { return !operator==(RHS); }
class const_iterator {
friend class CoalescingBitVector;
public:
using iterator_category = std::forward_iterator_tag;
using value_type = IndexT;
using difference_type = std::ptrdiff_t;
using pointer = value_type *;
using reference = value_type &;
private:
static constexpr unsigned kIteratorAtTheEndOffset = ~0u;
UnderlyingIterator MapIterator;
unsigned OffsetIntoMapIterator = 0;
IndexT CachedStart = IndexT();
IndexT CachedStop = IndexT();
void setToEnd() {
OffsetIntoMapIterator = kIteratorAtTheEndOffset;
CachedStart = IndexT();
CachedStop = IndexT();
}
void resetCache() {
if (MapIterator.valid()) {
OffsetIntoMapIterator = 0;
CachedStart = MapIterator.start();
CachedStop = MapIterator.stop();
} else {
setToEnd();
}
}
void advanceTo(IndexT Index) {
assert(Index <= CachedStop && "Cannot advance to OOB index");
if (Index < CachedStart)
return;
OffsetIntoMapIterator = Index - CachedStart;
}
const_iterator(UnderlyingIterator MapIt) : MapIterator(MapIt) {
resetCache();
}
public:
const_iterator() { setToEnd(); }
bool operator==(const const_iterator &RHS) const {
return std::tie(OffsetIntoMapIterator, CachedStart, CachedStop) ==
std::tie(RHS.OffsetIntoMapIterator, RHS.CachedStart,
RHS.CachedStop);
}
bool operator!=(const const_iterator &RHS) const {
return !operator==(RHS);
}
IndexT operator*() const { return CachedStart + OffsetIntoMapIterator; }
const_iterator &operator++() { if (CachedStart + OffsetIntoMapIterator < CachedStop) {
++OffsetIntoMapIterator;
} else {
++MapIterator;
resetCache();
}
return *this;
}
const_iterator operator++(int) { const_iterator tmp = *this;
operator++();
return tmp;
}
void advanceToLowerBound(IndexT Index) {
if (OffsetIntoMapIterator == kIteratorAtTheEndOffset)
return;
while (Index > CachedStop) {
++MapIterator;
resetCache();
if (OffsetIntoMapIterator == kIteratorAtTheEndOffset)
return;
}
advanceTo(Index);
}
};
const_iterator begin() const { return const_iterator(Intervals.begin()); }
const_iterator end() const { return const_iterator(); }
const_iterator find(IndexT Index) const {
auto UnderlyingIt = Intervals.find(Index);
if (UnderlyingIt == Intervals.end())
return end();
auto It = const_iterator(UnderlyingIt);
It.advanceTo(Index);
return It;
}
iterator_range<const_iterator> half_open_range(IndexT Start,
IndexT End) const {
assert(Start < End && "Not a valid range");
auto StartIt = find(Start);
if (StartIt == end() || *StartIt >= End)
return {end(), end()};
auto EndIt = StartIt;
EndIt.advanceToLowerBound(End);
return {StartIt, EndIt};
}
void print(raw_ostream &OS) const {
OS << "{";
for (auto It = Intervals.begin(), End = Intervals.end(); It != End;
++It) {
OS << "[" << It.start();
if (It.start() != It.stop())
OS << ", " << It.stop();
OS << "]";
}
OS << "}";
}
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
LLVM_DUMP_METHOD void dump() const {
dbgs() << "\n";
print(dbgs());
dbgs() << "\n";
}
#endif
private:
void insert(IndexT Start, IndexT End) { Intervals.insert(Start, End, 0); }
bool getOverlaps(const ThisT &Other,
SmallVectorImpl<IntervalT> &Overlaps) const {
for (IntervalMapOverlaps<MapT, MapT> I(Intervals, Other.Intervals);
I.valid(); ++I)
Overlaps.emplace_back(I.start(), I.stop());
assert(llvm::is_sorted(Overlaps,
[](IntervalT LHS, IntervalT RHS) {
return LHS.second < RHS.first;
}) &&
"Overlaps must be sorted");
return !Overlaps.empty();
}
void getNonOverlappingParts(IndexT Start, IndexT Stop,
const SmallVectorImpl<IntervalT> &Overlaps,
SmallVectorImpl<IntervalT> &NonOverlappingParts) {
IndexT NextUncoveredBit = Start;
for (IntervalT Overlap : Overlaps) {
IndexT OlapStart, OlapStop;
std::tie(OlapStart, OlapStop) = Overlap;
bool DoesOverlap = OlapStart <= Stop && Start <= OlapStop;
if (!DoesOverlap)
continue;
if (NextUncoveredBit < OlapStart)
NonOverlappingParts.emplace_back(NextUncoveredBit, OlapStart - 1);
NextUncoveredBit = OlapStop + 1;
if (NextUncoveredBit > Stop)
break;
}
if (NextUncoveredBit <= Stop)
NonOverlappingParts.emplace_back(NextUncoveredBit, Stop);
}
Allocator *Alloc;
MapT Intervals;
};
}
#endif