library dartmcts;

import 'dart:math';
import 'dart:developer' as d;

class InvalidMove implements Exception {}

abstract class RewardProvider<PlayerType> {
  Map<PlayerType, double> rewards();
}

abstract class GameState<MoveType, PlayerType> {
  GameState<MoveType, PlayerType> cloneAndApplyMove(
      MoveType move, Node<MoveType, PlayerType>? root);
  List<MoveType> getMoves();
  GameState<MoveType, PlayerType>? determine(
      GameState<MoveType, PlayerType>? initialState);
  PlayerType? winner;
  PlayerType? currentPlayer;
  Map<String, dynamic> toJson();
}

class NNPVResult<MoveType> {
  Map<MoveType, double> probabilities;
  Map<MoveType, double> qs;
  double value;

  NNPVResult(
      {required this.probabilities, required this.qs, required this.value});
}

abstract class NeuralNetworkPolicyAndValue<MoveType, PlayerType> {
  NNPVResult<MoveType> getResult(GameState<MoveType, PlayerType?> game);
}

class Config<MoveType, PlayerType> {
  late double c;
  NeuralNetworkPolicyAndValue<MoveType, PlayerType>? nnpv;
  double? valueThreshold;
  int? useValueAfterDepth;
  late Random random;
  PlayerType Function(PlayerType)? opponentWinsShortCircuit;
  bool useRewards;

  Config({
    double? c,
    this.nnpv,
    this.valueThreshold,
    this.useValueAfterDepth,
    this.opponentWinsShortCircuit,
    Random? random,
    this.useRewards = false,
  }) {
    this.random = random ?? Random();
    this.c = c ?? 1.41421356237; // square root of 2
  }
}

class Node<MoveType, PlayerType> {
  GameState<MoveType?, PlayerType?>? gameState;
  Node<MoveType, PlayerType>? root;
  final Node<MoveType, PlayerType>? parent;
  final MoveType? move;
  int visits;
  final int depth;
  final Map<PlayerType, double> winsByPlayer = {};
  int draws;
  final GameState? initialState;
  bool needStateReset = false;
  Map<MoveType?, Node<MoveType, PlayerType>> _children = {};
  NNPVResult<MoveType>? _nnpvResult;
  Config<MoveType, PlayerType> config;
  double q = 0;

  Node({
    this.gameState,
    this.parent,
    this.move,
    this.visits = 0,
    this.depth = 0,
    this.draws = 0,
    root,
    required this.config,
    this.q = 0,
  }) : initialState = gameState {
    this.root ??= this;
  }

  determine() {
    gameState = gameState!
        .determine(initialState as GameState<MoveType?, PlayerType?>?);
  }

  resetState() {
    needStateReset = true;
  }

  addNewChildrenForDetermination(List<MoveType?> moves) {
    for (var move in moves) {
      if (_children.containsKey(move)) {
        continue;
      }
      _children[move] = Node(
        gameState: gameState,
        config: config,
        move: move,
        parent: this,
        root: root,
        depth: depth + 1,
      );
    }
  }

  Map<MoveType?, Node<MoveType, PlayerType?>> get children {
    // This GameState might not be selected during a simulation so we only generate
    // the children when necessary
    if (_children.isEmpty || needStateReset) {
      if (move != null) {
        gameState = initialState!.cloneAndApplyMove(move, root)
            as GameState<MoveType?, PlayerType?>?;
      }
    }
    var moves = gameState!.getMoves();
    addNewChildrenForDetermination(moves);
    return Map.fromEntries(
        _children.entries.where((x) => moves.contains(x.value.move)));
  }

  double ucb1(PlayerType player, double priorScore) {
    if (parent == null || visits == 0) {
      return 0;
    }
    if (priorScore == 1.0) {
      return ((winsByPlayer[player] ?? 0 + (draws * 0.5)) / visits) +
          (config.c * sqrt(log(parent!.visits.toDouble()) / visits));
    } else {
      // Q[s][a] + c_puct*P[s][a]*sqrt(sum(N[s]))/(1+N[s][a])
      return q +
          config.c *
              priorScore *
              sqrt(parent!.visits.toDouble()) /
              (1.0 + visits.toDouble());
    }
  }

