Chess engine in zig
const std = @import("std");

const Board = @import("Board.zig");

const zobrist = @import("zobrist.zig");

pub const Perft = struct {
    nodes: usize = 0,
    gs: Board.GameState,

    pub fn perftDriver(self: *@This(), depth: usize) !void {
        if (depth == 0) {
            // count nodes
            self.nodes += 1;
            return;
        }

        var ml = try Board.MoveList.init(0);
        try self.gs.generateMoves(&ml);

        for (ml.slice()) |mp| {
            var bk: Board.GameState = undefined;
            self.gs.backup(&bk);

            if (!self.gs.makeMove(mp.move, .all)) {
                self.gs.restore(&bk);
                continue;
            }

            // only for testing incremental Zobrist hash updates!
            // {
            //     const orig = self.gs.hash;
            //     zobrist.updateHash(&self.gs);
            //     if (orig != self.gs.hash) {
            //         std.log.err("{any}", .{mp.move});
            //         bk.show();
            //         self.gs.show();

            //         return error.HashMismatch;
            //     }
            // }

            try self.perftDriver(depth - 1);

            self.gs.restore(&bk);
        }
    }

    fn perftTest(self: *@This(), depth: usize) !void {
        var ml = try Board.MoveList.init(0);
        try self.gs.generateMoves(&ml);

        for (ml.slice()) |mp| {
            var bk: Board.GameState = undefined;
            self.gs.backup(&bk);

            if (!self.gs.makeMove(mp.move, .all)) {
                self.gs.restore(&bk);
                continue;
            }

            const cummulative_nodes = self.nodes;

            try self.perftDriver(depth - 1);

            const old_nodes = self.nodes - cummulative_nodes;

            self.gs.restore(&bk);

            mp.move.show();
            std.debug.print("nodes: {d}\n", .{old_nodes});
        }
    }
};

test "Perft" {
    var gs = try Board.GameState.init(
        std.testing.allocator,
        "8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - -",
    );
    defer gs.deinit();

    {
        var p = Perft{ .gs = gs };
        try p.perftDriver(3);
        try std.testing.expectEqual(@as(usize, 2812), p.nodes);
    }

    {
        var p = Perft{ .gs = gs };
        try p.perftDriver(4);
        try std.testing.expectEqual(@as(usize, 43238), p.nodes);
    }
}

test "PerfTest" {
    var gs = try Board.GameState.init(
        std.testing.allocator,
        "r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - ",
    );
    defer gs.deinit();

    gs.show();
    var p = Perft{ .gs = gs };

    var timer = try std.time.Timer.start();

    try p.perftTest(5);

    std.debug.print("Checked {d} nodes in {d} ms\n", .{ p.nodes, timer.read() / std.time.ns_per_ms });
}