import 'package:test/test.dart';

import 'package:dartmcts/dartmcts.dart';
import 'package:dartmcts/tictactoe.dart';

enum Player { FIRST, SECOND }

enum Move { WIN }

class GameWithOneMove implements GameState<Move, Player> {
  Player? currentPlayer;
  Player? winner;
  Map<Player, int> scores = {};

  GameWithOneMove(
      {this.winner, required this.scores, this.currentPlayer = Player.FIRST});

  @override
  GameWithOneMove cloneAndApplyMove(Move move, Node<Move, Player>? root) {
    var newScores = {
      Player.FIRST: 1,
      Player.SECOND: 0,
    };
    return GameWithOneMove(
        winner: currentPlayer, scores: newScores, currentPlayer: Player.SECOND);
  }

  @override
  List<Move> getMoves() {
    if (winner == null) {
      return [Move.WIN];
    }
    return [];
  }

  @override
  GameState<Move, Player> determine(GameState<Move, Player>? initialState) {
    return this;
  }

  @override
  Map<String, dynamic> toJson() {
    throw UnimplementedError();
  }
}

enum ScoringMove { SCORE_5, SCORE_10, SCORE_100 }

class GameWithScore implements GameState<ScoringMove, Player?> {
  Player? currentPlayer = Player.FIRST;
  Map<Player?, int> scores = {Player.FIRST: 0, Player.SECOND: 0};
  Player? winner;
  int round = 0;

  GameWithScore(
      {this.winner,
      required this.scores,
      this.round = 0,
      this.currentPlayer = Player.FIRST});

  @override
  GameWithScore cloneAndApplyMove(
      ScoringMove move, Node<ScoringMove, Player?>? root) {
    var newPlayer, newScores, newWinner;
    newScores = new Map<Player, int>.from(scores);

    // process move
    switch (move) {
      case ScoringMove.SCORE_5:
        newScores.update(currentPlayer, (int score) => score + 5,
            ifAbsent: () => 5);
        break;
      case ScoringMove.SCORE_10:
        newScores.update(currentPlayer, (int score) => score + 10,
            ifAbsent: () => 10);
        break;
      case ScoringMove.SCORE_100:
        newScores.update(currentPlayer, (int score) => score + 100,
            ifAbsent: () => 100);
        break;
    }

    // change current player for the next play
    if (currentPlayer == Player.FIRST) {
      newPlayer = Player.SECOND;
    } else {
      newPlayer = Player.FIRST;
    }

    // check win conditions
    if (newScores[Player.FIRST] > 100 &&
        newScores[Player.FIRST] > newScores[Player.SECOND]) {
      newWinner = Player.FIRST;
    } else if (newScores[Player.SECOND] > 100) {
      newWinner = Player.SECOND;
    }

    return GameWithScore(
        winner: newWinner,
        round: round + 1,
        scores: newScores,
        currentPlayer: newPlayer);
  }

  @override
  List<ScoringMove> getMoves() {
    if (winner != null) {
      return [];
    }
    return [ScoringMove.SCORE_5, ScoringMove.SCORE_10, ScoringMove.SCORE_100];
  }

  @override
  GameState<ScoringMove, Player?> determine(
      GameState<ScoringMove, Player?>? initialState) {
    return this;
  }

  @override
  Map<String, dynamic> toJson() {
    throw UnimplementedError();
  }
}

class TestNNPV implements NeuralNetworkPolicyAndValue<int?, TicTacToePlayer> {
  @override
  NNPVResult<int?> getResult(GameState<int?, TicTacToePlayer?> game) {
    // pretend that the neural net thinks corner moves are good first moves

    Map<int?, double> probabilites = {};
    if (game.getMoves().length == 9) {
      probabilites = <int?, double>{
        0: 0.25,
        1: 0,
        2: 0.25,
        3: 0,
        4: 0,
        5: 0,
        6: 0.25,
        7: 0,
        8: 0.25,
      };
    }

    return NNPVResult(probabilities: probabilites, qs: {}, value: 0.0);
  }
}

