Chess engine in zig
const std = @import("std");

const Chess = @import("Chess.zig");
const Board = @import("Board.zig");

const nnue = @cImport({
    @cInclude("nnue.h");
});

// https://tests.stockfishchess.org/nns
const NNUE_FILE = "nn-62ef826d1a6d.nnue";

// convert our piece codes to Stockfish NNUE codes
const NNUE_PIECES: [12]c_int = .{ 6, 5, 4, 3, 2, 1, 12, 11, 10, 9, 8, 7 };

// convert our square indices to Stockfish NNUE indices
const NNUE_SQUARES: [64]Board.Square = .{
    // zig fmt: off
    .a1, .b1, .c1, .d1, .e1, .f1, .g1, .h1,
    .a2, .b2, .c2, .d2, .e2, .f2, .g2, .h2,
    .a3, .b3, .c3, .d3, .e3, .f3, .g3, .h3,
    .a4, .b4, .c4, .d4, .e4, .f4, .g4, .h4,
    .a5, .b5, .c5, .d5, .e5, .f5, .g5, .h5,
    .a6, .b6, .c6, .d6, .e6, .f6, .g6, .h6,
    .a7, .b7, .c7, .d7, .e7, .f7, .g7, .h7,
    .a8, .b8, .c8, .d8, .e8, .f8, .g8, .h8,
};
// zig fmt: on

test "NNUE_SQUARES" {
    {
        const square = @intFromEnum(Board.Square.a8);
        const nnue_square_enum = NNUE_SQUARES[square];
        std.debug.assert(nnue_square_enum == Board.Square.a1);
    }
    {
        const square = @intFromEnum(Board.Square.f3);
        const nnue_square_enum = NNUE_SQUARES[square];
        std.debug.assert(nnue_square_enum == Board.Square.f6);
    }
}

// 2*(pawns(8), knights(2), bishops(2), rooks(2), queen, king)
// +1 for the ending empty piece that Stockfish nnue expects
const PIECE_COUNT = 2 * 8 + 2 * 2 + 2 * 2 + 2 * 2 + 2 + 2 + 1;

var nnue_initialized = false;

// var data: [3]nnue.NNUEdata = .{nnue.NNUEdata{
//     .accumulator = nnue.Accumulator{
//         .accumulation = [_][256]i16{[_]i16{0} ** 256} ** 2,
//         .computedAccumulation = 0,
//     },
//     .dirtyPiece = nnue.DirtyPiece{
//         .dirtyNum = 0,
//         .pc = [_]c_int{0} ** 3,
//         .from = [_]c_int{0} ** 3,
//         .to = [_]c_int{0} ** 3,
//     },
// }} ** 3;

pub fn evaluate(gs: *const Board.GameState) isize {
    std.debug.assert(PIECE_COUNT == 33);

    var sf_pieces: [PIECE_COUNT]c_int = undefined;
    var sf_squares: [PIECE_COUNT]c_int = undefined;
    var index: usize = 2; // 0 and 1 are reserved for the two kings

    var bitboard: Board.BoardType = undefined;
    for (gs.bitboards, 0..) |bb, idx| {
        bitboard = @as(Board.BoardType, @bitCast(bb));
        while (bitboard != 0) {
            const piece = @as(Chess.PE, @enumFromInt(idx));
            const square = @ctz(bitboard);

            switch (piece) {
                .K => {
                    sf_pieces[0] = NNUE_PIECES[@intFromEnum(piece)];
                    sf_squares[0] = @intFromEnum(NNUE_SQUARES[square]);
                },
                .k => {
                    sf_pieces[1] = NNUE_PIECES[@intFromEnum(piece)];
                    sf_squares[1] = @intFromEnum(NNUE_SQUARES[square]);
                },
                else => {
                    sf_pieces[index] = NNUE_PIECES[@intFromEnum(piece)];
                    sf_squares[index] = @intFromEnum(NNUE_SQUARES[square]);
                    index += 1;
                },
            }

            bitboard ^= @as(Board.BoardType, 1) << @as(Board.SquareType, @intCast(square));
        }
    }

    // zero shows Stockfish nnue that there are no more pieces.
    sf_pieces[index] = 0;
    sf_squares[index] = 0;

    // return @intCast(isize, nnue.nnue_evaluate_incremental(
    //     @enumToInt(gs.side),
    //     &sf_pieces,
    //     &sf_squares,
    //     @ptrCast([*c][*c]nnue.NNUEdata, &data),
    // ));
    return @as(isize, @intCast(nnue.nnue_evaluate(
        @intFromEnum(gs.side),
        &sf_pieces,
        &sf_squares,
    ))) * @divTrunc(100 - gs.fifty, 100);
}

pub fn init() !void {
    nnue.nnue_init(NNUE_FILE);
    if (@as(c_int, 57) !=
        nnue.nnue_evaluate_fen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"))
    {
        return error.FailedToInitNNUE;
    }
}

test "nnue.evaluate" {
    nnue.nnue_init(NNUE_FILE);

    {
        const fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1";
        var gs = try Board.GameState.init(std.testing.allocator, fen);
        defer gs.deinit();
        try std.testing.expectEqual(@as(isize, nnue.nnue_evaluate_fen(fen)), evaluate(&gs));
    }
    {
        const fen = "r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq -";
        var gs = try Board.GameState.init(std.testing.allocator, fen);
        defer gs.deinit();
        try std.testing.expectEqual(@as(isize, nnue.nnue_evaluate_fen(fen)), evaluate(&gs));
    }
    {
        const fen = "8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - -";
        var gs = try Board.GameState.init(std.testing.allocator, fen);
        defer gs.deinit();
        try std.testing.expectEqual(@as(isize, nnue.nnue_evaluate_fen(fen)), evaluate(&gs));
    }
    {
        const fen = "r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1";
        var gs = try Board.GameState.init(std.testing.allocator, fen);
        defer gs.deinit();
        try std.testing.expectEqual(@as(isize, nnue.nnue_evaluate_fen(fen)), evaluate(&gs));
    }
}