Repository for the KTH Automata and Languages course (DD2373, VT24)
use std::{iter::Peekable, vec::IntoIter};

use tokenizer::Token;

mod tokenizer {
    #[derive(Debug, PartialEq, Eq)]
    pub enum Token {
        Char(u8),   // A-Z a-z
        Dot,        // .
        Alt,        // |
        Star,       // *
        Plus,       // +
        Quest,      // ?
        Lpar,       // (
        Rpar,       // )
    }

    #[derive(Debug, PartialEq)]
    pub enum TokenError {
        NotAccepted(u8),
        Miscellaneous(String),
    }

    pub fn map(input: u8) -> Result<Token, TokenError> {
        if input.is_ascii() {
            Ok(match input {
                b'.' => Token::Dot,
                b'|' => Token::Alt,
                b'*' => Token::Star,
                b'+' => Token::Plus,
                b'?' => Token::Quest,
                b'(' => Token::Lpar,
                b')' => Token::Rpar,
                x => Token::Char(x)
            })
        } else {
            Err(TokenError::NotAccepted(input))
        }
    }

    pub fn scan(input: &[u8]) -> Result<Vec<(Token, usize)>, (TokenError, usize)> {
        input
            .iter()
            .cloned()
            .enumerate()
            .map(|(pos, inp)| map(inp)
                .map(|tok| (tok, pos))
                .map_err(|err| (err, pos)))
            .collect()
    }

    #[cfg(test)]
    mod tests {
        use super::{map, Token};

        #[test]
        fn match_dot() {
            assert_eq!(map(b'.'), Ok(Token::Dot))
        }

        #[test]
        fn match_alternation() {
            assert_eq!(map(b'|'), Ok(Token::Alt))
        }

        #[test]
        fn match_star() {
            assert_eq!(map(b'*'), Ok(Token::Star))
        }

        #[test]
        fn match_plus() {
            assert_eq!(map(b'+'), Ok(Token::Plus))
        }

        #[test]
        fn match_question() {
            assert_eq!(map(b'?'), Ok(Token::Quest))
        }

        #[test]
        fn match_left_parenthesis() {
            assert_eq!(map(b'('), Ok(Token::Lpar))
        }

        #[test]
        fn match_right_parenthesis() {
            assert_eq!(map(b')'), Ok(Token::Rpar))
        }

        #[test]
        fn match_ascii() {
            assert_eq!(map(b'a'), Ok(Token::Char(b'a')))
        }
    }
}

#[derive(Debug, PartialEq)]
pub enum Expr {
    Conc(Box<Expr>, Box<Expr>),
    Alt(Box<Expr>, Box<Expr>),
    Star(Box<Expr>),
    Quest(Box<Expr>),
    Char(u8),
    Dot
}

#[derive(Debug, PartialEq)]
pub enum ParseError {
    TokenError((tokenizer::TokenError, usize)),
    UnexpectedEnd
}

type Result = std::result::Result<Expr, ParseError>;

pub fn parse(input: &[u8]) -> Result {
    let mut tokens = tokenizer::scan(input)
        .map_err(|x| ParseError::TokenError(x))?
        .into_iter()
        .peekable();

    parse_expr(&mut tokens)
}

fn parse_expr(tokens: &mut Peekable<IntoIter<(Token, usize)>>) -> Result {
    let left = parse_term(tokens)?;
    parse_expr_inner(tokens, left)
}

fn parse_expr_inner(tokens: &mut Peekable<IntoIter<(Token, usize)>>, left: Expr) -> Result {
    if matches!(tokens.peek(), Some((Token::Alt, _))) {
        let _ = tokens.next().unwrap();

        let right = parse_term(tokens)?;
        let expr = Expr::Alt(Box::new(left), Box::new(right));

        parse_expr_inner(tokens, expr)
    } else {
        Ok(left)
    }
}

fn parse_term(tokens: &mut Peekable<IntoIter<(Token, usize)>>) -> Result {
    let left = parse_factor(tokens)?;
    parse_term_inner(tokens, left)
}

