use std::ops::Index;

use crate::element::{
  Element::{*, self}, ElementSet, LEGAL_ABSTRACT_DICE_ELEMENTS, LEGAL_ACTUAL_DICE_ELEMENTS, PURE_ELEMENTS,
};
use enum_map::EnumMap;
use rand::{seq::IteratorRandom, thread_rng};

pub trait Dices<Rhs = Self> {
  type Output;
  fn empty() -> Self;
  fn new_from_counter_iter<T>(dices: T) -> Self
  where
    T: Iterator<Item = (Element, i8)>;
  fn new_from_element_iter<T>(dices: T) -> Self
  where
    T: Iterator<Item = Element>;
  fn add_mut(&mut self, other: &Rhs);
  fn add(&self, other: &Rhs) -> Self::Output;
  fn sub_mut(&mut self, other: &Rhs);
  fn sub(&self, other: &Rhs) -> Self::Output;
  fn add_elements_mut<T>(&mut self, other: T)
  where
    T: Iterator<Item = Element>;
  fn add_elements<T>(&self, other: T) -> Self::Output
  where
    T: Iterator<Item = Element>;
  fn sub_elements_mut<T>(&mut self, other: T)
  where
    T: Iterator<Item = Element>;
  fn sub_elements<T>(&self, other: T) -> Self::Output
  where
    T: Iterator<Item = Element>;
  fn num_dices(&self) -> usize;
  fn is_even(&self) -> bool;
  fn is_odd(&self) -> bool;
  fn is_empty(&self) -> bool;
  fn is_legal(&self) -> bool;
  fn elem_set(&self) -> ElementSet;
  fn elems(&self) -> impl Iterator<Item = Element>;
  fn values(&self) -> impl Iterator<Item = i8>;
  fn elem_values(&self) -> impl Iterator<Item = (Element, i8)>;
  fn legal_elems() -> ElementSet;
  fn extract_with(&self, eset: ElementSet) -> Self::Output;
  fn pick_random(&self, x: usize) -> (Self, Self)
  where
    Self: Sized;
  fn contains(&self, elem: Element) -> bool;
}

fn check_is_legal(dices: &impl Dices, legal_elements: ElementSet) -> bool {
  dices
    .elem_values()
    .all(|(elem, count)| count > 0 && legal_elements.contains(elem) || count == 0)
}

impl ActualDices {
  fn satisfy(&self, requirement: &AbstractDices) -> bool {
    assert!(self.is_legal() && requirement.is_legal());
    let upgraded: AbstractDices = self.clone().into();
    let pure_deducted = upgraded.sub(requirement).extract_with(PURE_ELEMENTS);
    let omni_required = pure_deducted.values().filter(|x| *x < 0).sum::<i8>();

    if self[OMNI] < omni_required {
      return false;
    }

    let omni_remained = self[OMNI] - omni_required;
    let most_pure = pure_deducted.values().max().unwrap();
    if omni_remained + most_pure < requirement[OMNI] {
      return false;
    }

    true
  }

  fn loosely_satify(&self, requirement: &AbstractDices) -> bool {
    self.num_dices() >= requirement.num_dices() && self.satisfy(requirement)
  }

  fn just_satisfy(&self, requirement: &AbstractDices) -> bool {
    self.num_dices() == requirement.num_dices() && self.satisfy(requirement)
  }
}

impl ActualDices {
  pub const fn new_from_raw(x: EnumMap<Element, i8>) -> Self {
    ActualDices(x)
  }
}

impl AbstractDices {
  pub const fn new_from_raw(x: EnumMap<Element, i8>) -> Self {
    AbstractDices(x)
  }
}

impl Dices for EnumMap<Element, i8> {
  type Output = Self;
  fn empty() -> Self {
    EnumMap::default()
  }

  fn new_from_counter_iter<T>(dices: T) -> Self
  where
    T: Iterator<Item = (Element, i8)>,
  {
    let mut dices = dices;
    let mut map = EnumMap::default();
    while let Some((elem, count)) = dices.next() {
      map[elem] += count;
    }
    map
  }

  fn new_from_element_iter<T>(dices: T) -> Self
  where
    T: Iterator<Item = Element>,
  {
    let mut dices = dices;
    let mut map = EnumMap::default();
    while let Some(elem) = dices.next() {
      map[elem] += 1;
    }
    map
  }

  fn add_mut(&mut self, other: &Self) {
    for (elem, count) in other {
      self[elem] += count;
    }
  }
  fn add(&self, other: &Self) -> Self::Output {
    let mut new = self.clone();
    new.add_mut(other);
    new
  }
  fn sub_mut(&mut self, other: &Self) {
    for (elem, count) in other {
      self[elem] -= count;
    }
  }
  fn sub(&self, other: &Self) -> Self::Output {
    let mut new = self.clone();
    new.sub_mut(other);
    new
  }
  fn add_elements_mut<T>(&mut self, other: T)
  where
    T: Iterator<Item = Element>,
  {
    for elem in other {
      self[elem] += 1;
    }
  }
  fn add_elements<T>(&self, other: T) -> Self::Output
  where
    T: Iterator<Item = Element>,
  {
    let mut new = self.clone();
    new.add_elements_mut(other);
    new
  }

