use std::{
    cell::RefCell,
    cmp::Ordering,
    collections::HashSet,
    fmt::{self, Display},
    hash::Hash,
    ops::{DerefMut, SubAssign},
};

use itertools::Itertools;
use owo_colors::OwoColorize;
use rand::{seq::SliceRandom, Rng};
use rand_pcg::Pcg64Mcg;

thread_local! {
    static RNG: RefCell<Pcg64Mcg> = RefCell::new(Pcg64Mcg::new(0xcafef00dd15ea5e5));
}

/// One card of a Set game.
///
/// This struct represents one card of a Set game.
/// To model set games with a different number of
/// attributes than 4, the const generic DIM can be changed.
#[derive(PartialEq, Eq, Debug, Clone, Copy, Hash, Ord)]
pub struct Card<const DIM: usize = 4>([u8; DIM]);

impl<const DIM: usize> Card<DIM> {
    /// Constructs a new card.
    pub fn new(value: [u8; DIM]) -> Self {
        if DIM == 0 {
            panic!("the game needs one dimension");
        }
        if value.iter().any(|v| *v > 2) {
            panic!("card values have to be 0, 1 or 2");
        }
        Self(value)
    }

    /// Returns an iterator going over all cards
    pub fn all() -> impl Iterator<Item = Self> {
        (0..DIM)
            .map(|_| 0..3)
            .multi_cartesian_product()
            .map(TryInto::<[u8; DIM]>::try_into)
            .map(Result::unwrap)
            .map(Self::new)
    }

    /// Given two [Cards](Card), returns the third [Card] needed to form a [Set].
    pub fn missing_for_set(a: Self, b: Self) -> Self {
        let mut result = [0; DIM];
        for i in 0..DIM {
            if a.0[i] == b.0[i] {
                result[i] = a.0[i];
            } else {
                result[i] = 3 - (a.0[i] + b.0[i]);
            }
        }
        Self(result)
    }

    /// Given a [CardSelection], finds all [Sets](Set) that can be built from self and two cards
    /// from the selection.
    pub fn sets_with(self, cards: &CardSelection<DIM>) -> impl Iterator<Item = Set<DIM>> + '_ {
        cards
            .iter()
            .copied()
            .map(move |card| [self, card, Self::missing_for_set(self, card)])
            .filter(move |maybe_set| cards.contains(&maybe_set[2]))
            .map(Set::new)
            .unique()
    }

    /// Checks whether a set can be constructed from self and two cards from the provided
    /// [CardSelection].
    pub fn has_sets_with(self, cards: &CardSelection<DIM>) -> bool {
        self.sets_with(cards).next().is_some()
    }
}

impl<const DIM: usize> PartialOrd for Card<DIM> {
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
        for i in (0..DIM).rev() {
            let res = self.0[i].cmp(&other.0[i]);
            if res != Ordering::Equal {
                return Some(res);
            }
        }
        Some(Ordering::Equal)
    }
}

impl<const DIM: usize> Display for Card<DIM> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let symbols = [
            [["", "", ""], ["", "", ""], ["", "", ""]],
            [["○○", "◇◇", "□□"], ["◍◍", "◈◈", "▥▥"], ["●●", "◆◆", "■■"]],
            [
                ["○○○", "◇◇◇", "□□□"],
                ["◍◍◍", "◈◈◈", "▥▥▥"],
                ["●●●", "◆◆◆", "■■■"],
            ],
        ];

        if DIM == 0 {
            panic!("the game needs one dimension")
        }
        let amount = self.0[0] as usize;
        let shape = if DIM >= 2 { self.0[1] } else { 0 } as usize;
        let fill = if DIM >= 3 { self.0[2] } else { 0 } as usize;
        let symbol = symbols[amount][fill][shape];
        if DIM < 4 {
            f.write_str(symbol)
        } else {
            f.write_str(&match self.0[3] {
                0 => symbol.red().to_string(),
                1 => symbol.green().to_string(),
                2 => symbol.blue().to_string(),
                _ => panic!("value cannot be more than 3"),
            })
        }
    }
}

/// A selection of [Cards](Card).
///
/// Stores any set of [Cards](Card). Can also be empty.
/// Not named Set to avoid confusion with [Sets](Set).
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CardSelection<const DIM: usize>(HashSet<Card<DIM>>);

