const std = @import("std");
fn SkipList(
comptime K: anytype,
comptime V: anytype,
context: anytype,
comptime lessThan: fn (@TypeOf(context), lhs: K, rhs: K) bool,
) type {
return struct {
rnd: std.rand.Random,
allocator: std.mem.Allocator,
P: f32,
max_level: usize,
level: usize = 0,
head: *Node,
const Node = struct {
key: ?K,
value: V,
forward: []?*Node,
fn init(allocator: std.mem.Allocator, key: K, level: usize) !*@This() {
var n = try allocator.create(Node);
n.key = key;
n.forward = try allocator.alloc(?*Node, level + 1);
std.mem.set(?*Node, n.forward, null);
return n;
}
fn deinit(self: *@This(), allocator: std.mem.Allocator) void {
allocator.free(self.forward);
allocator.destroy(self);
}
};
fn init(allocator: std.mem.Allocator, max_level: usize, P: f32) !@This() {
var rnd_tmp = std.rand.DefaultPrng.init(@intCast(u64, std.time.milliTimestamp()));
var sl = @This(){
.rnd = rnd_tmp.random(),
.max_level = max_level,
.P = P,
.allocator = allocator,
.head = undefined,
};
sl.head = Node.init(sl.allocator, 0, max_level) catch unreachable;
sl.head.key = null;
// XXX: this version causing segfault
// sl.head = try Node.init(sl.allocator, -1, max_level);
return sl;
}
fn deinit(self: @This()) void {
_ = self;
var current: ?*Node = self.head;
while (true) {
const next = current.?.forward[0];
self.allocator.free(current.?.forward);
self.allocator.destroy(current.?);
if (next == null) break;
current = next;
}
}
fn search(self: @This(), search_key: K) ?*Node {
var current: *Node = self.head;
var level: usize = self.max_level + 1;
while (level > 0) : (level -= 1) {
while (current.forward[level - 1]) |next| : (current = next) {
if (!lessThan(context, next.key.?, search_key)) break;
}
}
current = current.forward[0].?;
if (current.key == search_key) return current else return null;
}
fn randomLevel(self: @This()) usize {
var level: usize = 0;
while (self.rnd.float(f32) < self.P and level < self.max_level) level += 1;
return level;
}
fn insert(self: *@This(), insert_key: K, insert_value: V) !void {
var current: ?*Node = self.head;
// create update array for `*Node` items which should be updated
var update = try self.allocator.alloc(?*Node, self.max_level + 1);
std.mem.set(?*Node, update, null);
defer self.allocator.free(update);
// Start from highest level move the current pointer forward while `insert_key`
// is greater than key of node next to current. Otherwise inserted current in
// update and move one level down and continue search.
var l: usize = self.level + 1;
while (l > 0) : (l -= 1) {
while (current.?.forward[l - 1]) |next| : (current = next) {
if (next.key.? >= insert_key) break;
}
std.debug.assert(current != null);
update[l - 1] = current;
}
// reached level 0 and forward pointer to right, which is the desired
// point of insertion
current = current.?.forward[0];
// If current is null, then we have reached the end of the level.
// If it is not null and current.key == insert_key, then key already exists.
if (current != null and current.?.key == insert_key) return;
// We have to insert our new Node.
const rlevel = self.randomLevel();
if (rlevel > self.level) {
var level: usize = self.level + 1;
while (level < rlevel + 1) : (level += 1) {
update[level] = self.head;
}
// update level
self.level = rlevel;
}
// create node with rlevel
const n = try Node.init(self.allocator, insert_key, rlevel);
n.value = insert_value;
// insert new node to SkipList
var i: usize = 0;
while (i <= rlevel) : (i += 1) {
n.forward[i] = update[i].?.forward[i];
update[i].?.forward[i] = n;
}
}
fn display(self: @This()) void {
std.debug.print("SkipList structure:\n", .{});
var level: usize = self.max_level + 1;
while (level > 0) : (level -= 1) {
var node: ?*Node = self.head.forward[level - 1];
std.debug.print("Level {d}: ", .{level - 1});
while (node != null) {
std.debug.print("{any} ", .{node.?.key});
node = node.?.forward[level - 1];
}
std.debug.print("\n", .{});
}
}
fn remove(self: *@This(), delete_key: K) !bool {
var current: ?*Node = self.head;
// create update array for `*Node` items which should be updated
var update = try self.allocator.alloc(?*Node, self.max_level + 1);
std.mem.set(?*Node, update, null);
defer self.allocator.free(update);
// Start from highest level move the current pointer forward while `delete_key`
// is greater than key of node next to current. Otherwise inserted current in
// update and move one level down and continue search.
var l: usize = self.level + 1;
while (l > 0) : (l -= 1) {
while (current.?.forward[l - 1]) |next| : (current = next) {
if (next.key.? >= delete_key) break;
}
std.debug.assert(current != null);
update[l - 1] = current;
}
// Reached level 0 and forward pointer to right, which is the desired
// point of deletion.
current = current.?.forward[0];
// If current is `null` or `current.key != delete_key` the item is not
// in the SkipList.
if (current == null or current.?.key != delete_key) return false;
// We have found the node to delete.
var level: usize = 0;
while (level <= self.level) : (level += 1) {
// If next Node is not the current one then we do not need to
// update any more Nodes.
if (update[level].?.forward[level] != current) break;
update[level].?.forward[level] = current.?.forward[level];
}
// Remove levels without elements.
level = 1;
while (level < self.level) : (level += 1) {
if (self.head.forward[level] == null) self.level -= 1;
}
current.?.deinit(self.allocator);
return true;
}
};
}
test "SkipList" {
const KeyType = usize;
const SL = SkipList(KeyType, void, {}, comptime std.sort.asc(KeyType));
var sl = try SL.init(std.testing.allocator, 3, 1 / std.math.e);
defer sl.deinit();
try sl.insert(3, {});
try sl.insert(6, {});
try sl.insert(7, {});
try sl.insert(9, {});
try sl.insert(0, {});
try sl.insert(std.math.minInt(KeyType), {});
try sl.insert(12, {});
try sl.insert(12, {}); // does nothing
try sl.insert(12, {}); // does nothing
try sl.insert(19, {});
try sl.insert(17, {});
try sl.insert(26, {});
try sl.insert(21, {});
try sl.insert(25, {});
try std.testing.expectEqual(@as(KeyType, 12), sl.search(12).?.key.?);
try std.testing.expectEqual(@as(KeyType, 26), sl.search(26).?.key.?);
try std.testing.expectEqual(@as(?*SL.Node, null), sl.search(13));
try std.testing.expectEqual(true, try sl.remove(12));
try std.testing.expectEqual(false, try sl.remove(12));
_ = try sl.remove(std.math.minInt(KeyType));
sl.display();
}