const std = @import("std");

const max_score = 21;
const p1_start = 7;
const p2_start = 9;

pub fn main() !void {
    var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator);
    defer arena.deinit();

    var timer = try std.time.Timer.start();
    const ret = try second(arena.allocator());
    const t = timer.lap() / 1000;

    try std.testing.expectEqual(@as(usize, 433315766324816), ret);

    std.debug.print("Day 21b result: {d} time: {d}us\n", .{ ret, t });
}

const Cache = std.AutoHashMap(GameState, Wins);

pub fn second(allocator: ?std.mem.Allocator) !usize {
    var gs: GameState = .{};

    var cache = Cache.init(allocator.?);
    defer cache.deinit();

    var ret = try countWin(gs, &cache);

    return std.math.max(ret[0], ret[1]);
}

const GameState = struct {
    p1_pos: usize = p1_start - 1,
    p2_pos: usize = p2_start - 1,

    p1_score: usize = 0,
    p2_score: usize = 0,
};

const Wins = [2]usize;

fn countWin(g: GameState, cache: *Cache) anyerror!Wins {
    if (g.p1_score >= max_score) {
        return Wins{ 1, 0 };
    }

    if (g.p2_score >= max_score) {
        return Wins{ 0, 1 };
    }

    if (cache.get(g)) |v| return v;

    var ret: Wins = .{ 0, 0 };

    for ([_]u2{ 1, 2, 3 }) |d1| {
        for ([_]u2{ 1, 2, 3 }) |d2| {
            for ([_]u2{ 1, 2, 3 }) |d3| {
                const next_pos = (g.p1_pos + d1 + d2 + d3) % 10;
                const next_score = g.p1_score + next_pos + 1;

                var other_state = GameState{
                    .p1_pos = g.p2_pos,
                    .p2_pos = next_pos,
                    .p1_score = g.p2_score,
                    .p2_score = next_score,
                };
                const p2_turn = try countWin(other_state, cache);

                ret = .{ ret[0] + p2_turn[1], ret[1] + p2_turn[0] };
            }
        }
    }

    try cache.put(g, ret);

    return ret;
}

test "day21b" {
    try std.testing.expectEqual(@as(usize, 433315766324816), try second(std.testing.allocator));
}