From ebfd74ddf0ef6372624ea171e06f8460d0e1351b Mon Sep 17 00:00:00 2001 From: Cody Date: Mon, 27 Feb 2023 07:21:50 -0600 Subject: pain --- src/ast/parser.rs | 364 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 364 insertions(+) create mode 100644 src/ast/parser.rs (limited to 'src/ast/parser.rs') diff --git a/src/ast/parser.rs b/src/ast/parser.rs new file mode 100644 index 0000000..64ed352 --- /dev/null +++ b/src/ast/parser.rs @@ -0,0 +1,364 @@ +use super::{Expr, Stmt}; +use crate::lexer::{Literal, Token, TokenType}; + +pub struct AstParser<'a> { + tokens: Vec>, + index: usize, +} + +/// Implementation containing utilities used by the parsers internal components +impl<'a> AstParser<'a> { + pub fn new(tokens: Vec>) -> Self { + Self { tokens, index: 0 } + } + + fn previous(&self) -> Option<&Token> { + self.tokens.get(self.index - 1) + } + + fn peek(&self) -> &Token { + &self.tokens[self.index] + } + + fn peek_nth(&self, nth: usize) -> Option<&Token> { + self.tokens.get(self.index + nth) + } + + fn advance(&mut self) -> Option<&Token> { + if self.eof() { + return None; + } + + self.index += 1; + Some(&self.tokens[self.index - 1]) + } + + fn advance_if(&mut self, next: impl FnOnce(&Token) -> bool) -> bool { + if self.eof() { + return false; + } + + if next(self.peek()) { + self.advance(); + return true; + } + + false + } + + fn advance_if_eq(&mut self, next: &TokenType) -> bool { + self.advance_if(|it| it.tt == *next) + } + + fn advance_seq(&mut self, seq: &[TokenType]) -> bool { + for token in seq { + if !self.advance_if_eq(token) { + return false; + } + } + + true + } + + fn consume(&mut self, next: TokenType, error: &str) { + if std::mem::discriminant(&self.peek().tt) != std::mem::discriminant(&next) { + panic!("{error}"); + } + self.advance(); + } + + fn eof(&self) -> bool { + self.index >= self.tokens.len() + } +} + +/// Implementation containing parsers internal components related to statements +impl<'a> AstParser<'a> { + pub fn parse(&mut self) -> Vec { + let mut statements = Vec::new(); + + while !self.eof() { + statements.push(self.statement()); + } + + statements + } + + fn block(&mut self) -> Vec { + self.consume(TokenType::LeftBrace, "Expected '{' at beggining of block"); + + let mut statements = Vec::new(); + + while !self.eof() && self.peek().tt != TokenType::RightBrace { + statements.push(self.statement()); + } + + self.consume(TokenType::RightBrace, "Expected '}' at end of block"); + statements + } + + fn statement(&mut self) -> Stmt { + if self.peek().tt == TokenType::LeftBrace { + return Stmt::Block(self.block()); + } + + if self.advance_if_eq(&TokenType::Print) { + return self.print_statement(); + } + + if self.advance_if_eq(&TokenType::Var) { + return self.var_statement(); + } + + if self.advance_if_eq(&TokenType::Val) { + return self.val_statement(); + } + + if self.advance_if_eq(&TokenType::If) { + return self.if_statement(); + } + + if self.advance_if_eq(&TokenType::For) { + return self.for_statement(); + } + + // If we couldn't parse a statement return an expression statement + self.expression_statement() + } + + fn print_statement(&mut self) -> Stmt { + let value = self.expression(); + self.consume(TokenType::SemiColon, "Expected ';' at end of statement"); + Stmt::Print { value } + } + + fn var_statement(&mut self) -> Stmt { + let TokenType::Identifier(ident) = self.peek().tt.clone() else { + panic!("Identifier expected after 'var'"); + }; + + self.advance(); // Advancing from the identifier + self.consume(TokenType::Eq, "Expected '=' after identifier"); + + let value = self.expression(); + + self.consume(TokenType::SemiColon, "Expected ';' at end of statement"); + + Stmt::Var { ident, value } + } + + fn val_statement(&mut self) -> Stmt { + let TokenType::Identifier(ident) = self.peek().tt.clone() else { + panic!("Identifier expected after 'val'"); + }; + + self.advance(); // Advancing from the identifier + self.consume(TokenType::Eq, "Expected '=' after identifier"); + + let value = self.expression(); + + self.consume(TokenType::SemiColon, "Expected ';' at end of statement"); + + Stmt::Val { ident, value } + } + + fn if_statement(&mut self) -> Stmt { + let condition = self.expression(); + let body = self.block(); + + Stmt::If { condition, body } + } + + fn for_statement(&mut self) -> Stmt { + let binding = self.expression(); + let Expr::Variable(binding) = binding else { + panic!("Left side of for statement must be identifier"); + }; + + self.consume( + TokenType::In, + "Expected 'in' in between identifier and range", + ); + + let range_start = self.expression(); + self.consume( + TokenType::DotDot, + "Expected '..' denoting min and max of range", + ); + let range_end = self.expression(); + + let body = self.block(); + + Stmt::For { + binding, + range: (range_start, range_end), + body, + } + } + + fn expression_statement(&mut self) -> Stmt { + let expr = self.expression(); + + // FIXME: Move assignment handling + if self.advance_if_eq(&TokenType::Eq) { + if let Expr::Variable(ident) = &expr { + let value = self.expression(); + + self.consume(TokenType::SemiColon, "Expected ';' at end of statement"); + return Stmt::Assignment { + ident: ident.clone(), + value, + }; + } + } + + self.consume(TokenType::SemiColon, "Expected ';' at end of statement"); + Stmt::Expr(expr) + } +} + +/// Implementation containing parsers internal components related to expressions +impl<'a> AstParser<'a> { + // FIXME: Should probably avoid cloning token types + + fn expression(&mut self) -> Expr { + self.logical_or() + } + + fn unary(&mut self) -> Expr { + if !self.eof() + && matches!( + self.peek().tt, + TokenType::Bang | TokenType::Plus | TokenType::Minus + ) + { + let operator = self.advance().unwrap().tt.clone(); + let rhs = self.unary(); + return Expr::Unary { + operator, + expr: Box::new(rhs), + }; + } + + self.primary() + } + + fn primary(&mut self) -> Expr { + match self.advance().unwrap().tt.clone() { + TokenType::Literal(literal) => Expr::Literal(literal), + TokenType::Identifier(ident) => Expr::Variable(ident), + TokenType::LeftParen => { + let expr = self.expression(); + self.consume(TokenType::RightParen, "Must end expression with ')'"); + Expr::Grouping(Box::new(expr)) + } + _ => unimplemented!("{:?}", self.peek()), + } + } +} + +// Macro to generate repetitive binary expressions. Things like addition, +// multiplication, exc. +macro_rules! binary_expr { + ($name:ident, $parent:ident, $pattern:pat) => { + fn $name(&mut self) -> Expr { + let mut expr = self.$parent(); + + while !self.eof() && matches!(self.peek().tt, $pattern) { + let operator = self.advance().unwrap().tt.clone(); + let rhs = self.$parent(); + expr = Expr::Binary { + operator, + lhs: Box::new(expr), + rhs: Box::new(rhs), + }; + } + + expr + } + }; +} + +#[rustfmt::skip] +impl<'a> AstParser<'a> { + // Binary expressions in order of precedence from lowest to highest. + binary_expr!(logical_or , logical_and , (TokenType::PipePipe)); + binary_expr!(logical_and , equality , (TokenType::AmpAmp)); + binary_expr!(equality , comparison , (TokenType::BangEq | TokenType::EqEq)); + binary_expr!(comparison , bitwise_shifting, (TokenType::Lt | TokenType::Gt | TokenType::LtEq | TokenType::GtEq)); + binary_expr!(bitwise_shifting, additive , (TokenType::LtLt | TokenType::GtGt)); + binary_expr!(additive , multiplicative , (TokenType::Plus | TokenType::Minus)); + binary_expr!(multiplicative , unary , (TokenType::Star | TokenType::Slash | TokenType::Perc)); +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + + use super::AstParser; + use crate::ast::Expr; + use crate::lexer::{Lexer, Literal, TokenType}; + + #[test] + fn basic_expression_a() { + let lexer = Lexer::new("3 + 5 * 4"); + let tokens = lexer.collect_vec(); + + let expected_ast = Expr::Binary { + operator: TokenType::Plus, + lhs: Box::new(Expr::Literal(Literal::Number(3))), + rhs: Box::new(Expr::Binary { + operator: TokenType::Star, + lhs: Box::new(Expr::Literal(Literal::Number(5))), + rhs: Box::new(Expr::Literal(Literal::Number(4))), + }), + }; + + let mut parser = AstParser::new(tokens); + let generated_ast = parser.expression(); + + println!("Expected AST:\n{expected_ast:#?}\n\n"); + println!("Generated AST:\n{generated_ast:#?}\n\n"); + + assert_eq!(expected_ast, generated_ast); + } + + #[test] + fn basic_expression_b() { + let lexer = Lexer::new("17 - (-5 + 5) / 6"); + let tokens = lexer.collect_vec(); + + let expected_ast = Expr::Binary { + operator: TokenType::Minus, + lhs: Box::new(Expr::Literal(Literal::Number(17))), + rhs: Box::new(Expr::Binary { + operator: TokenType::Slash, + lhs: Box::new(Expr::Grouping(Box::new(Expr::Binary { + operator: TokenType::Plus, + lhs: Box::new(Expr::Unary { + operator: TokenType::Minus, + expr: Box::new(Expr::Literal(Literal::Number(5))), + }), + rhs: Box::new(Expr::Literal(Literal::Number(5))), + }))), + rhs: Box::new(Expr::Literal(Literal::Number(6))), + }), + }; + + let mut parser = AstParser::new(tokens); + let generated_ast = parser.expression(); + + println!("Expected AST:\n{expected_ast:#?}\n\n"); + println!("Generated AST:\n{generated_ast:#?}\n\n"); + + assert_eq!(expected_ast, generated_ast); + } + + #[test] + fn basic_expression_c() { + let lexer = Lexer::new("9 > 6 && 5 + 7 == 32 || \"apple\" != \"banana\""); + let tokens = lexer.collect_vec(); + + // TODO: + } +} -- cgit v1.2.3