diff options
Diffstat (limited to 'src/ast')
| -rw-r--r-- | src/ast/display.rs | 45 | ||||
| -rw-r--r-- | src/ast/evaluate.rs | 5 | ||||
| -rw-r--r-- | src/ast/mod.rs | 95 | ||||
| -rw-r--r-- | src/ast/parser.rs | 364 | ||||
| -rw-r--r-- | src/ast/printer.rs | 38 |
5 files changed, 458 insertions, 89 deletions
diff --git a/src/ast/display.rs b/src/ast/display.rs deleted file mode 100644 index 25b4d23..0000000 --- a/src/ast/display.rs +++ /dev/null @@ -1,45 +0,0 @@ -use std::fmt::Display; - -use super::{Expression, Statement, Value}; - -impl<'a> Display for Statement<'a> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{{")?; - let value = match self { - Statement::Val { - identifier, - initializer, - } => format!("val {} {}", identifier.lexeme, initializer), - Statement::Var { - identifier, - initializer, - } => format!("var {} {}", identifier.lexeme, initializer), - Statement::Expression { expr } => expr.to_string(), - }; - write!(f, "{value}")?; - write!(f, "}}")?; - - Ok(()) - } -} - -impl<'a> Display for Expression<'a> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "(")?; - let value = match self { - Expression::Literal(value) => value.0.to_string(), - Expression::Unary { expr, .. } => format!("+ {}", expr), - Expression::Binary { lhs, rhs, .. } => format!("+ {lhs} {rhs}"), - }; - write!(f, "{value}")?; - write!(f, ")")?; - - Ok(()) - } -} - -impl Display for Value { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} diff --git a/src/ast/evaluate.rs b/src/ast/evaluate.rs deleted file mode 100644 index 4d47994..0000000 --- a/src/ast/evaluate.rs +++ /dev/null @@ -1,5 +0,0 @@ -impl<'a> Expression<'a> { - fn evaluate() -> Value { - Value(5) - } -} diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 73e73b3..04b1d26 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1,55 +1,72 @@ -#![allow(dead_code)] +use std::fmt::Display; -pub mod display; +use crate::lexer::{Literal, TokenType}; -use crate::lexer::Token; +pub mod parser; +pub mod printer; -#[derive(Clone)] -pub enum Statement<'a> { +#[derive(Debug, Eq, PartialEq)] +pub enum Stmt { + Block(Vec<Stmt>), + Expr(Expr), Val { - identifier: &'a Token<'a>, - initializer: &'a Expression<'a>, + ident: String, + value: Expr, }, Var { - identifier: &'a Token<'a>, - initializer: &'a Expression<'a>, + ident: String, + value: Expr, }, - Expression { - expr: &'a Expression<'a>, + Assignment { + ident: String, + value: Expr, }, -} - -#[derive(Clone)] -pub enum Expression<'a> { - // Basic - Literal(Value), - Unary { - operation: Operation, - expr: &'a Expression<'a>, + Function { + ident: String, + arguments: Vec<FunctionArgument>, + return_type: String, + body: Vec<Stmt>, }, - Binary { - operation: Operation, - lhs: &'a Expression<'a>, - rhs: &'a Expression<'a>, + If { + condition: Expr, + body: Vec<Stmt>, + }, + For { + binding: String, + range: (Expr, Expr), + body: Vec<Stmt>, + }, + Return { + value: Expr, + }, + Print { + value: Expr, }, - // Grouping } -#[derive(Clone)] -pub enum Operation { - Add, - Subtract, +#[derive(Debug, Eq, PartialEq)] +pub struct FunctionArgument { + name: String, + types: String, } -#[derive(Clone)] -pub struct Value(pub i32); +#[derive(Debug, Eq, PartialEq)] +pub enum Expr { + Literal(Literal), + Variable(String), + Grouping(Box<Expr>), + Binary { + operator: TokenType, + lhs: Box<Expr>, + rhs: Box<Expr>, + }, + Unary { + operator: TokenType, + expr: Box<Expr>, + }, +} -#[test] -fn test() { - let right = Expression::Literal(Value(7)); - let _ = Expression::Binary { - operation: Operation::Add, - lhs: &Expression::Literal(Value(5)), - rhs: &right, - }; +pub trait AstVisitor<T = ()> { + fn visit_stmt(&mut self, stmt: &Stmt) -> T; + fn visit_expr(&mut self, expr: &Expr) -> T; } diff --git a/src/ast/parser.rs b/src/ast/parser.rs new file mode 100644 index 0000000..64ed352 --- /dev/null +++ b/src/ast/parser.rs @@ -0,0 +1,364 @@ +use super::{Expr, Stmt}; +use crate::lexer::{Literal, Token, TokenType}; + +pub struct AstParser<'a> { + tokens: Vec<Token<'a>>, + index: usize, +} + +/// Implementation containing utilities used by the parsers internal components +impl<'a> AstParser<'a> { + pub fn new(tokens: Vec<Token<'a>>) -> Self { + Self { tokens, index: 0 } + } + + fn previous(&self) -> Option<&Token> { + self.tokens.get(self.index - 1) + } + + fn peek(&self) -> &Token { + &self.tokens[self.index] + } + + fn peek_nth(&self, nth: usize) -> Option<&Token> { + self.tokens.get(self.index + nth) + } + + fn advance(&mut self) -> Option<&Token> { + if self.eof() { + return None; + } + + self.index += 1; + Some(&self.tokens[self.index - 1]) + } + + fn advance_if(&mut self, next: impl FnOnce(&Token) -> bool) -> bool { + if self.eof() { + return false; + } + + if next(self.peek()) { + self.advance(); + return true; + } + + false + } + + fn advance_if_eq(&mut self, next: &TokenType) -> bool { + self.advance_if(|it| it.tt == *next) + } + + fn advance_seq(&mut self, seq: &[TokenType]) -> bool { + for token in seq { + if !self.advance_if_eq(token) { + return false; + } + } + + true + } + + fn consume(&mut self, next: TokenType, error: &str) { + if std::mem::discriminant(&self.peek().tt) != std::mem::discriminant(&next) { + panic!("{error}"); + } + self.advance(); + } + + fn eof(&self) -> bool { + self.index >= self.tokens.len() + } +} + +/// Implementation containing parsers internal components related to statements +impl<'a> AstParser<'a> { + pub fn parse(&mut self) -> Vec<Stmt> { + let mut statements = Vec::new(); + + while !self.eof() { + statements.push(self.statement()); + } + + statements + } + + fn block(&mut self) -> Vec<Stmt> { + self.consume(TokenType::LeftBrace, "Expected '{' at beggining of block"); + + let mut statements = Vec::new(); + + while !self.eof() && self.peek().tt != TokenType::RightBrace { + statements.push(self.statement()); + } + + self.consume(TokenType::RightBrace, "Expected '}' at end of block"); + statements + } + + fn statement(&mut self) -> Stmt { + if self.peek().tt == TokenType::LeftBrace { + return Stmt::Block(self.block()); + } + + if self.advance_if_eq(&TokenType::Print) { + return self.print_statement(); + } + + if self.advance_if_eq(&TokenType::Var) { + return self.var_statement(); + } + + if self.advance_if_eq(&TokenType::Val) { + return self.val_statement(); + } + + if self.advance_if_eq(&TokenType::If) { + return self.if_statement(); + } + + if self.advance_if_eq(&TokenType::For) { + return self.for_statement(); + } + + // If we couldn't parse a statement return an expression statement + self.expression_statement() + } + + fn print_statement(&mut self) -> Stmt { + let value = self.expression(); + self.consume(TokenType::SemiColon, "Expected ';' at end of statement"); + Stmt::Print { value } + } + + fn var_statement(&mut self) -> Stmt { + let TokenType::Identifier(ident) = self.peek().tt.clone() else { + panic!("Identifier expected after 'var'"); + }; + + self.advance(); // Advancing from the identifier + self.consume(TokenType::Eq, "Expected '=' after identifier"); + + let value = self.expression(); + + self.consume(TokenType::SemiColon, "Expected ';' at end of statement"); + + Stmt::Var { ident, value } + } + + fn val_statement(&mut self) -> Stmt { + let TokenType::Identifier(ident) = self.peek().tt.clone() else { + panic!("Identifier expected after 'val'"); + }; + + self.advance(); // Advancing from the identifier + self.consume(TokenType::Eq, "Expected '=' after identifier"); + + let value = self.expression(); + + self.consume(TokenType::SemiColon, "Expected ';' at end of statement"); + + Stmt::Val { ident, value } + } + + fn if_statement(&mut self) -> Stmt { + let condition = self.expression(); + let body = self.block(); + + Stmt::If { condition, body } + } + + fn for_statement(&mut self) -> Stmt { + let binding = self.expression(); + let Expr::Variable(binding) = binding else { + panic!("Left side of for statement must be identifier"); + }; + + self.consume( + TokenType::In, + "Expected 'in' in between identifier and range", + ); + + let range_start = self.expression(); + self.consume( + TokenType::DotDot, + "Expected '..' denoting min and max of range", + ); + let range_end = self.expression(); + + let body = self.block(); + + Stmt::For { + binding, + range: (range_start, range_end), + body, + } + } + + fn expression_statement(&mut self) -> Stmt { + let expr = self.expression(); + + // FIXME: Move assignment handling + if self.advance_if_eq(&TokenType::Eq) { + if let Expr::Variable(ident) = &expr { + let value = self.expression(); + + self.consume(TokenType::SemiColon, "Expected ';' at end of statement"); + return Stmt::Assignment { + ident: ident.clone(), + value, + }; + } + } + + self.consume(TokenType::SemiColon, "Expected ';' at end of statement"); + Stmt::Expr(expr) + } +} + +/// Implementation containing parsers internal components related to expressions +impl<'a> AstParser<'a> { + // FIXME: Should probably avoid cloning token types + + fn expression(&mut self) -> Expr { + self.logical_or() + } + + fn unary(&mut self) -> Expr { + if !self.eof() + && matches!( + self.peek().tt, + TokenType::Bang | TokenType::Plus | TokenType::Minus + ) + { + let operator = self.advance().unwrap().tt.clone(); + let rhs = self.unary(); + return Expr::Unary { + operator, + expr: Box::new(rhs), + }; + } + + self.primary() + } + + fn primary(&mut self) -> Expr { + match self.advance().unwrap().tt.clone() { + TokenType::Literal(literal) => Expr::Literal(literal), + TokenType::Identifier(ident) => Expr::Variable(ident), + TokenType::LeftParen => { + let expr = self.expression(); + self.consume(TokenType::RightParen, "Must end expression with ')'"); + Expr::Grouping(Box::new(expr)) + } + _ => unimplemented!("{:?}", self.peek()), + } + } +} + +// Macro to generate repetitive binary expressions. Things like addition, +// multiplication, exc. +macro_rules! binary_expr { + ($name:ident, $parent:ident, $pattern:pat) => { + fn $name(&mut self) -> Expr { + let mut expr = self.$parent(); + + while !self.eof() && matches!(self.peek().tt, $pattern) { + let operator = self.advance().unwrap().tt.clone(); + let rhs = self.$parent(); + expr = Expr::Binary { + operator, + lhs: Box::new(expr), + rhs: Box::new(rhs), + }; + } + + expr + } + }; +} + +#[rustfmt::skip] +impl<'a> AstParser<'a> { + // Binary expressions in order of precedence from lowest to highest. + binary_expr!(logical_or , logical_and , (TokenType::PipePipe)); + binary_expr!(logical_and , equality , (TokenType::AmpAmp)); + binary_expr!(equality , comparison , (TokenType::BangEq | TokenType::EqEq)); + binary_expr!(comparison , bitwise_shifting, (TokenType::Lt | TokenType::Gt | TokenType::LtEq | TokenType::GtEq)); + binary_expr!(bitwise_shifting, additive , (TokenType::LtLt | TokenType::GtGt)); + binary_expr!(additive , multiplicative , (TokenType::Plus | TokenType::Minus)); + binary_expr!(multiplicative , unary , (TokenType::Star | TokenType::Slash | TokenType::Perc)); +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + + use super::AstParser; + use crate::ast::Expr; + use crate::lexer::{Lexer, Literal, TokenType}; + + #[test] + fn basic_expression_a() { + let lexer = Lexer::new("3 + 5 * 4"); + let tokens = lexer.collect_vec(); + + let expected_ast = Expr::Binary { + operator: TokenType::Plus, + lhs: Box::new(Expr::Literal(Literal::Number(3))), + rhs: Box::new(Expr::Binary { + operator: TokenType::Star, + lhs: Box::new(Expr::Literal(Literal::Number(5))), + rhs: Box::new(Expr::Literal(Literal::Number(4))), + }), + }; + + let mut parser = AstParser::new(tokens); + let generated_ast = parser.expression(); + + println!("Expected AST:\n{expected_ast:#?}\n\n"); + println!("Generated AST:\n{generated_ast:#?}\n\n"); + + assert_eq!(expected_ast, generated_ast); + } + + #[test] + fn basic_expression_b() { + let lexer = Lexer::new("17 - (-5 + 5) / 6"); + let tokens = lexer.collect_vec(); + + let expected_ast = Expr::Binary { + operator: TokenType::Minus, + lhs: Box::new(Expr::Literal(Literal::Number(17))), + rhs: Box::new(Expr::Binary { + operator: TokenType::Slash, + lhs: Box::new(Expr::Grouping(Box::new(Expr::Binary { + operator: TokenType::Plus, + lhs: Box::new(Expr::Unary { + operator: TokenType::Minus, + expr: Box::new(Expr::Literal(Literal::Number(5))), + }), + rhs: Box::new(Expr::Literal(Literal::Number(5))), + }))), + rhs: Box::new(Expr::Literal(Literal::Number(6))), + }), + }; + + let mut parser = AstParser::new(tokens); + let generated_ast = parser.expression(); + + println!("Expected AST:\n{expected_ast:#?}\n\n"); + println!("Generated AST:\n{generated_ast:#?}\n\n"); + + assert_eq!(expected_ast, generated_ast); + } + + #[test] + fn basic_expression_c() { + let lexer = Lexer::new("9 > 6 && 5 + 7 == 32 || \"apple\" != \"banana\""); + let tokens = lexer.collect_vec(); + + // TODO: + } +} diff --git a/src/ast/printer.rs b/src/ast/printer.rs new file mode 100644 index 0000000..1aa32ae --- /dev/null +++ b/src/ast/printer.rs @@ -0,0 +1,38 @@ +// use super::{AstVisitor, Expr, Stmt}; + +// pub struct AstPrettyPrinter; +// impl AstVisitor<String> for AstPrettyPrinter { +// fn visit_stmt(&self, stmt: &Stmt) -> String { +// match stmt { +// Stmt::Expr(expr) => self.visit_expr(expr), +// Stmt::Val(name, expr) => format!("(val '{}' <- {})", name, +// self.visit_expr(expr)), Stmt::Var(name, expr) => format!("(var +// '{}' <- {})", name, self.visit_expr(expr)), } +// } + +// fn visit_expr(&self, expr: &Expr) -> String { +// match expr { +// Expr::Literal(i) => i.to_string(), +// Expr::Add(lhs, rhs) => { +// let lhs = self.visit_expr(lhs); +// let rhs = self.visit_expr(rhs); +// format!("({lhs} + {rhs})") +// } +// Expr::Sub(lhs, rhs) => { +// let lhs = self.visit_expr(lhs); +// let rhs = self.visit_expr(rhs); +// format!("({lhs} - {rhs})") +// } +// Expr::Mul(lhs, rhs) => { +// let lhs = self.visit_expr(lhs); +// let rhs = self.visit_expr(rhs); +// format!("({lhs} * {rhs})") +// } +// Expr::Div(lhs, rhs) => { +// let lhs = self.visit_expr(lhs); +// let rhs = self.visit_expr(rhs); +// format!("({lhs} / {rhs})") +// } +// } +// } +// } |