void main() {
  test('game with one move works', () {
    var game = GameWithOneMove(scores: {});
    expect(game.getMoves(), equals([Move.WIN]));
    var result = MCTS(gameState: GameWithOneMove(scores: {}))
        .getSimulationResult(iterations: 10);
    expect(result.move, equals(Move.WIN));
    expect(
        result.root!.children.values.first.getWinner(), equals(Player.FIRST));
  });
  test('selects winning tic tac toe move (scenario 1)', () {
    var o = TicTacToePlayer.O;
    var x = TicTacToePlayer.X;
    var e;
    var oneMoveFromWinning = TicTacToeGame(
        board: [o, o, e, x, e, x, e, x, e], currentPlayer: o, scores: {});
    MCTSResult<int?, TicTacToePlayer> result =
        MCTS(gameState: oneMoveFromWinning)
            .getSimulationResult(iterations: 100);
    expect(result.root!.children.length, equals(4));
    expect(result.move, equals(2));
    expect(result.maxDepth, equals(4));
  });
  test('selects winning tic tac toe move (scenario 2)', () {
    var o = TicTacToePlayer.O;
    var x = TicTacToePlayer.X;
    var e;
    var oneMoveFromWinning = TicTacToeGame(
        board: [o, e, e, o, x, x, e, x, e], currentPlayer: o, scores: {});
    MCTSResult<int?, TicTacToePlayer> result =
        MCTS(gameState: oneMoveFromWinning)
            .getSimulationResult(iterations: 100);
    expect(result.root!.children.length, equals(4));
    expect(result.maxDepth, equals(4));
    expect(result.move, equals(6));
  });
  test('plays out a game from start to finish', () {
    for (var _ = 0; _ < 100; _++) {
      TicTacToeGame gameState = TicTacToeGame.newGame() as TicTacToeGame;
      while (gameState.getMoves().length > 0) {
        MCTSResult<int?, TicTacToePlayer> result =
            MCTS(gameState: gameState).getSimulationResult(iterations: 100);
        gameState = gameState.cloneAndApplyMove(result.move, result.root!);
      }
    }
  });
  test(
      'game with a score selects high scoring moves more frequently than low scoring moves',
      () {
    var gameState = GameWithScore(scores: {Player.FIRST: 0, Player.SECOND: 0});
    MCTSResult<ScoringMove, Player?> result =
        MCTS(gameState: gameState).getSimulationResult(iterations: 100);
    expect(result.root!.children.length, equals(3));
    expect(result.move, equals(ScoringMove.SCORE_100));
    expect(result.root!.children[ScoringMove.SCORE_100]?.visits ?? 0,
        greaterThan(result.root!.children[ScoringMove.SCORE_5]?.visits ?? 0));
    expect(result.root!.children[ScoringMove.SCORE_100]?.visits ?? 0,
        greaterThan(result.root!.children[ScoringMove.SCORE_10]?.visits ?? 0));
  });
  test('visits neural net prescribed nodes more frequently', () {
    var o = TicTacToePlayer.O;
    var e;
    var ttgg = TicTacToeGame(
        board: [e, e, e, e, e, e, e, e, e], currentPlayer: o, scores: {});
    MCTSResult<int?, TicTacToePlayer> result = MCTS(gameState: ttgg)
        .getSimulationResult(iterations: 100, nnpv: TestNNPV());
    expect(result.root!.children.length, equals(9));
    expect(result.maxDepth, equals(9));
    Map<int?, int> visits = {};
    result.root!.children.forEach((key, value) {
      visits[value.move] = value.visits;
    });
    print(visits);
    ttgg = TicTacToeGame(
        board: [e, e, e, e, e, e, e, e, e], currentPlayer: o, scores: {});
    result = MCTS(gameState: ttgg).getSimulationResult(iterations: 100);
    result.root!.children.forEach((key, value) {
      visits[value.move] = value.visits;
    });
    print(visits);
  });
}