impl<const DIM: usize> CardSelection<DIM> {
    /// Returns an empty [CardSelection].
    pub fn new() -> Self {
        Self(HashSet::new())
    }

    /// Returns `true`, iff no new card could be inserted to this [CardSelection] without forming
    /// a new [Set].
    ///
    /// This does not check if there already are [Sets](Set) within this [CardSelection], use
    /// [CardSelection::has_sets] to check if needed.
    pub fn no_sets_is_maximal(&self) -> bool {
        Self::no_sets_possible_extensions(self).is_empty()
    }

    /// Returns all cards that could be added to this selection without forming new sets.
    pub fn no_sets_possible_extensions(&self) -> CardSelection<DIM> {
        CardSelection(
            Card::all()
                .filter(|card| !self.clone().contains(card))
                .filter(|card| !card.has_sets_with(self))
                .collect::<HashSet<Card<DIM>>>(),
        )
    }

    /// Returns a random card from the selection.
    pub fn pick_random(&self) -> Option<Card<DIM>> {
        if self.is_empty() {
            return None;
        }
        let index = RNG.with(|rng| rng.borrow_mut().gen_range(0..self.len()));
        self.iter().nth(index).copied()
    }

    /// Returns a sorted iterator over the contained cards.
    pub fn sorted(&self) -> impl Iterator<Item = Card<DIM>> {
        let mut vec = self.0.iter().copied().collect_vec();
        vec.sort_unstable();
        vec.into_iter()
    }

    /// Find all sets in this selection.
    pub fn sets(&self) -> impl Iterator<Item = Set<DIM>> + '_ {
        self.iter()
            .copied()
            .combinations(2)
            .map(|pair| [pair[0], pair[1], Card::missing_for_set(pair[0], pair[1])])
            .filter(move |maybe_set| self.contains(&maybe_set[2]))
            .map(Set::new)
            .unique()
    }

    /// Check if this selection has any sets.
    pub fn has_sets(&self) -> bool {
        self.sets().next().is_some()
    }

    /// Check if this selection contains some card.
    pub fn contains(&self, value: &Card<DIM>) -> bool {
        self.0.contains(value)
    }

    /// Returns the amount of cards in this selection.
    pub fn len(&self) -> usize {
        self.0.len()
    }

    /// Removes a card from this selection.
    pub fn remove(&mut self, value: &Card<DIM>) -> bool {
        self.0.remove(value)
    }

    /// Inserts a card to this selection.
    ///
    /// Returns true if the card was already contained.
    pub fn insert(&mut self, value: Card<DIM>) -> bool {
        self.0.insert(value)
    }

    pub fn iter(&self) -> impl Iterator<Item = &Card<DIM>> {
        self.0.iter()
    }

    /// Check if the selection is empty.
    pub fn is_empty(&self) -> bool {
        self.0.is_empty()
    }
}

impl<const DIM: usize> Extend<Card<DIM>> for CardSelection<DIM> {
    fn extend<T: IntoIterator<Item = Card<DIM>>>(&mut self, iter: T) {
        for card in iter.into_iter() {
            self.insert(card);
        }
    }
}

impl<'a, const DIM: usize> Extend<&'a Card<DIM>> for CardSelection<DIM> {
    fn extend<T: IntoIterator<Item = &'a Card<DIM>>>(&mut self, iter: T) {
        for card in iter.into_iter() {
            self.insert(*card);
        }
    }
}

impl<const DIM: usize> FromIterator<Card<DIM>> for CardSelection<DIM> {
    fn from_iter<T: IntoIterator<Item = Card<DIM>>>(iter: T) -> Self {
        Self(iter.into_iter().collect())
    }
}

impl<const DIM: usize> SubAssign<Set<DIM>> for CardSelection<DIM> {
    fn sub_assign(&mut self, rhs: Set<DIM>) {
        self.remove(&rhs.0[0]);
        self.remove(&rhs.0[1]);
        self.remove(&rhs.0[2]);
    }
}

impl<const DIM: usize> Display for CardSelection<DIM> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.write_str(
            &self
                .0
                .iter()
                .map(Card::to_string)
                .reduce(|a, b| a + " " + &b)
                .unwrap(),
        )
    }
}