  PlayerType? getWinner() {
    return gameState!.winner;
  }

  PlayerType? currentPlayer() {
    return gameState!.currentPlayer;
  }

  NNPVResult<MoveType> get nnpvResult {
    if (_nnpvResult == null && config.nnpv != null) {
      _nnpvResult =
          config.nnpv!.getResult(gameState as GameState<MoveType, PlayerType?>);
    }
    return _nnpvResult!;
  }

  Node<MoveType, PlayerType?> getBestChild() {
    var player = currentPlayer();
    var sortedChildren = children.entries.toList();
    sortedChildren.sort((a, b) {
      var aVisits = a.value.visits;
      var bVisits = b.value.visits;
      if (config.nnpv != null && (aVisits == 0 && bVisits == 0)) {
        return (nnpvResult.probabilities[b.key] ?? 0)
            .compareTo(nnpvResult.probabilities[a.key] ?? 0);
      }
      if (aVisits == 0 && bVisits == 0) {
        return config.random.nextInt(100).compareTo(config.random.nextInt(100));
      }
      if (aVisits == 0) {
        return -1;
      }
      if (bVisits == 0) {
        return 1;
      }
      double bScore = b.value.ucb1(player,
          config.nnpv != null ? (nnpvResult.probabilities[b.key] ?? 1.0) : 1.0);
      double aScore = a.value.ucb1(player,
          config.nnpv != null ? (nnpvResult.probabilities[a.key] ?? 1.0) : 1.0);
      return bScore.compareTo(aScore);
    });
    List<MapEntry<MoveType?, Node<MoveType, PlayerType?>>> tiedChildren = [];
    for (var x in sortedChildren) {
      if (x.value.visits == sortedChildren.first.value.visits) {
        tiedChildren.add(x);
      }
    }
    tiedChildren.shuffle(config.random);
    return tiedChildren.first.value;
  }

  backProp(PlayerType? winner) {
    Node<MoveType, PlayerType?>? currentNode = this;

    while (currentNode != null) {
      double reward = 0.0;
      if (winner == null) {
        currentNode.draws += 1;
        reward = 0.5;
      } else {
        currentNode.winsByPlayer
            .update(winner, (value) => value + 1, ifAbsent: () => 1);
        reward = currentNode.parent?.currentPlayer() == winner ? 1 : 0;
      }
      // Q[s][a] = (N[s][a]*Q[s][a] + v)/(N[s][a]+1)
      currentNode.q = (currentNode.visits * currentNode.q + reward) /
          (currentNode.visits + 1.0);
      currentNode.visits += 1;
      currentNode = currentNode.parent;
    }
  }

  rewardBackProp(Map<PlayerType, double> rewards) {
    Node<MoveType, PlayerType?>? currentNode = this;

    while (currentNode != null) {
      rewards.forEach((player, reward) {
        currentNode?.winsByPlayer
            .update(player, (value) => value + reward, ifAbsent: () => reward);
      });
      var currentPlayerReward = rewards[currentPlayer()]!;
      // Q[s][a] = (N[s][a]*Q[s][a] + v)/(N[s][a]+1)
      currentNode.q =
          (currentNode.visits * currentNode.q + currentPlayerReward) /
              (currentNode.visits + 1.0);
      currentNode.visits += 1;
      currentNode = currentNode.parent;
    }
  }

  Node<MoveType, PlayerType?> getMostVisitedChild(
      [List<MoveType>? actualMoves]) {
    var currentChildren = children;
    if (actualMoves != null) {
      addNewChildrenForDetermination(actualMoves);
      currentChildren = Map.fromEntries(
          _children.entries.where((x) => actualMoves.contains(x.value.move)));
    }
    var sortedChildren = currentChildren.entries.toList();
    sortedChildren.sort((b, a) => a.value.visits.compareTo(b.value.visits));
    return sortedChildren.first.value;
  }
}

class MCTSResult<MoveType, PlayerType> {
  final Node<MoveType, PlayerType>? root;
  final MoveType? move;
  final List<Node<MoveType, PlayerType>>? leafNodes;
  final int? maxDepth;
  final int? plays;
  MCTSResult({this.root, this.move, this.leafNodes, this.maxDepth, this.plays});
}

