const std = @import("std");
const build_options = @import("build_options");
const EVAL_FUNCTION = if (build_options.nnue) nnue.evaluate else eval.evaluate;
const Str = []const u8;
const Chess = @import("Chess.zig");
const BoardType = @import("Board.zig").BoardType;
const BitMove = @import("Board.zig").BitMove;
const MovePrio = @import("Board.zig").MovePrio;
const BitMoveType = @import("Board.zig").BitMoveType;
const GameState = @import("Board.zig").GameState;
const MoveList = @import("Board.zig").MoveList;
const SquareType = @import("Board.zig").SquareType;
const Transposition = @import("Transposition.zig");
const eval = @import("eval.zig");
const nnue = @import("nnue.zig");
const score = @import("score.zig");
const uci = @import("uci.zig");
const zobrist = @import("zobrist.zig");
const DEFAULT_DEPTH = 64;
const TIME_SAFETY = 50;
pub const SearchOptions = struct {
depth: usize = DEFAULT_DEPTH,
movetime: usize = std.math.maxInt(usize),
nodes: usize = std.math.maxInt(usize),
movestogo: usize = 20,
wtime: usize = std.math.maxInt(usize),
winc: usize = 0,
btime: usize = std.math.maxInt(usize),
binc: usize = 0,
};
pub const Search = struct {
state: *GameState,
stop: bool = false,
ply: usize = 0,
// UCI time control
timer: std.time.Timer = undefined,
movetime: usize = undefined,
node_limit: usize = 0,
// Transposition Table
tt: Transposition.Table = undefined,
var nodes: usize = undefined;
// Time control
const CHECK_STOP = 0x7ff; // 2047
var start_state: GameState = undefined;
var start_history: usize = undefined;
// Maximum ply we can reach while searching
const MAX_PLY = DEFAULT_DEPTH;
// LMR
const FULL_DEPTH_MOVES = 4;
const REDUCTION_LIMIT = 3;
// NULL move pruning
const NULL_REDUCTION = 2;
// Principle Variation
var pv_length: [MAX_PLY]usize = undefined;
var pv_table: [MAX_PLY][MAX_PLY]BitMove = undefined;
var pv_follow: bool = false;
pub fn init(gs: *GameState, options: uci.EngineOptions) !@This() {
if (build_options.nnue) {
try nnue.init();
}
std.log.debug("Initializin Search, options: {any}", .{options});
return Search{
.state = gs,
.tt = try Transposition.Table.init(gs.allocator, options.hash_size * (1 << 20)),
};
}
pub fn deinit(self: @This()) void {
self.tt.deinit();
}
pub fn bestMove(self: *@This(), options: SearchOptions) !void {
if (options.movetime != std.math.maxInt(usize)) {
self.movetime = options.movetime -| TIME_SAFETY;
} else {
switch (self.state.side) {
.white => self.movetime = @max(
options.wtime /
@max(options.movestogo -| self.state.history.count(), 5),
options.winc -| TIME_SAFETY,
),
.black => self.movetime = @max(
options.btime /
@max(options.movestogo -| self.state.history.count(), 5),
options.binc -| TIME_SAFETY,
),
}
}
self.node_limit = options.nodes;
std.log.debug(
"bestMove depth: {d} node_limit: {d} movetime {d}",
.{ options.depth, self.node_limit, self.movetime },
);
std.log.debug("Start searching: ply: {d} fifty: {d} history_length: {d}", .{
self.ply,
self.state.fifty,
self.state.history.count(),
});
const ALPHA = -eval.INFINITY;
const BETA = eval.INFINITY;
var alpha: isize = ALPHA;
var beta: isize = BETA;
// clear helper data for search
{
score.killer_moves = [_][MAX_PLY]BitMove{
[_]BitMove{@as(BitMove, @bitCast(@as(BitMoveType, 0)))} ** MAX_PLY,
} ** 2;
score.history_moves = [_][64]usize{[_]usize{0} ** 64} ** 12;
pv_length = [_]usize{0} ** MAX_PLY;
pv_table = [_][MAX_PLY]BitMove{
[_]BitMove{@as(BitMove, @bitCast(@as(BitMoveType, 0)))} ** MAX_PLY,
} ** MAX_PLY;
nodes = 0;
self.state.backup(&start_state);
start_history = self.state.history.count();
}
// iterative deepening
var current_depth: usize = 1;
while (current_depth <= options.depth) : (current_depth += 1) {
pv_follow = true;
const scr = self.negamax(alpha, beta, current_depth) catch |err| {
switch (err) {
error.TimeIsUp => {
self.tearDownHard();
break;
},
else => @panic("undexpected negamax error"),
}
};
// ASPIRATION window: we fell out, so reset values and try again
if (scr <= alpha or scr >= beta) {
alpha = ALPHA;
beta = BETA;
current_depth -= 1;
continue;
}
alpha = scr - 50;
beta = scr + 50;
try self.print_info(scr, current_depth);
}
try self.printBestMove();
}
fn printBestMove(self: *@This()) !void {
// Store our move in move history
_ = self.state.makeMove(pv_table[0][0], .all);
try self.state.history.put(self.state.hash, {});
// Print best move
const prom = if (pv_table[0][0].prom != .none)
try std.ascii.allocLowerString(self.state.allocator, @tagName(pv_table[0][0].prom))
else
"";
_ = try std.io.getStdOut().write(try std.fmt.allocPrint(
self.state.allocator,
"bestmove {s}{s}{s}\n",
.{ @tagName(pv_table[0][0].source), @tagName(pv_table[0][0].target), prom },
));
}
inline fn stopSearch(self: *@This()) bool {
if (nodes >= self.node_limit or
(self.timer.read() / std.time.ns_per_ms) > self.movetime or
@atomicLoad(bool, &self.stop, std.builtin.AtomicOrder.Unordered))
{
return true;
} else {
return false;
}
}
inline fn adjustTTgetScore(self: *const @This(), s: isize) isize {
if (s < -eval.CHECKMATE_SCORE) {
return s + @as(isize, @intCast(self.ply));
} else if (s > eval.CHECKMATE_SCORE) {
return s - @as(isize, @intCast(self.ply));
} else {
return s;
}
}
inline fn adjustTTsetScore(self: *const @This(), s: isize) isize {
if (s < -eval.CHECKMATE_SCORE) {
return s - @as(isize, @intCast(self.ply));
} else if (s > eval.CHECKMATE_SCORE) {
return s + @as(isize, @intCast(self.ply));
} else {
return s;
}
}
fn isDraw(self: @This()) bool {
if (self.state.fifty >= 100 or self.state.history.contains(self.state.hash)) {
return true;
}
return false;
}
inline fn tearDown(self: *@This(), bck: *GameState) void {
self.ply -= 1;
self.state.restore(bck);
_ = self.state.history.pop();
}
fn tearDownHard(self: *@This()) void {
self.ply = 0;
self.state.history.shrinkRetainingCapacity(start_history);
}
fn negamax(self: *@This(), alpha_orig: isize, beta: isize, depth_orig: usize) !isize {
// TT probe
var best_move: BitMove = @as(BitMove, @bitCast(@as(BitMoveType, 0)));
const pv_node = (beta - alpha_orig > 1);
if (self.ply != 0 and !pv_node) {
if (self.tt.get(self.state.hash)) |entry| {
if (entry.depth >= depth_orig) {
switch (entry.flag) {
.empty => @panic("negamax: found .empty entry in tt"),
.exact => return self.adjustTTgetScore(entry.score),
.lower => {
if (entry.score <= alpha_orig) return self.adjustTTgetScore(alpha_orig);
},
.upper => {
if (entry.score >= beta) return self.adjustTTgetScore(beta);
},
}
}
best_move = entry.best;
}
}
if (nodes & CHECK_STOP == 0 and self.stopSearch()) return error.TimeIsUp;
pv_length[self.ply] = self.ply;
if (depth_orig == 0) return try self.quiescenceSearch(alpha_orig, beta);
// avoid overflow of self.ply arrays
if (self.ply > MAX_PLY - 1) {
return EVAL_FUNCTION(self.state);
}
var alpha = alpha_orig;
var depth = depth_orig;
// increase search depth if king exposed to check
const in_check = self.state.inCheck();
if (in_check) depth += 1;
if (self.ply > 0 and self.isDraw()) {
return 0;
}
nodes += 1;
var legal_moves: usize = 0;
// NULL move pruning
{
if (depth >= 3 and !in_check and self.ply != 0) {
// backup board state
var bck: GameState = undefined;
self.state.backup(&bck);
self.ply += 1;
try self.state.history.put(self.state.hash, {});
if (self.state.enpassant != null) {
self.state.hash ^= zobrist.enpassant_hashes[@intFromEnum(self.state.enpassant.?)];
}
// prepare GameState
self.state.side = self.state.side.enemy();
self.state.enpassant = null;
self.state.hash ^= zobrist.side_hash;
const s = -try self.negamax(-beta, -beta + 1, depth - 1 - NULL_REDUCTION);
self.tearDown(&bck);
if (s >= beta) {
return beta;
}
}
}
var ml = try MoveList.init(0);
try self.state.generateMoves(&ml);
score.moves(self, &ml, best_move);
// Honor following the PV line
if (pv_follow) {
pv_follow = false;
for (ml.slice()) |*mp| {
if (@as(BitMoveType, @bitCast(pv_table[0][self.ply])) == @as(BitMoveType, @bitCast(mp.move))) {
pv_follow = true;
mp.score = score.PV_SCORE;
break;
}
}
}
std.mem.sort(MovePrio, ml.slice(), {}, comptime MovePrio.moreThan);
var scr: isize = undefined;
var tt_flag: Transposition.Flag = .lower;
var moves_searched: usize = 0;
for (ml.slice()) |mp| {
var bck: GameState = undefined;
self.state.backup(&bck);
self.ply += 1;
try self.state.history.put(self.state.hash, {});
if (!self.state.makeMove(mp.move, .all)) {
self.tearDown(&bck);
continue;
}
legal_moves += 1;
// Normal negamax without PVS or LMR
if (moves_searched == 0) {
scr = -try self.negamax(-beta, -alpha, depth - 1);
} else {
// Conditions to consider LMR
if (moves_searched >= FULL_DEPTH_MOVES and
depth >= REDUCTION_LIMIT and
!mp.move.capture and mp.move.prom == .none and
!self.state.inCheck())
{
// reduced-depth, narrow score search for lame moves
scr = -try self.negamax(-alpha - 1, -alpha, depth -| 3);
} else {
// hack to ensure that we continue with PVS code below
scr = alpha + 1;
}
// PVS code
if (scr > alpha) {
// Once you've found a move with a score that is between alpha and beta,
// the rest of the moves are searched with the goal of proving that they are all bad.
// It's possible to do this a bit faster than a search that worries that one
// of the remaining moves might be good.
scr = -try self.negamax(-alpha - 1, -alpha, depth - 1);
// If the algorithm finds out that it was wrong, and that one of the
// subsequent moves was better than the first PV move, it has to search again,
// in the normal alpha-beta manner. This happens sometimes, and it's a waste of time,
// but generally not often enough to counteract the savings gained from doing the
// "bad move proof" search referred to earlier.
if (scr > alpha and scr < beta) {
scr = -try self.negamax(-beta, -alpha, depth - 1);
}
}
}
self.tearDown(&bck);
moves_searched += 1;
// found a better move
if (scr > alpha) {
tt_flag = .exact;
best_move = mp.move;
// store history moves
if (!mp.move.capture) {
score.history_moves[@intFromEnum(mp.move.piece)][@intFromEnum(mp.move.target)] += depth;
}
// PV (principal variation) node
alpha = scr;
{
// write PV move
pv_table[self.ply][self.ply] = mp.move;
// copy move from deeper self.ply into the current self.ply's line
var next: usize = self.ply + 1;
while (next < pv_length[self.ply + 1]) : (next += 1) {
pv_table[self.ply][next] = pv_table[self.ply + 1][next];
}
// adjust pv_length
pv_length[self.ply] = pv_length[self.ply + 1];
}
// fail-hard beta cutoff
if (scr >= beta) {
self.tt.set(.{
.key = self.state.hash,
.score = self.adjustTTsetScore(beta),
.best = best_move,
.depth = @as(u8, @intCast(depth)),
.flag = .upper,
});
// store killer moves
if (!mp.move.capture) {
score.killer_moves[1][self.ply] = score.killer_moves[0][self.ply];
score.killer_moves[0][self.ply] = mp.move;
}
// node (move) fails high
return beta;
}
}
}
// Detect check- and stalemate
if (legal_moves == 0) {
if (in_check) {
// Adding self.ply helps the engine finding fastest mates.
return -eval.MATE_VALUE + @as(isize, @intCast(self.ply));
} else {
return eval.STALEMATE_SCORE;
}
}
self.tt.set(.{
.key = self.state.hash,
.score = self.adjustTTsetScore(alpha),
.best = best_move,
.depth = @as(u8, @intCast(depth)),
.flag = tt_flag,
});
// node (move) fails low
return alpha;
}
// quiescenceSearch helps the engine not thow away its pieces because of
// a temporary seemingly good move. (Horizon effect)
fn quiescenceSearch(self: *@This(), alpha_orig: isize, beta: isize) !isize {
if (nodes & CHECK_STOP == 0 and self.stopSearch()) return error.TimeIsUp;
nodes += 1;
var alpha = alpha_orig;
// avoid overflow of self.ply arrays
if (self.ply > MAX_PLY - 1) {
return EVAL_FUNCTION(self.state);
}
const evaluation = EVAL_FUNCTION(self.state);
// fail-hard beta cutoff
if (evaluation >= beta) {
// node (move) fails high
return beta;
}
// found a better move
if (evaluation > alpha) {
// PV (principal validation) node
alpha = evaluation;
}
if (self.tt.get(self.state.hash)) |entry| {
switch (entry.flag) {
.empty => unreachable,
.exact => return self.adjustTTgetScore(entry.score),
.lower => {
if (entry.score <= alpha_orig) return self.adjustTTgetScore(alpha_orig);
},
.upper => {
if (entry.score >= beta) return self.adjustTTgetScore(beta);
},
}
}
var ml = try MoveList.init(0);
try self.state.generateMoves(&ml);
score.moves(self, &ml, @as(BitMove, @bitCast(@as(BitMoveType, 0))));
std.mem.sort(MovePrio, ml.slice(), {}, comptime MovePrio.moreThan);
for (ml.slice()) |mp| {
var bck: GameState = undefined;
self.state.backup(&bck);
self.ply += 1;
if (!self.state.makeMove(mp.move, .captures)) {
self.ply -= 1;
self.state.restore(&bck);
continue;
}
const scr = -try self.quiescenceSearch(-beta, -alpha);
self.ply -= 1;
self.state.restore(&bck);
// found a better move
if (scr > alpha) {
// PV (principal validation) node
alpha = scr;
// fail-hard beta cutoff
if (scr >= beta) {
// node (move) fails high
return beta;
}
}
}
// node (move) fails low
return alpha;
}
fn print_info(self: *@This(), scr: isize, depth: usize) !void {
var ret = std.ArrayList(u8).init(self.state.allocator);
defer ret.deinit();
const time = @as(f32, @floatFromInt(self.timer.read())) / std.time.ns_per_ms;
if (try std.math.absInt(scr) > eval.CHECKMATE_SCORE) {
const info = try std.fmt.allocPrint(
self.state.allocator,
"info score mate {d} depth {d} nodes {d} time {d:.0} nps {d:.0} pv ",
.{ scr, depth, nodes, time, (@as(f32, @floatFromInt(nodes)) / time) * std.time.ms_per_s },
);
try ret.appendSlice(info);
} else {
const info = try std.fmt.allocPrint(
self.state.allocator,
"info score cp {d} depth {d} nodes {d} time {d:.0} nps {d:.0} pv ",
.{ scr, depth, nodes, time, (@as(f32, @floatFromInt(nodes)) / time) * std.time.ms_per_s },
);
try ret.appendSlice(info);
}
// loop over the pv_line
var count: usize = 0;
while (count < pv_length[0]) : (count += 1) {
// print PV move
const m = pv_table[0][count];
const move = try std.fmt.allocPrint(
self.state.allocator,
"{s}{s} ",
.{ @tagName(m.source), @tagName(m.target) },
);
try ret.appendSlice(move);
}
try ret.append('\n');
_ = try std.io.getStdOut().write(ret.items);
}
};