#[derive(Debug, Clone)]
pub struct Deck<const DIM: usize = 4>(Vec<Card<DIM>>);

impl<const DIM: usize> Deck<DIM> {
    pub fn new() -> Self {
        let mut deck = Self(Card::all().collect());
        RNG.with(|rng| deck.0.shuffle(rng.borrow_mut().deref_mut()));
        deck
    }

    pub fn draw(&mut self, amount: usize) -> impl Iterator<Item = Card<DIM>> + '_ {
        if amount == 0 {
            panic!("cannot draw 0 cards");
        }
        let start = self.0.len()
            - if amount < self.len() {
                amount
            } else {
                self.len()
            };
        let end = if amount == 0 { start } else { self.0.len() };
        self.0.drain(start..end)
    }

    pub fn reshuffle(&mut self, cards: &CardSelection<DIM>) {
        for card in cards.iter().copied() {
            self.0.push(card);
        }
        RNG.with(|rng| self.0.shuffle(rng.borrow_mut().deref_mut()));
    }

    pub fn has_sets(&self) -> bool {
        self.0
            .iter()
            .copied()
            .collect::<CardSelection<DIM>>()
            .has_sets()
    }

    pub fn len(&self) -> usize {
        self.0.len()
    }
}

impl<const DIM: usize> AsRef<[Card<DIM>]> for Deck<DIM> {
    fn as_ref(&self) -> &[Card<DIM>] {
        &self.0
    }
}

/// A set according to set rules.
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
pub struct Set<const DIM: usize>([Card<DIM>; 3]);

impl<const DIM: usize> Set<DIM> {
    fn new(mut cards: [Card<DIM>; 3]) -> Self {
        cards.sort_unstable();
        Self(cards)
    }

    /// Count the number of attributes that are the same across the cards in this set
    pub fn same_count(&self) -> u8 {
        self.0[0]
            .0
            .iter()
            .zip(self.0[1].0.iter())
            .map(|(a, b)| (*a == *b) as u8)
            .sum()
    }
}

impl<const DIM: usize> Display for Set<DIM> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.write_fmt(format_args!(
            "{{ {} {} {} }}",
            self.0[0], self.0[1], self.0[2]
        ))
    }
}

#[cfg(test)]
mod test {
    use std::collections::HashSet;

    use itertools::Itertools;

    use crate::{Card, CardSelection, Set};

    #[test]
    fn same_count() {
        let set = Set::new([Card([0, 0, 0, 0]), Card([0, 0, 1, 2]), Card([0, 0, 2, 1])]);
        assert_eq!(set.same_count(), 2);
    }

    #[test]
    fn sets() {
        let sets = [
            Set([Card([0, 1, 0, 2]), Card([1, 1, 0, 2]), Card([2, 1, 0, 2])]),
            Set([Card([0, 1, 0, 0]), Card([1, 1, 2, 1]), Card([2, 1, 1, 2])]),
            Set([Card([0, 1, 0, 0]), Card([1, 1, 2, 1]), Card([2, 1, 1, 2])]),
            Set([Card([0, 1, 2, 0]), Card([2, 0, 0, 1]), Card([1, 2, 1, 2])]),
            Set([Card([1, 2, 0, 0]), Card([0, 2, 0, 1]), Card([2, 2, 0, 2])]),
            Set([Card([2, 0, 2, 0]), Card([0, 2, 0, 1]), Card([1, 1, 1, 2])]),
            Set([Card([2, 0, 0, 0]), Card([0, 0, 1, 1]), Card([1, 0, 2, 2])]),
            Set([Card([2, 0, 0, 0]), Card([0, 0, 1, 1]), Card([1, 0, 2, 2])]),
        ];

        let unique = vec![
            Set([Card([0, 1, 0, 2]), Card([1, 1, 0, 2]), Card([2, 1, 0, 2])]),
            Set([Card([0, 1, 0, 0]), Card([1, 1, 2, 1]), Card([2, 1, 1, 2])]),
            Set([Card([0, 1, 2, 0]), Card([2, 0, 0, 1]), Card([1, 2, 1, 2])]),
            Set([Card([1, 2, 0, 0]), Card([0, 2, 0, 1]), Card([2, 2, 0, 2])]),
            Set([Card([2, 0, 2, 0]), Card([0, 2, 0, 1]), Card([1, 1, 1, 2])]),
            Set([Card([2, 0, 0, 0]), Card([0, 0, 1, 1]), Card([1, 0, 2, 2])]),
        ];

        assert_eq!(unique, sets.into_iter().unique().collect_vec());
    }

