Chess engine in zig
const std = @import("std");

const build_options = @import("build_options");
const EVAL_FUNCTION = if (build_options.nnue) nnue.evaluate else eval.evaluate;

const Str = []const u8;

const Chess = @import("Chess.zig");
const BoardType = @import("Board.zig").BoardType;
const BitMove = @import("Board.zig").BitMove;
const MovePrio = @import("Board.zig").MovePrio;
const BitMoveType = @import("Board.zig").BitMoveType;
const GameState = @import("Board.zig").GameState;
const MoveList = @import("Board.zig").MoveList;
const SquareType = @import("Board.zig").SquareType;
const Transposition = @import("Transposition.zig");

const eval = @import("eval.zig");
const nnue = @import("nnue.zig");
const score = @import("score.zig");
const uci = @import("uci.zig");
const zobrist = @import("zobrist.zig");

const DEFAULT_DEPTH = 64;
const TIME_SAFETY = 50;

pub const SearchOptions = struct {
    depth: usize = DEFAULT_DEPTH,
    movetime: usize = std.math.maxInt(usize),
    nodes: usize = std.math.maxInt(usize),
    movestogo: usize = 20,
    wtime: usize = std.math.maxInt(usize),
    winc: usize = 0,
    btime: usize = std.math.maxInt(usize),
    binc: usize = 0,
};