class MCTS<MoveType, PlayerType> {
  GameState<MoveType, PlayerType>? gameState;

  MCTS({this.gameState});
  MCTSResult<MoveType, PlayerType> getSimulationResult({
    Node<MoveType, PlayerType>? initialRootNode,
    int iterations = 100,
    double? maxSeconds,
    List<MoveType>? actualMoves,
    NeuralNetworkPolicyAndValue<MoveType, PlayerType>? nnpv,
    double? c,
    int? useValueAfterDepth,
    double? valueThreshold,
    Random? random,
    PlayerType Function(PlayerType)? opponentWinsShortCircuit,
    bool useRewards = false,
    bool resetDepth = true,
  }) {
    var rootNode = initialRootNode;
    Config<MoveType, PlayerType> config = Config(
        c: c,
        nnpv: nnpv,
        useValueAfterDepth: useValueAfterDepth,
        valueThreshold: valueThreshold,
        random: random,
        opponentWinsShortCircuit: opponentWinsShortCircuit,
        useRewards: useRewards);
    if (rootNode == null) {
      rootNode = Node(
        gameState: gameState,
        parent: null,
        move: null,
        config: config,
      );
    } else {
      rootNode.resetState();
    }
    rootNode.config = config;
    var plays = 0;
    var maxDepth = 0;
    var currentDepth = 0;
    var startTime = DateTime.now();

    var iterationsToRun = iterations;
    if (maxSeconds != null) {
      iterationsToRun = 9223372; // really big integer
    }

    while (plays < iterationsToRun) {
      rootNode.determine();
      if (resetDepth) currentDepth = 0;
      if (maxSeconds != null) {
        var elapsedTime = DateTime.now().difference(startTime);
        if (elapsedTime.inSeconds > maxSeconds.toInt()) {
          break;
        }
      }
      plays += 1;
      Node<MoveType, PlayerType?> currentNode = rootNode;

      PlayerType? winner;

      while (currentNode.children.length > 0 &&
          currentNode.gameState?.winner == null) {
        currentNode = currentNode.getBestChild();
        currentNode.resetState();
        if (currentNode.gameState?.winner != null) {
          winner = currentNode.gameState?.winner;
          break;
        }
        winner = getShortcutWinner(currentDepth, config, currentNode);
        if (winner != null) {
          break;
        }
        currentDepth += 1;
      }

      if (gameState is RewardProvider && config.useRewards) {
        currentNode.rewardBackProp(
            (gameState as RewardProvider<PlayerType>).rewards());
      } else {
        currentNode.backProp(winner);
      }
      maxDepth = max(maxDepth, currentNode.depth);
    }

    var selectedMove = rootNode.getMostVisitedChild(actualMoves).move;

    return MCTSResult(
        root: rootNode, move: selectedMove, maxDepth: maxDepth, plays: plays);
  }

  PlayerType? getShortcutWinner(int currentDepth, Config config,
      Node<MoveType, PlayerType?> currentNode) {
    if (config.nnpv == null &&
        config.useRewards == true &&
        gameState is RewardProvider) {
      if (currentDepth >= config.useValueAfterDepth!) {
        var rewards = (gameState as RewardProvider).rewards();
        var sortedRewards = List.from(rewards.values);
        sortedRewards.sort();
        var highestReward = sortedRewards.last;
        for (var player in rewards.keys) {
          if (highestReward == rewards[player]) {
            return player;
          }
        }
      }
    }
    if (config.nnpv != null &&
        config.useValueAfterDepth != null &&
        config.valueThreshold != null) {
      if (currentDepth >= config.useValueAfterDepth!) {
        d.log('currentDepth: $currentDepth');
        double currentValue = currentNode.nnpvResult.value;
        if (currentValue >= config.valueThreshold!) {
          return currentNode.gameState!.currentPlayer;
        } else {
          if (config.opponentWinsShortCircuit != null) {
            return config.opponentWinsShortCircuit
                ?.call(currentNode.gameState!.currentPlayer);
          }
        }
      }
    }
    return null;
  }
}