  fn sub_elements_mut<T>(&mut self, other: T)
  where
    T: Iterator<Item = Element>,
  {
    for elem in other {
      self[elem] -= 1;
    }
  }

  fn sub_elements<T>(&self, other: T) -> Self::Output
  where
    T: Iterator<Item = Element>,
  {
    let mut new = self.clone();
    new.sub_elements_mut(other);
    new
  }

  fn num_dices(&self) -> usize {
    assert!(self.is_legal());
    self.values().sum::<i8>() as usize
  }

  fn is_even(&self) -> bool {
    self.num_dices() % 2 == 0
  }

  fn is_odd(&self) -> bool {
    !self.is_even()
  }

  fn is_empty(&self) -> bool {
    self.num_dices() == 0
  }

  fn is_legal(&self) -> bool {
    unreachable!("EnumMap is always legal")
  }

  fn elem_set(&self) -> ElementSet {
    let mut es = ElementSet::empty();
    for (elem, count) in self {
      if *count > 0 {
        es.insert(elem);
      }
    }
    es
  }

  fn elems(&self) -> impl Iterator<Item = Element> {
    let mut es = Vec::new();
    for (elem, count) in self {
      let mut n = *count as usize;
      while n > 0 {
        es.push(elem);
        n -= 1;
      }
    }
    es.into_iter()
  }

  fn values(&self) -> impl Iterator<Item = i8> {
    let mut vs = Vec::new();
    for (_, count) in self {
      vs.push(*count);
    }
    vs.into_iter()
  }

  fn elem_values(&self) -> impl Iterator<Item = (Element, i8)> {
    self.into_iter().map(|(elem, count)| (elem, *count))
  }

  fn legal_elems() -> ElementSet {
    ElementSet::all()
  }

  fn extract_with(&self, eset: ElementSet) -> Self::Output {
    Self::new_from_counter_iter(self.elem_values().filter(|x| eset.contains(x.0)))
  }

  fn pick_random(&self, n: usize) -> (Self, Self) {
    let m = n.min(self.num_dices());
    if m == 0 {
      (self.clone(), Self::empty())
    } else {
      let mut trng = thread_rng();
      let dices = self.elems();
      let choosen = dices.into_iter().choose_multiple(&mut trng, m);
      let picked = Self::new_from_element_iter(choosen.into_iter());
      let new_dices = self.sub(&picked);
      (new_dices, picked)
    }
  }

  fn contains(&self, elem: Element) -> bool {
    self[elem] > 0
  }
}

#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ActualDices(EnumMap<Element, i8>);

impl From<EnumMap<Element, i8>> for ActualDices {
  fn from(map: EnumMap<Element, i8>) -> Self {
    Self(map)
  }
}

impl Dices for ActualDices {
  type Output = Self;
  fn empty() -> Self {
    Self(EnumMap::empty())
  }

  fn new_from_counter_iter<T>(dices: T) -> Self
  where
    T: Iterator<Item = (Element, i8)>,
  {
    EnumMap::new_from_counter_iter(dices).into()
  }

  fn new_from_element_iter<T>(dices: T) -> Self
  where
    T: Iterator<Item = Element>,
  {
    EnumMap::new_from_element_iter(dices).into()
  }

  fn add_mut(&mut self, other: &Self) {
    self.0.add_mut(&other.0)
  }

  fn add(&self, other: &Self) -> Self {
    self.0.add(&other.0).into()
  }

  fn sub_mut(&mut self, other: &Self) {
    self.0.sub_mut(&other.0)
  }

  fn sub(&self, other: &Self) -> Self {
    self.0.sub(&other.0).into()
  }

  fn add_elements_mut<T>(&mut self, other: T)
  where
    T: Iterator<Item = Element>,
  {
    self.0.add_elements_mut(other);
  }

  fn add_elements<T>(&self, other: T) -> Self
  where
    T: Iterator<Item = Element>,
  {
    self.0.add_elements(other).into()
  }

  fn sub_elements_mut<T>(&mut self, other: T)
  where
    T: Iterator<Item = Element>,
  {
    self.0.sub_elements_mut(other);
  }

  fn sub_elements<T>(&self, other: T) -> Self
  where
    T: Iterator<Item = Element>,
  {
    self.0.sub_elements(other).into()
  }

  fn num_dices(&self) -> usize {
    assert!(self.is_legal());
    self.0.num_dices()
  }

  fn is_even(&self) -> bool {
    self.0.is_even()
  }

  fn is_odd(&self) -> bool {
    self.0.is_odd()
  }