    #[test]
    fn test_missing_for_set() {
        let a = Card([0, 0, 0]);
        let b = Card([2, 1, 2]);
        assert_eq!(Card::missing_for_set(a, b), Card([1, 2, 1]));

        let a = Card([0, 0, 0]);
        let b = Card([0, 0, 1]);
        assert_eq!(Card::missing_for_set(a, b), Card([0, 0, 2]));

        let a = Card([2, 1, 1]);
        let b = Card([0, 2, 1]);
        assert_eq!(Card::missing_for_set(a, b), Card([1, 0, 1]));
    }

    #[test]
    fn find_sets() {
        let selection = CardSelection(HashSet::from([
            Card([0, 0, 0]),
            Card([0, 1, 2]),
            Card([1, 2, 0]),
            Card([1, 1, 1]),
            Card([1, 0, 1]),
            Card([0, 2, 1]),
        ]));

        let sets = vec![Set::new([
            Card([0, 0, 0]),
            Card([0, 1, 2]),
            Card([0, 2, 1]),
        ])];

        assert_eq!(sets, selection.sets().collect::<Vec<Set<3>>>());
    }

    #[test]
    fn find_sets_with() {
        let card = Card([0, 0, 0]);
        let selection = CardSelection(HashSet::from([
            Card([0, 1, 2]),
            Card([1, 2, 0]),
            Card([1, 1, 1]),
            Card([1, 0, 1]),
            Card([0, 2, 1]),
        ]));

        let sets = vec![Set::new([
            Card([0, 0, 0]),
            Card([0, 1, 2]),
            Card([0, 2, 1]),
        ])];

        assert_eq!(sets, card.sets_with(&selection).collect::<Vec<Set<3>>>());
    }

    #[test]
    fn has_sets() {
        let selection = CardSelection(HashSet::from([
            Card([0, 0, 0]),
            Card([0, 1, 2]),
            Card([1, 2, 0]),
            Card([1, 1, 1]),
            Card([1, 0, 1]),
            Card([0, 2, 1]),
        ]));
        assert!(selection.has_sets());
    }

    #[test]
    fn has_sets_with() {
        let card = Card([0, 0, 0]);
        let selection = CardSelection(HashSet::from([
            Card([0, 1, 2]),
            Card([1, 2, 0]),
            Card([1, 1, 1]),
            Card([1, 0, 1]),
            Card([0, 2, 1]),
        ]));
        assert!(card.has_sets_with(&selection));

        let card = Card([0, 1, 2, 1]);
        let selection = CardSelection(HashSet::from([
            Card([0, 0, 1, 2]),
            Card([1, 1, 2, 2]),
            Card([0, 0, 1, 1]),
            Card([2, 2, 2, 2]),
            Card([1, 2, 2, 2]),
            Card([0, 0, 0, 0]),
            Card([0, 0, 0, 1]),
            Card([0, 1, 1, 2]),
            Card([0, 1, 1, 1]),
            Card([0, 1, 2, 2]),
        ]));

        assert!(!card.has_sets_with(&selection));
    }

    #[test]
    fn is_maximal() {
        let selection = CardSelection(HashSet::from([
            Card([0, 0, 1, 2]),
            Card([1, 1, 2, 2]),
            Card([0, 0, 1, 1]),
            Card([2, 2, 2, 2]),
            Card([1, 2, 2, 2]),
            Card([0, 0, 0, 0]),
            Card([0, 0, 0, 2]),
            Card([0, 1, 1, 2]),
            Card([0, 1, 1, 1]),
            Card([0, 1, 2, 2]),
        ]));
        assert!(!selection.no_sets_is_maximal());
    }

    #[test]
    fn all() {
        let all_cards = CardSelection(HashSet::from([
            Card([0, 0]),
            Card([0, 1]),
            Card([0, 2]),
            Card([1, 0]),
            Card([1, 1]),
            Card([1, 2]),
            Card([2, 0]),
            Card([2, 1]),
            Card([2, 2]),
        ]));

        assert_eq!(all_cards, Card::<2>::all().collect());
    }
}