use crate::{
    lexer::{Literal, Token, TokenType},
    location::LocationContainer,
    Lox,
};

use display_tree::DisplayTree;
use thiserror::Error;

use std::{fmt::Display, rc::Rc};

/* GRAMMER:
expression     → equality ;
equality       → comparison ( ( "!=" | "==" ) comparison )* ;
comparison     → term ( ( ">" | ">=" | "<" | "<=" ) term )* ;
term           → factor ( ( "-" | "+" ) factor )* ;
factor         → unary ( ( "/" | "*" ) unary )* ;
unary          → ( "!" | "-" ) unary
               | primary ;
primary        → NUMBER | STRING | "true" | "false" | "nil"
               | "(" expression ")" ;
*/

/* Recursive descent parsing:
A recursive descent parser is a literal translation of
the grammar’s rules straight into imperative code.
Each rule becomes a function.

The body of the rule translates to code roughly like:
| Grammar notation | Code representation               |
| ---------------- | --------------------------------- |
| Terminal         | Code to match and consume a token |
| Non-terminal     | Call to that rule’s function      |
| '|'              | if or switch statement            |
| '*' or '+'       | while or for loop                 |
| '?'              | if statement                      |
*/

#[derive(Debug, DisplayTree)]
pub enum Expression {
    Binary {
        #[tree]
        left: Rc<Expression>,
        operator: Token,
        #[tree]
        right: Rc<Expression>,
    },
    Grouping {
        #[tree]
        expression: Rc<Expression>,
    },
    Literal {
        value: LocationContainer<Literal>,
    },
    Unary {
        operator: Token,
        #[tree]
        right: Rc<Expression>,
    },
}

impl Display for Expression {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", parenthize(self))
    }
}

fn parenthize(expression: &Expression) -> String {
    match expression {
        Expression::Binary {
            left,
            operator,
            right,
        } => {
            format!(
                "({} {} {})",
                operator.lexeme,
                parenthize(left),
                parenthize(right)
            )
        }
        Expression::Grouping { expression } => format!("(group {})", parenthize(expression)),
        Expression::Literal { value } => format!("{}", value),
        Expression::Unary { operator, right } => {
            format!("({} {})", operator.lexeme, parenthize(right))
        }
    }
}

#[derive(Error, Debug)]
pub enum ParseError {
    #[error("Expect ')' after expression.")]
    RightParenExpected,
    #[error("Expect expression.")]
    ExpressionExpected,
    #[error("internal compiler error: Literal has no value.")]
    EmptyLiteral,
}

pub struct Parser<'a> {
    pub lox: &'a mut Lox,
    pub tokens: Vec<Token>,
    pub current: usize,
}

impl<'a> Parser<'a> {
    pub fn new(lox: &'a mut Lox, tokens: Vec<Token>) -> Self {
        Self {
            lox,
            tokens,
            current: 0,
        }
    }

    pub fn peek(&self) -> Token {
        self.tokens[self.current].clone()
    }

    pub fn previous(&self) -> Token {
        self.tokens[self.current - 1].clone()
    }

    pub fn is_at_end(&self) -> bool {
        self.current >= self.tokens.len() || self.peek().token_type == TokenType::Eof
    }

    pub fn advance(&mut self) -> Token {
        if !self.is_at_end() {
            self.current += 1;
        }
        self.previous()
    }

    pub fn check(&self, token_type: &TokenType) -> bool {
        if self.is_at_end() {
            false
        } else {
            &self.peek().token_type == token_type
        }
    }

    pub fn match_tokens(&mut self, types: &[TokenType]) -> bool {
        for token_type in types {
            if self.check(token_type) {
                self.advance();
                return true;
            }
        }
        false
    }

    fn expression(&mut self) -> Result<Expression, ParseError> {
        self.equality()
    }

    fn equality(&mut self) -> Result<Expression, ParseError> {
        let mut expression = self.comparison()?;

        while self.match_tokens(&[TokenType::BangEqual, TokenType::EqualEqual]) {
            let operator = self.previous();
            let right = self.comparison()?;
            expression = Expression::Binary {
                left: Rc::new(expression),
                operator,
                right: Rc::new(right),
            };
        }

        Ok(expression)
    }

    fn comparison(&mut self) -> Result<Expression, ParseError> {
        let mut expression = self.term()?;

        while self.match_tokens(&[
            TokenType::Greater,
            TokenType::GreaterEqual,
            TokenType::Less,
            TokenType::LessEqual,
        ]) {
            let operator = self.previous();
            let right = self.term()?;
            expression = Expression::Binary {
                left: Rc::new(expression),
                operator,
                right: Rc::new(right),
            };
        }

        Ok(expression)
    }

    fn term(&mut self) -> Result<Expression, ParseError> {
        let mut expression = self.factor()?;

        while self.match_tokens(&[TokenType::Minus, TokenType::Plus, TokenType::Tilde]) {
            let operator = self.previous();
            let right = self.factor()?;
            expression = Expression::Binary {
                left: Rc::new(expression),
                operator,
                right: Rc::new(right),
            };
        }

        Ok(expression)
    }

    fn factor(&mut self) -> Result<Expression, ParseError> {
        let mut expression = self.unary()?;

        while self.match_tokens(&[TokenType::Slash, TokenType::Star]) {
            let operator = self.previous();
            let right = self.unary()?;
            expression = Expression::Binary {
                left: Rc::new(expression),
                operator,
                right: Rc::new(right),
            };
        }

        Ok(expression)
    }

    fn unary(&mut self) -> Result<Expression, ParseError> {
        if self.match_tokens(&[TokenType::Bang, TokenType::Minus]) {
            let operator = self.previous();
            let right = self.unary()?;
            Ok(Expression::Unary {
                operator,
                right: Rc::new(right),
            })
        } else {
            self.primary()
        }
    }

    fn consume(&mut self, token_type: TokenType, message: &str) {
        if !self.check(&token_type) {
            self.lox.error(self.peek().location.line, message);
            panic!();
        }
        self.advance();
    }

    fn synchronize(&mut self) {
        self.advance();

        while !self.is_at_end() {
            if self.previous().token_type == TokenType::Semicolon {
                return;
            }

            match self.peek().token_type {
                TokenType::Class
                | TokenType::Fun
                | TokenType::Var
                | TokenType::For
                | TokenType::If
                | TokenType::While
                | TokenType::Print
                | TokenType::Return => return,
                _ => _ = self.advance(),
            }
        }
    }

    fn primary(&mut self) -> Result<Expression, ParseError> {
        if self.match_tokens(&[TokenType::Literal]) {
            let previous = self.previous();
            let inner = previous.literal.ok_or(ParseError::EmptyLiteral)?;
            Ok(Expression::Literal {
                value: LocationContainer {
                    location: previous.location.clone(),
                    inner,
                },
            })
        } else if self.match_tokens(&[TokenType::LeftParen]) {
            let expression = self.expression()?;
            if self.match_tokens(&[TokenType::RightParen]) {
                Ok(Expression::Grouping {
                    expression: Rc::new(expression),
                })
            } else {
                Err(ParseError::RightParenExpected)
            }
        } else {
            Err(ParseError::ExpressionExpected)
        }
    }

    pub fn parse(&mut self) -> Result<Expression, ParseError> {
        let expression = self.expression()?;
        if !self.is_at_end() {
            self.lox
                .error_token(self.tokens[self.current].clone(), "Expected EOF.");
        }
        Ok(expression)
    }
}