  fn is_empty(&self) -> bool {
    self.0.is_empty()
  }

  fn is_legal(&self) -> bool {
    check_is_legal(self, Self::legal_elems())
  }

  fn elem_set(&self) -> ElementSet {
    self.0.elem_set()
  }

  fn elems(&self) -> impl Iterator<Item = Element> {
    Dices::elems(&self.0)
  }

  fn values(&self) -> impl Iterator<Item = i8> {
    Dices::values(&self.0)
  }

  fn elem_values(&self) -> impl Iterator<Item = (Element, i8)> {
    Dices::elem_values(&self.0)
  }

  fn legal_elems() -> ElementSet {
    LEGAL_ACTUAL_DICE_ELEMENTS
  }

  fn extract_with(&self, eset: ElementSet) -> Self::Output {
    self.0.extract_with(eset).into()
  }

  fn pick_random(&self, x: usize) -> (Self, Self)
  where
    Self: Sized,
  {
    let (a, b) = self.0.pick_random(x);
    (a.into(), b.into())
  }

  fn contains(&self, elem: Element) -> bool {
    self.0.contains(elem)
  }
}

#[derive(Debug, Clone, Copy, PartialEq)]
pub struct AbstractDices(EnumMap<Element, i8>);

impl From<EnumMap<Element, i8>> for AbstractDices {
  fn from(map: EnumMap<Element, i8>) -> Self {
    Self(map)
  }
}

impl From<ActualDices> for AbstractDices {
  fn from(dices: ActualDices) -> Self {
    AbstractDices::new_from_counter_iter(dices.elem_values())
  }
}

impl Dices for AbstractDices {
  type Output = Self;
  fn empty() -> Self {
    Self(EnumMap::empty())
  }

  fn new_from_counter_iter<T>(dices: T) -> Self
  where
    T: Iterator<Item = (Element, i8)>,
  {
    EnumMap::new_from_counter_iter(dices).into()
  }

  fn new_from_element_iter<T>(dices: T) -> Self
  where
    T: Iterator<Item = Element>,
  {
    EnumMap::new_from_element_iter(dices).into()
  }

  fn add_mut(&mut self, other: &Self) {
    self.0.add_mut(&other.0)
  }

  fn add(&self, other: &Self) -> Self {
    self.0.add(&other.0).into()
  }

  fn sub_mut(&mut self, other: &Self) {
    self.0.sub_mut(&other.0)
  }

  fn sub(&self, other: &Self) -> Self::Output {
    self.0.sub(&other.0).into()
  }

  fn add_elements_mut<T>(&mut self, other: T)
  where
    T: Iterator<Item = Element>,
  {
    self.0.add_elements_mut(other);
  }

  fn add_elements<T>(&self, other: T) -> Self::Output
  where
    T: Iterator<Item = Element>,
  {
    self.0.add_elements(other).into()
  }

  fn sub_elements_mut<T>(&mut self, other: T)
  where
    T: Iterator<Item = Element>,
  {
    self.0.sub_elements_mut(other);
  }

  fn sub_elements<T>(&self, other: T) -> Self::Output
  where
    T: Iterator<Item = Element>,
  {
    self.0.sub_elements(other).into()
  }

  fn num_dices(&self) -> usize {
    assert!(self.is_legal());
    self.0.num_dices()
  }

  fn is_even(&self) -> bool {
    self.0.is_even()
  }

  fn is_odd(&self) -> bool {
    self.0.is_odd()
  }

  fn is_empty(&self) -> bool {
    self.0.is_empty()
  }

  fn is_legal(&self) -> bool {
    check_is_legal(self, Self::legal_elems())
  }

  fn elem_set(&self) -> ElementSet {
    self.0.elem_set()
  }

  fn elems(&self) -> impl Iterator<Item = Element> {
    Dices::elems(&self.0)
  }

  fn values(&self) -> impl Iterator<Item = i8> {
    Dices::values(&self.0)
  }

  fn elem_values(&self) -> impl Iterator<Item = (Element, i8)> {
    Dices::elem_values(&self.0)
  }

  fn legal_elems() -> ElementSet {
    LEGAL_ABSTRACT_DICE_ELEMENTS
  }

  fn extract_with(&self, eset: ElementSet) -> Self::Output {
    self.0.extract_with(eset).into()
  }

  fn pick_random(&self, x: usize) -> (Self, Self)
  where
    Self: Sized,
  {
    let (a, b) = self.0.pick_random(x);
    (a.into(), b.into())
  }

  fn contains(&self, elem: Element) -> bool {
    self.0.contains(elem)
  }
}

impl Index<Element> for ActualDices {
  type Output = i8;
  fn index(&self, elem: Element) -> &Self::Output {
    &self.0[elem]
  }
}

impl Index<Element> for AbstractDices {
  type Output = i8;
  fn index(&self, elem: Element) -> &Self::Output {
    &self.0[elem]
  }
}