pub const Search = struct {
    state: *GameState,
    stop: bool = false,
    ply: usize = 0,

    // UCI time control
    timer: std.time.Timer = undefined,
    movetime: usize = undefined,
    node_limit: usize = 0,

    // Transposition Table
    tt: Transposition.Table = undefined,

    var nodes: usize = undefined;

    // Time control
    const CHECK_STOP = 0x7ff; // 2047
    var start_state: GameState = undefined;
    var start_history: usize = undefined;

    // Maximum ply we can reach while searching
    const MAX_PLY = DEFAULT_DEPTH;

    // LMR
    const FULL_DEPTH_MOVES = 4;
    const REDUCTION_LIMIT = 3;

    // NULL move pruning
    const NULL_REDUCTION = 2;

    // Principle Variation
    var pv_length: [MAX_PLY]usize = undefined;
    var pv_table: [MAX_PLY][MAX_PLY]BitMove = undefined;
    var pv_follow: bool = false;

    pub fn init(gs: *GameState, options: uci.EngineOptions) !@This() {
        if (build_options.nnue) {
            try nnue.init();
        }
        std.log.debug("Initializin Search, options: {any}", .{options});
        return Search{
            .state = gs,
            .tt = try Transposition.Table.init(gs.allocator, options.hash_size * (1 << 20)),
        };
    }

    pub fn deinit(self: @This()) void {
        self.tt.deinit();
    }

    pub fn bestMove(self: *@This(), options: SearchOptions) !void {
        if (options.movetime != std.math.maxInt(usize)) {
            self.movetime = options.movetime -| TIME_SAFETY;
        } else {
            switch (self.state.side) {
                .white => self.movetime = @max(
                    options.wtime /
                        @max(options.movestogo -| self.state.history.count(), 5),
                    options.winc -| TIME_SAFETY,
                ),
                .black => self.movetime = @max(
                    options.btime /
                        @max(options.movestogo -| self.state.history.count(), 5),
                    options.binc -| TIME_SAFETY,
                ),
            }
        }
        self.node_limit = options.nodes;

        std.log.debug(
            "bestMove depth: {d} node_limit: {d} movetime {d}",
            .{ options.depth, self.node_limit, self.movetime },
        );

        std.log.debug("Start searching: ply: {d} fifty: {d} history_length: {d}", .{
            self.ply,
            self.state.fifty,
            self.state.history.count(),
        });

        const ALPHA = -eval.INFINITY;
        const BETA = eval.INFINITY;

        var alpha: isize = ALPHA;
        var beta: isize = BETA;

        // clear helper data for search
        {
            score.killer_moves = [_][MAX_PLY]BitMove{
                [_]BitMove{@as(BitMove, @bitCast(@as(BitMoveType, 0)))} ** MAX_PLY,
            } ** 2;
            score.history_moves = [_][64]usize{[_]usize{0} ** 64} ** 12;
            pv_length = [_]usize{0} ** MAX_PLY;
            pv_table = [_][MAX_PLY]BitMove{
                [_]BitMove{@as(BitMove, @bitCast(@as(BitMoveType, 0)))} ** MAX_PLY,
            } ** MAX_PLY;
            nodes = 0;

            self.state.backup(&start_state);
            start_history = self.state.history.count();
        }

        // iterative deepening
        var current_depth: usize = 1;
        while (current_depth <= options.depth) : (current_depth += 1) {
            pv_follow = true;

            const scr = self.negamax(alpha, beta, current_depth) catch |err| {
                switch (err) {
                    error.TimeIsUp => {
                        self.tearDownHard();
                        break;
                    },
                    else => @panic("undexpected negamax error"),
                }
            };

            // ASPIRATION window: we fell out, so reset values and try again
            if (scr <= alpha or scr >= beta) {
                alpha = ALPHA;
                beta = BETA;
                current_depth -= 1;
                continue;
            }

            alpha = scr - 50;
            beta = scr + 50;

            try self.print_info(scr, current_depth);
        }

        try self.printBestMove();
    }

    fn printBestMove(self: *@This()) !void {
        // Store our move in move history
        _ = self.state.makeMove(pv_table[0][0], .all);
        try self.state.history.put(self.state.hash, {});

        // Print best move
        const prom = if (pv_table[0][0].prom != .none)
            try std.ascii.allocLowerString(self.state.allocator, @tagName(pv_table[0][0].prom))
        else
            "";
        _ = try std.io.getStdOut().write(try std.fmt.allocPrint(
            self.state.allocator,
            "bestmove {s}{s}{s}\n",
            .{ @tagName(pv_table[0][0].source), @tagName(pv_table[0][0].target), prom },
        ));
    }

    inline fn stopSearch(self: *@This()) bool {
        if (nodes >= self.node_limit or
            (self.timer.read() / std.time.ns_per_ms) > self.movetime or
            @atomicLoad(bool, &self.stop, std.builtin.AtomicOrder.Unordered))
        {
            return true;
        } else {
            return false;
        }
    }

    inline fn adjustTTgetScore(self: *const @This(), s: isize) isize {
        if (s < -eval.CHECKMATE_SCORE) {
            return s + @as(isize, @intCast(self.ply));
        } else if (s > eval.CHECKMATE_SCORE) {
            return s - @as(isize, @intCast(self.ply));
        } else {
            return s;
        }
    }

    inline fn adjustTTsetScore(self: *const @This(), s: isize) isize {
        if (s < -eval.CHECKMATE_SCORE) {
            return s - @as(isize, @intCast(self.ply));
        } else if (s > eval.CHECKMATE_SCORE) {
            return s + @as(isize, @intCast(self.ply));
        } else {
            return s;
        }
    }

    fn isDraw(self: @This()) bool {
        if (self.state.fifty >= 100 or self.state.history.contains(self.state.hash)) {
            return true;
        }
        return false;
    }

    inline fn tearDown(self: *@This(), bck: *GameState) void {
        self.ply -= 1;
        self.state.restore(bck);
        _ = self.state.history.pop();
    }

    fn tearDownHard(self: *@This()) void {
        self.ply = 0;
        self.state.history.shrinkRetainingCapacity(start_history);
    }

    fn negamax(self: *@This(), alpha_orig: isize, beta: isize, depth_orig: usize) !isize {
        // TT probe
        var best_move: BitMove = @as(BitMove, @bitCast(@as(BitMoveType, 0)));
        const pv_node = (beta - alpha_orig > 1);
        if (self.ply != 0 and !pv_node) {
            if (self.tt.get(self.state.hash)) |entry| {
                if (entry.depth >= depth_orig) {
                    switch (entry.flag) {
                        .empty => @panic("negamax: found .empty entry in tt"),
                        .exact => return self.adjustTTgetScore(entry.score),
                        .lower => {
                            if (entry.score <= alpha_orig) return self.adjustTTgetScore(alpha_orig);
                        },
                        .upper => {
                            if (entry.score >= beta) return self.adjustTTgetScore(beta);
                        },
                    }
                }
                best_move = entry.best;
            }
        }

        if (nodes & CHECK_STOP == 0 and self.stopSearch()) return error.TimeIsUp;

        pv_length[self.ply] = self.ply;

        if (depth_orig == 0) return try self.quiescenceSearch(alpha_orig, beta);

        // avoid overflow of self.ply arrays
        if (self.ply > MAX_PLY - 1) {
            return EVAL_FUNCTION(self.state);
        }

        var alpha = alpha_orig;
        var depth = depth_orig;

        // increase search depth if king exposed to check
        const in_check = self.state.inCheck();
        if (in_check) depth += 1;

        if (self.ply > 0 and self.isDraw()) {
            return 0;
        }

        nodes += 1;
        var legal_moves: usize = 0;

        // NULL move pruning
        {
            if (depth >= 3 and !in_check and self.ply != 0) {
                // backup board state
                var bck: GameState = undefined;
                self.state.backup(&bck);

                self.ply += 1;
                try self.state.history.put(self.state.hash, {});

                if (self.state.enpassant != null) {
                    self.state.hash ^= zobrist.enpassant_hashes[@intFromEnum(self.state.enpassant.?)];
                }

                // prepare GameState
                self.state.side = self.state.side.enemy();
                self.state.enpassant = null;

                self.state.hash ^= zobrist.side_hash;

                const s = -try self.negamax(-beta, -beta + 1, depth - 1 - NULL_REDUCTION);

                self.tearDown(&bck);

                if (s >= beta) {
                    return beta;
                }
            }
        }

        var ml = try MoveList.init(0);
        try self.state.generateMoves(&ml);
        score.moves(self, &ml, best_move);

        // Honor following the PV line
        if (pv_follow) {
            pv_follow = false;
            for (ml.slice()) |*mp| {
                if (@as(BitMoveType, @bitCast(pv_table[0][self.ply])) == @as(BitMoveType, @bitCast(mp.move))) {
                    pv_follow = true;
                    mp.score = score.PV_SCORE;
                    break;
                }
            }
        }

        std.mem.sort(MovePrio, ml.slice(), {}, comptime MovePrio.moreThan);

        var scr: isize = undefined;
        var tt_flag: Transposition.Flag = .lower;

        var moves_searched: usize = 0;
        for (ml.slice()) |mp| {
            var bck: GameState = undefined;
            self.state.backup(&bck);

            self.ply += 1;
            try self.state.history.put(self.state.hash, {});

            if (!self.state.makeMove(mp.move, .all)) {
                self.tearDown(&bck);
                continue;
            }

            legal_moves += 1;

            // Normal negamax without PVS or LMR
            if (moves_searched == 0) {
                scr = -try self.negamax(-beta, -alpha, depth - 1);
            } else {
                // Conditions to consider LMR
                if (moves_searched >= FULL_DEPTH_MOVES and
                    depth >= REDUCTION_LIMIT and
                    !mp.move.capture and mp.move.prom == .none and
                    !self.state.inCheck())
                {
                    // reduced-depth, narrow score search for lame moves
                    scr = -try self.negamax(-alpha - 1, -alpha, depth -| 3);
                } else {
                    // hack to ensure that we continue with PVS code below
                    scr = alpha + 1;
                }

                // PVS code
                if (scr > alpha) {
                    // Once you've found a move with a score that is between alpha and beta,
                    // the rest of the moves are searched with the goal of proving that they are all bad.
                    // It's possible to do this a bit faster than a search that worries that one
                    // of the remaining moves might be good.
                    scr = -try self.negamax(-alpha - 1, -alpha, depth - 1);

                    // If the algorithm finds out that it was wrong, and that one of the
                    // subsequent moves was better than the first PV move, it has to search again,
                    // in the normal alpha-beta manner.  This happens sometimes, and it's a waste of time,
                    // but generally not often enough to counteract the savings gained from doing the
                    // "bad move proof" search referred to earlier.
                    if (scr > alpha and scr < beta) {
                        scr = -try self.negamax(-beta, -alpha, depth - 1);
                    }
                }
            }

            self.tearDown(&bck);
            moves_searched += 1;

            // found a better move
            if (scr > alpha) {
                tt_flag = .exact;

                best_move = mp.move;

                // store history moves
                if (!mp.move.capture) {
                    score.history_moves[@intFromEnum(mp.move.piece)][@intFromEnum(mp.move.target)] += depth;
                }

                // PV (principal variation) node
                alpha = scr;

                {
                    // write PV move
                    pv_table[self.ply][self.ply] = mp.move;

                    // copy move from deeper self.ply into the current self.ply's line
                    var next: usize = self.ply + 1;
                    while (next < pv_length[self.ply + 1]) : (next += 1) {
                        pv_table[self.ply][next] = pv_table[self.ply + 1][next];
                    }

                    // adjust pv_length
                    pv_length[self.ply] = pv_length[self.ply + 1];
                }

                // fail-hard beta cutoff
                if (scr >= beta) {
                    self.tt.set(.{
                        .key = self.state.hash,
                        .score = self.adjustTTsetScore(beta),
                        .best = best_move,
                        .depth = @as(u8, @intCast(depth)),
                        .flag = .upper,
                    });

                    // store killer moves
                    if (!mp.move.capture) {
                        score.killer_moves[1][self.ply] = score.killer_moves[0][self.ply];
                        score.killer_moves[0][self.ply] = mp.move;
                    }

                    // node (move) fails high
                    return beta;
                }
            }
        }

        // Detect check- and stalemate
        if (legal_moves == 0) {
            if (in_check) {
                // Adding self.ply helps the engine finding fastest mates.
                return -eval.MATE_VALUE + @as(isize, @intCast(self.ply));
            } else {
                return eval.STALEMATE_SCORE;
            }
        }

        self.tt.set(.{
            .key = self.state.hash,
            .score = self.adjustTTsetScore(alpha),
            .best = best_move,
            .depth = @as(u8, @intCast(depth)),
            .flag = tt_flag,
        });

        // node (move) fails low
        return alpha;
    }

    // quiescenceSearch helps the engine not thow away its pieces because of
    // a temporary seemingly good move. (Horizon effect)
    fn quiescenceSearch(self: *@This(), alpha_orig: isize, beta: isize) !isize {
        if (nodes & CHECK_STOP == 0 and self.stopSearch()) return error.TimeIsUp;

        nodes += 1;

        var alpha = alpha_orig;

        // avoid overflow of self.ply arrays
        if (self.ply > MAX_PLY - 1) {
            return EVAL_FUNCTION(self.state);
        }

        const evaluation = EVAL_FUNCTION(self.state);

        // fail-hard beta cutoff
        if (evaluation >= beta) {
            // node (move) fails high
            return beta;
        }

        // found a better move
        if (evaluation > alpha) {
            // PV (principal validation) node
            alpha = evaluation;
        }

        if (self.tt.get(self.state.hash)) |entry| {
            switch (entry.flag) {
                .empty => unreachable,
                .exact => return self.adjustTTgetScore(entry.score),
                .lower => {
                    if (entry.score <= alpha_orig) return self.adjustTTgetScore(alpha_orig);
                },
                .upper => {
                    if (entry.score >= beta) return self.adjustTTgetScore(beta);
                },
            }
        }

        var ml = try MoveList.init(0);
        try self.state.generateMoves(&ml);
        score.moves(self, &ml, @as(BitMove, @bitCast(@as(BitMoveType, 0))));
        std.mem.sort(MovePrio, ml.slice(), {}, comptime MovePrio.moreThan);

        for (ml.slice()) |mp| {
            var bck: GameState = undefined;
            self.state.backup(&bck);

            self.ply += 1;
            if (!self.state.makeMove(mp.move, .captures)) {
                self.ply -= 1;
                self.state.restore(&bck);
                continue;
            }

            const scr = -try self.quiescenceSearch(-beta, -alpha);

            self.ply -= 1;
            self.state.restore(&bck);

            // found a better move
            if (scr > alpha) {
                // PV (principal validation) node
                alpha = scr;

                // fail-hard beta cutoff
                if (scr >= beta) {
                    // node (move) fails high
                    return beta;
                }
            }
        }

        // node (move) fails low
        return alpha;
    }

    fn print_info(self: *@This(), scr: isize, depth: usize) !void {
        var ret = std.ArrayList(u8).init(self.state.allocator);
        defer ret.deinit();

        const time = @as(f32, @floatFromInt(self.timer.read())) / std.time.ns_per_ms;
        if (try std.math.absInt(scr) > eval.CHECKMATE_SCORE) {
            const info = try std.fmt.allocPrint(
                self.state.allocator,
                "info score mate {d} depth {d} nodes {d} time {d:.0} nps {d:.0} pv ",
                .{ scr, depth, nodes, time, (@as(f32, @floatFromInt(nodes)) / time) * std.time.ms_per_s },
            );
            try ret.appendSlice(info);
        } else {
            const info = try std.fmt.allocPrint(
                self.state.allocator,
                "info score cp {d} depth {d} nodes {d} time {d:.0} nps {d:.0} pv ",
                .{ scr, depth, nodes, time, (@as(f32, @floatFromInt(nodes)) / time) * std.time.ms_per_s },
            );
            try ret.appendSlice(info);
        }

        // loop over the pv_line
        var count: usize = 0;
        while (count < pv_length[0]) : (count += 1) {
            // print PV move
            const m = pv_table[0][count];
            const move = try std.fmt.allocPrint(
                self.state.allocator,
                "{s}{s} ",
                .{ @tagName(m.source), @tagName(m.target) },
            );
            try ret.appendSlice(move);
        }
        try ret.append('\n');

        _ = try std.io.getStdOut().write(ret.items);
    }
};