fn parse_term_inner(tokens: &mut Peekable<IntoIter<(Token, usize)>>, left: Expr) -> Result {
    if matches!(tokens.peek(), Some((Token::Char(_), _)) | Some((Token::Dot, _))) {
        let right = parse_factor(tokens)?;
        let expr = Expr::Conc(Box::new(left), Box::new(right));

        parse_term_inner(tokens, expr)
    } else {
        Ok(left)
    }
}

fn parse_factor(tokens: &mut Peekable<IntoIter<(Token, usize)>>) -> Result {
    match tokens.next() {
        Some((Token::Star, _)) => {
            Ok(Expr::Star(()))
        }

        _ => Err(ParseError::UnexpectedEnd)
    }
}

#[cfg(test)]
mod tests {
    use crate::parser::{parse, Expr};

    #[test]
    fn parse_ascii() {
        assert_eq!(parse(b"a"), Ok(Expr::Char(b'b')))
    }

    #[test]
    fn parse_dot() {
        assert_eq!(parse(b"."), Ok(Expr::Dot))
    }

    #[test]
    fn parse_concatenation() {
        assert_eq!(parse(b"ab"), Ok(Expr::Conc(Box::new(Expr::Char(b'a')), Box::new(Expr::Char(b'b')))));
    }

    #[test]
    fn parse_alternation() {
        assert_eq!(parse(b"a|b"), Ok(Expr::Alt(Box::new(Expr::Char(b'a')), Box::new(Expr::Char(b'b')))));
    }

    #[test]
    fn parse_star() {
        assert_eq!(parse(b"a*"), Ok(Expr::Star(Box::new(Expr::Char(b'a')))))
    }

    #[test]
    fn parse_plus() {
        assert_eq!(
            parse(b"a+"),
            Ok(Expr::Conc(
                Box::new(Expr::Char(b'a')),
                Box::new(Expr::Star(
                    Box::new(Expr::Char(b'a')))))))
    }

    #[test]
    fn parse_question() {
        assert_eq!(parse(b"a?"), Ok(Expr::Quest(Box::new(Expr::Char(b'a')))))
    }

    #[test]
    fn parse_star_before_concatenation() {
        assert_eq!(
            parse(b"ab*"),
            Ok(Expr::Conc(
                Box::new(Expr::Char(b'a')),
                Box::new(Expr::Star(
                    Box::new(Expr::Char(b'b')))
                ))
            )
        )
    }

    #[test]
    fn parse_quest_before_concatenation() {
        assert_eq!(
            parse(b"ab*"),
            Ok(Expr::Conc(
                Box::new(Expr::Char(b'a')),
                Box::new(Expr::Quest(
                    Box::new(Expr::Char(b'b')))
                ))
            )
        )
    }

    #[test]
    fn parse_plus_before_concatenation() {
        assert_eq!(
            parse(b"ab+"),
            Ok(Expr::Conc(
                Box::new(Expr::Char(b'a')),
                Box::new(Expr::Quest(
                    Box::new(Expr::Char(b'b')))
                ))
            )
        )
    }

    #[test]
    fn no_star_after_alternation() {
        assert!(matches!(
            parse(b"|*"),
            Err(_)
        ))
    }

    #[test]
    fn no_plus_after_alternation() {
        assert!(matches!(
            parse(b"|+"),
            Err(_)
        ))
    }

    #[test]
    fn no_quest_after_alternation() {
        assert!(matches!(
            parse(b"|?"),
            Err(_)
        ))
    }

    #[test]
    fn no_lone_star() {
        assert!(matches!(
            parse(b"*"),
            Err(_)
        ))
    }

    #[test]
    fn no_lone_plus() {
        assert!(matches!(
            parse(b"+"),
            Err(_)
        ))
    }

    #[test]
    fn no_lone_question() {
        assert!(matches!(
            parse(b"?"),
            Err(_)
        ))
    }
}