diff options
| author | Cody <cody@codyq.dev> | 2023-06-07 03:28:40 -0500 |
|---|---|---|
| committer | Cody <cody@codyq.dev> | 2023-06-07 03:28:40 -0500 |
| commit | 6f6613419f1511c5637c9f69b3caa5ae838270b9 (patch) | |
| tree | e203d6cdc0eb2140ae6f0a430e76f2992de66bec /sloth | |
| parent | 25c5ccb29a6f2387a04bfb5d50874e00084c15d6 (diff) | |
| download | sloth-6f6613419f1511c5637c9f69b3caa5ae838270b9.tar.gz | |
Moving over from a VM interpreter to natively compiled w/ LLVM
Diffstat (limited to 'sloth')
| -rw-r--r-- | sloth/Cargo.toml | 11 | ||||
| -rw-r--r-- | sloth/src/compiler/mod.rs | 131 | ||||
| -rw-r--r-- | sloth/src/lexer.rs | 559 | ||||
| -rw-r--r-- | sloth/src/main.rs | 43 | ||||
| -rw-r--r-- | sloth/src/parser/ast.rs | 115 | ||||
| -rw-r--r-- | sloth/src/parser/expr.rs | 261 | ||||
| -rw-r--r-- | sloth/src/parser/mod.rs | 57 | ||||
| -rw-r--r-- | sloth/src/parser/stmt.rs | 646 |
8 files changed, 1823 insertions, 0 deletions
diff --git a/sloth/Cargo.toml b/sloth/Cargo.toml new file mode 100644 index 0000000..4fabdb7 --- /dev/null +++ b/sloth/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "sloth" + +license.workspace = true +version.workspace = true +edition.workspace = true + +[dependencies] +inkwell = { version = "0.2.0", features = ["llvm15-0"] } +itertools = "0.10.5" +thiserror = "1.0.40" diff --git a/sloth/src/compiler/mod.rs b/sloth/src/compiler/mod.rs new file mode 100644 index 0000000..87c0618 --- /dev/null +++ b/sloth/src/compiler/mod.rs @@ -0,0 +1,131 @@ +#![allow(unused)] + +use std::collections::HashMap; +use std::path::Path; +use std::vec; + +use inkwell::builder::Builder; +use inkwell::context::Context; +use inkwell::module::Module; +use inkwell::targets::{ + CodeModel, FileType, InitializationConfig, RelocMode, Target, TargetMachine, +}; +use inkwell::values::IntValue; +use inkwell::OptimizationLevel; + +use crate::parser::ast::{BinaryOp, Expr, FuncArgs, Literal, Stmt, UnaryOp}; + +pub struct Compiler<'ctx> { + context: &'ctx Context, + builder: Builder<'ctx>, + module: Module<'ctx>, +} + +impl<'ctx> Compiler<'ctx> { + pub fn new(context: &'ctx Context) -> Self { + let builder = context.create_builder(); + let module = context.create_module("sloth"); + + Self { + context, + builder, + module, + } + } + + pub fn compile(&self, src: Vec<Stmt>) { + for stmt in src { + match stmt { + Stmt::DefineFunction { + ident, + args, + body, + return_type, + } => { + self.compile_function(&ident, &args, return_type.is_some(), body); + } + _ => panic!("You may only define a function top level"), + } + } + + Target::initialize_native(&InitializationConfig::default()).unwrap(); + + let triple = TargetMachine::get_default_triple(); + let target = Target::from_triple(&triple).unwrap(); + let machine = target + .create_target_machine( + &triple, + "x86-64", + "", + OptimizationLevel::None, + RelocMode::Default, + CodeModel::Default, + ) + .unwrap(); + + self.module.set_triple(&triple); + machine + .write_to_file(&self.module, FileType::Object, Path::new("output.o")) + .unwrap(); + } + + fn compile_function(&self, identifier: &str, args: &[FuncArgs], returns: bool, src: Vec<Stmt>) { + let void_type = self.context.void_type(); + let i64_type = self.context.i64_type(); + + let function_type = if returns { + i64_type.fn_type(&vec![i64_type.into(); args.len()], false) + } else { + void_type.fn_type(&vec![i64_type.into(); args.len()], false) + }; + let function = self.module.add_function(identifier, function_type, None); + + let basic_block = self.context.append_basic_block(function, "body"); + + self.builder.position_at_end(basic_block); + + let mut arg_values = HashMap::<String, IntValue>::new(); + for (i, arg) in args.iter().enumerate() { + arg_values.insert( + arg.name.clone(), + function.get_nth_param(i as u32).unwrap().into_int_value(), + ); + } + + for stmt in src { + match stmt { + Stmt::Return { value } => match value { + Expr::BinaryOp { op, lhs, rhs } => { + let lhs = match *lhs { + Expr::Variable(a) => arg_values[&a], + _ => unimplemented!(), + }; + + let rhs = match *rhs { + Expr::Variable(a) => arg_values[&a], + _ => unimplemented!(), + }; + + let res = match op { + BinaryOp::Add => self.builder.build_int_add(lhs, rhs, "addop"), + BinaryOp::Sub => self.builder.build_int_sub(lhs, rhs, "subop"), + _ => unimplemented!(), + }; + + self.builder.build_return(Some(&res)); + return; + } + Expr::Variable(name) => { + let var = arg_values[&name]; + self.builder.build_return(Some(&var)); + return; + } + _ => unimplemented!(), + }, + _ => unimplemented!(), + } + } + + self.builder.build_return(None); + } +} diff --git a/sloth/src/lexer.rs b/sloth/src/lexer.rs new file mode 100644 index 0000000..0afaf1c --- /dev/null +++ b/sloth/src/lexer.rs @@ -0,0 +1,559 @@ +#![allow(dead_code)] + +//! TODO: Lexing Regex Literals + +use std::str::Chars; + +use thiserror::Error; + +#[derive(Debug, Clone, PartialEq, Error)] +pub enum LexerError { + #[error("Unexpected token")] + UnexpectedToken, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum TokenType { + // Meta + DocComment, + Comment, + + // Brackets + OpeningParen, // ( + ClosingParen, // ) + OpeningBracket, // [ + ClosingBracket, // ] + OpeningBrace, // { + ClosingBrace, // } + + // Operators + Plus, // + + PlusPlus, // ++ + Minus, // - + Star, // * + StarStar, // ** + Slash, // / + Perc, // % + Tilde, // ~ + + PlusEq, // += + PlusPlusEq, // ++= + MinusEq, // -= + StarEq, // *= + StarStarEq, // **= + SlashEq, // /= + PercEq, // %= + TildeEq, // ~= + + Amp, // & + AmpAmp, // && + Pipe, // | + PipePipe, // || + Caret, // ^ + + Eq, // = + EqEq, // == + Bang, // ! + BangBang, // !! + BangEq, // != + + Lt, // < + LtLt, // << + LtEq, // <= + LtLtEq, // <<= + Gt, // > + GtGt, // >> + GtEq, // >= + GtGtEq, // >>= + + Comma, + + Question, // ? + QuestionDot, // ?. + QuestionQuestion, // ?? + Dot, // . + DotDot, // .. + + Colon, // : + ColonColon, // :: + SemiColon, // ; + + Arrow, // -> + FatArrow, // => + + // Keywords + Val, + Var, + + Fn, + Return, + + If, + Else, + + While, + For, + In, + + Loop, + Break, + Continue, + + As, + + // Literals + Integer(i128), + Float(f64), + Boolean(bool), + Character(char), + String(String), + Regex(String), + + Identifier(String), + + // Utility + Error(LexerError), +} + +#[derive(Debug, Default, Clone, Copy)] +pub struct Location { + index: usize, + pub row: u32, + pub col: u32, +} + +impl Location { + fn advance(&mut self, len: usize, newline: bool) { + if newline { + self.row += 1; + self.col = 0; + } else { + self.col += 1; + } + self.index += len; + } +} + +#[derive(Debug)] +pub struct Token<'a> { + pub tt: TokenType, + pub lexeme: &'a str, + + start: Location, + end: Location, +} + +pub struct Lexer<'a> { + source: &'a [u8], + window: [char; 3], + chars: Chars<'a>, + + start: Location, + current: Location, + + // Keep track if the lexer has encountered an error to stop lexing asap + errored: bool, +} + +impl<'a> Lexer<'a> { + pub(crate) fn new(source: &'a str) -> Self { + let mut chars = source.chars(); + let window = [ + chars.next().unwrap_or('\0'), + chars.next().unwrap_or('\0'), + chars.next().unwrap_or('\0'), + ]; + + Self { + source: source.as_bytes(), + window, + chars, + start: Default::default(), + current: Default::default(), + errored: false, + } + } +} + +impl<'a> Lexer<'a> { + fn pos(&self) -> usize { + self.current.index + } + + fn peek(&self) -> char { + self.window[0] + } + + fn eof(&self) -> bool { + self.peek() == '\0' + } + + fn advance(&mut self) -> char { + let current = self.window[0]; + self.window = [ + self.window[1], + self.window[2], + self.chars.next().unwrap_or('\0'), + ]; + self.current.advance(current.len_utf8(), current == '\n'); + current + } + + fn advance_with(&mut self, with: TokenType) -> TokenType { + self.advance(); + with + } + + fn advance_by(&mut self, amount: usize) { + for _ in 0..amount { + self.advance(); + } + } + + fn advance_by_with(&mut self, amount: usize, with: TokenType) -> TokenType { + self.advance_by(amount); + with + } + + fn advance_while(&mut self, predicate: impl Fn([char; 3]) -> bool) { + while !self.eof() && predicate(self.window) { + self.advance(); + } + } +} + +impl<'a> Lexer<'a> { + fn lex_number(&mut self) -> TokenType { + let mut value = self.advance().to_string(); + + while self.peek().is_ascii_digit() { + value.push(self.advance()); + } + + if self.peek() == '.' { + value.push(self.advance()); + + while self.peek().is_ascii_digit() { + value.push(self.advance()); + } + + TokenType::Float(value.parse::<f64>().expect("Expected float")) + } else { + TokenType::Integer(value.parse::<i128>().expect("Expected integer")) + } + } + + fn lex_string(&mut self) -> TokenType { + let mut value = String::new(); + + self.advance(); + loop { + match self.window { + ['\\', '"', ..] => { + self.advance_by(2); + value.push('"'); + } + ['\\', 't', ..] => { + self.advance_by(2); + value.push('\t'); + } + ['\\', 'n', ..] => { + self.advance_by(2); + value.push('\n'); + } + ['"', ..] => { + self.advance(); + break; + } + _ => { + value.push(self.advance()); + continue; + } + } + } + + TokenType::String(value) + } +} + +impl<'a> Iterator for Lexer<'a> { + type Item = Token<'a>; + + fn next(&mut self) -> Option<Self::Item> { + // Skipping whitespace + self.advance_while(|it| it[0].is_whitespace()); + self.start = self.current; + + // If were at the end of the file or an error has occurred return nothing + if self.eof() || self.errored { + return None; + } + + // Figuring out the token type + let tt = match self.window { + ['#', '#', ..] => { + self.advance_while(|it| it[0] != '\n'); + // TODO: TokenType::DocComment + return self.next(); + } + + ['#', ..] => { + self.advance_while(|it| it[0] != '\n'); + // TODO: okenType::Comment + return self.next(); + } + + // Blocks + ['(', ..] => self.advance_with(TokenType::OpeningParen), + [')', ..] => self.advance_with(TokenType::ClosingParen), + ['[', ..] => self.advance_with(TokenType::OpeningBracket), + [']', ..] => self.advance_with(TokenType::ClosingBracket), + ['{', ..] => self.advance_with(TokenType::OpeningBrace), + ['}', ..] => self.advance_with(TokenType::ClosingBrace), + + // Operators + ['-', '>', ..] => self.advance_by_with(2, TokenType::Arrow), + ['=', '>', ..] => self.advance_by_with(2, TokenType::FatArrow), + + ['+', '+', '='] => self.advance_by_with(3, TokenType::PlusPlusEq), + ['*', '*', '='] => self.advance_by_with(3, TokenType::StarStarEq), + ['+', '+', ..] => self.advance_by_with(2, TokenType::PlusPlus), + ['*', '*', ..] => self.advance_by_with(2, TokenType::StarStar), + + ['+', '=', ..] => self.advance_by_with(2, TokenType::PlusEq), + ['-', '=', ..] => self.advance_by_with(2, TokenType::MinusEq), + ['*', '=', ..] => self.advance_by_with(2, TokenType::StarEq), + ['/', '=', ..] => self.advance_by_with(2, TokenType::SlashEq), + ['%', '=', ..] => self.advance_by_with(2, TokenType::PercEq), + ['~', '=', ..] => self.advance_by_with(2, TokenType::TildeEq), + + ['+', ..] => self.advance_with(TokenType::Plus), + ['-', ..] => self.advance_with(TokenType::Minus), + ['*', ..] => self.advance_with(TokenType::Star), + ['/', ..] => self.advance_with(TokenType::Slash), // TODO: Check for regex literals + ['%', ..] => self.advance_with(TokenType::Perc), + ['~', ..] => self.advance_with(TokenType::Tilde), + + ['&', '&', ..] => self.advance_by_with(2, TokenType::AmpAmp), + ['&', ..] => self.advance_with(TokenType::Amp), + + ['|', '|', ..] => self.advance_by_with(2, TokenType::PipePipe), + ['|', ..] => self.advance_with(TokenType::Pipe), + + ['^', ..] => self.advance_by_with(2, TokenType::Caret), + + ['=', '=', ..] => self.advance_by_with(2, TokenType::EqEq), + ['!', '=', ..] => self.advance_by_with(2, TokenType::BangEq), + ['!', '!', ..] => self.advance_by_with(2, TokenType::BangBang), + ['=', ..] => self.advance_with(TokenType::Eq), + ['!', ..] => self.advance_with(TokenType::Bang), + + ['<', '<', '='] => self.advance_by_with(3, TokenType::LtLtEq), + ['<', '<', ..] => self.advance_by_with(2, TokenType::LtLt), + ['<', '=', ..] => self.advance_by_with(2, TokenType::LtEq), + ['<', ..] => self.advance_with(TokenType::Lt), + + ['>', '>', '='] => self.advance_by_with(3, TokenType::GtGtEq), + ['>', '>', ..] => self.advance_by_with(2, TokenType::GtGt), + ['>', '=', ..] => self.advance_by_with(2, TokenType::GtEq), + ['>', ..] => self.advance_with(TokenType::Gt), + + [',', ..] => self.advance_with(TokenType::Comma), + + ['.', '.', ..] => self.advance_by_with(2, TokenType::DotDot), + ['.', ..] => self.advance_with(TokenType::Dot), + ['?', '?', ..] => self.advance_by_with(2, TokenType::QuestionQuestion), + ['?', '.', ..] => self.advance_by_with(2, TokenType::QuestionDot), + ['?', ..] => self.advance_with(TokenType::Question), + + [';', ..] => self.advance_with(TokenType::SemiColon), + [':', ':', ..] => self.advance_by_with(2, TokenType::ColonColon), + [':', ..] => self.advance_with(TokenType::Colon), + + // Literals + ['\'', c, '\''] => self.advance_by_with(3, TokenType::Character(c)), + ['0'..='9', ..] => self.lex_number(), + ['"', ..] => self.lex_string(), + + ['a'..='z' | 'A'..='Z' | '_' | '$', ..] => { + let mut value = String::new(); + while matches!(self.peek(), 'a'..='z' | 'A'..='Z' | '0'..='9' | '_' | '$') { + value.push(self.advance()); + } + + match value.as_str() { + "val" => TokenType::Val, + "var" => TokenType::Var, + "fn" => TokenType::Fn, + "return" => TokenType::Return, + "if" => TokenType::If, + "else" => TokenType::Else, + "while" => TokenType::While, + "for" => TokenType::For, + "in" => TokenType::In, + "loop" => TokenType::Loop, + "break" => TokenType::Break, + "continue" => TokenType::Continue, + "as" => TokenType::As, + "true" => TokenType::Boolean(true), + "false" => TokenType::Boolean(false), + _ => TokenType::Identifier(value), + } + } + + _ => { + self.errored = true; + TokenType::Error(LexerError::UnexpectedToken) + } + }; + + let lexeme = unsafe { + // At this point it is already known that the string is valid UTF-8, might + // aswell not check again + std::str::from_utf8_unchecked(&self.source[self.start.index..self.pos()]) + }; + + let token = Token { + tt, + lexeme, + start: self.start, + end: self.current, + }; + + Some(token) + } +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + + use super::{Lexer, TokenType}; + use crate::lexer::LexerError; + + #[test] + fn lex_operators() { + let source = "+ ++ - * ** / % ~ += ++= -= *= **= /= %= ~= & && | || ^ = == ! !! != < << \ + <<= <= > >> >>= >= , ? ?. ?? . .. : :: ; -> =>"; + let tokens = Lexer::new(source).map(|it| it.tt).collect_vec(); + + assert_eq!(&tokens, &[ + TokenType::Plus, + TokenType::PlusPlus, + TokenType::Minus, + TokenType::Star, + TokenType::StarStar, + TokenType::Slash, + TokenType::Perc, + TokenType::Tilde, + TokenType::PlusEq, + TokenType::PlusPlusEq, + TokenType::MinusEq, + TokenType::StarEq, + TokenType::StarStarEq, + TokenType::SlashEq, + TokenType::PercEq, + TokenType::TildeEq, + TokenType::Amp, + TokenType::AmpAmp, + TokenType::Pipe, + TokenType::PipePipe, + TokenType::Caret, + TokenType::Eq, + TokenType::EqEq, + TokenType::Bang, + TokenType::BangBang, + TokenType::BangEq, + TokenType::Lt, + TokenType::LtLt, + TokenType::LtLtEq, + TokenType::LtEq, + TokenType::Gt, + TokenType::GtGt, + TokenType::GtGtEq, + TokenType::GtEq, + TokenType::Comma, + TokenType::Question, + TokenType::QuestionDot, + TokenType::QuestionQuestion, + TokenType::Dot, + TokenType::DotDot, + TokenType::Colon, + TokenType::ColonColon, + TokenType::SemiColon, + TokenType::Arrow, + TokenType::FatArrow, + ]); + } + + #[test] + fn lex_keywords() { + let source = "val var fn if else while for in loop break continue as true false"; + let tokens = Lexer::new(source).map(|it| it.tt).collect_vec(); + + assert_eq!(&tokens, &[ + TokenType::Val, + TokenType::Var, + TokenType::Fn, + TokenType::If, + TokenType::Else, + TokenType::While, + TokenType::For, + TokenType::In, + TokenType::Loop, + TokenType::Break, + TokenType::Continue, + TokenType::As, + TokenType::Boolean(true), + TokenType::Boolean(false), + ]); + } + + #[test] + fn lex_literals_a() { + let source = "foo bar _foo __bar $0 $$1 \"foo\" \"bar\" \"baz\" \"\\\"\" \"\\n\" \"\\t\" \ + 'a' 'b' '\"' 93 3252 238 -382 -832 83 -25 52.9 83.7 12.4 35.2 3.3"; + let tokens = Lexer::new(source).map(|it| it.tt).collect_vec(); + + assert_eq!(&tokens, &[ + TokenType::Identifier("foo".to_owned()), + TokenType::Identifier("bar".to_owned()), + TokenType::Identifier("_foo".to_owned()), + TokenType::Identifier("__bar".to_owned()), + TokenType::Identifier("$0".to_owned()), + TokenType::Identifier("$$1".to_owned()), + TokenType::String("foo".to_owned()), + TokenType::String("bar".to_owned()), + TokenType::String("baz".to_owned()), + TokenType::String("\"".to_owned()), + TokenType::String("\n".to_owned()), + TokenType::String("\t".to_owned()), + TokenType::Character('a'), + TokenType::Character('b'), + TokenType::Character('"'), + TokenType::Integer(93), + TokenType::Integer(3252), + TokenType::Integer(238), + TokenType::Minus, + TokenType::Integer(382), + TokenType::Minus, + TokenType::Integer(832), + TokenType::Integer(83), + TokenType::Minus, + TokenType::Integer(25), + TokenType::Float(52.9), + TokenType::Float(83.7), + TokenType::Float(12.4), + TokenType::Float(35.2), + TokenType::Float(3.3), + ]); + } + + #[test] + fn lex_errors() { + let source = "`"; + let tokens = Lexer::new(source).map(|it| it.tt).collect_vec(); + + assert_eq!(&tokens, &[TokenType::Error(LexerError::UnexpectedToken)]); + } +} diff --git a/sloth/src/main.rs b/sloth/src/main.rs new file mode 100644 index 0000000..a611156 --- /dev/null +++ b/sloth/src/main.rs @@ -0,0 +1,43 @@ +#![warn( + clippy::wildcard_imports, + clippy::string_add, + clippy::string_add_assign, + clippy::manual_ok_or, + unused_lifetimes +)] + +pub mod compiler; +pub mod lexer; +pub mod parser; + +use std::{env, fs}; + +use compiler::Compiler; +use inkwell::context::Context; +use itertools::Itertools; +use lexer::Lexer; +use parser::AstParser; + +fn main() { + let args = env::args().collect_vec(); + + if args.len() < 2 { + println!("Sloth programming language interpreter\n"); + println!("Usage: sloth <file>"); + return; + } + + let source_path = &args[1]; + let Ok(source) = fs::read_to_string(source_path) else { + println!("Error while reading '{source_path}'"); + return; + }; + + let tokens = Lexer::new(&source).collect_vec(); + let ast = AstParser::new(tokens).parse(); + + let context = Context::create(); + let compiler = Compiler::new(&context); + + compiler.compile(ast); +} diff --git a/sloth/src/parser/ast.rs b/sloth/src/parser/ast.rs new file mode 100644 index 0000000..543ea3a --- /dev/null +++ b/sloth/src/parser/ast.rs @@ -0,0 +1,115 @@ +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum BinaryOp { + Add, + Con, + Sub, + Mul, + Pow, + Div, + Mod, + + BWSftRight, + BWSftLeft, + BWAnd, + BWOr, + BWXor, + + Lt, + Gt, + LtEq, + GtEq, + EqEq, + NotEq, + LogAnd, + LogOr, + Range, +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum UnaryOp { + Not, + Neg, + + BWComp, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum Literal { + Integer(i128), + Float(f64), + Bool(bool), + Char(char), + String(String), + Regex(String), + List(Vec<Expr>), +} + +#[derive(Debug, Clone, PartialEq)] +pub enum Expr { + Grouping(Box<Expr>), + BinaryOp { + op: BinaryOp, + lhs: Box<Expr>, + rhs: Box<Expr>, + }, + UnaryOp { + op: UnaryOp, + value: Box<Expr>, + }, + Call { + ident: Box<Expr>, + args: Vec<Expr>, + }, + Variable(String), + Literal(Literal), + Lambda, // TODO: Lambda +} + +#[derive(PartialEq, Clone, Debug)] +pub struct FuncArgs { + pub name: String, + pub typ: Option<String>, +} + +#[derive(PartialEq, Clone, Debug)] +pub enum Stmt { + ExprStmt(Expr), + DefineFunction { + ident: String, + args: Vec<FuncArgs>, + body: Vec<Stmt>, + return_type: Option<String>, + }, + DefineVariable { + name: String, + value: Expr, + typ: Option<String>, + }, + DefineValue { + name: String, + value: Expr, + typ: Option<String>, + }, + AssignVariable { + name: String, + value: Expr, + }, + If { + expr: Expr, + body: Vec<Stmt>, + else_if: Vec<(Expr, Stmt)>, + els: Option<Box<Stmt>>, + }, + For { + name: String, + iter: Expr, + body: Vec<Stmt>, + }, + While { + condition: Expr, + body: Vec<Stmt>, + }, + Return { + value: Expr, + }, +} diff --git a/sloth/src/parser/expr.rs b/sloth/src/parser/expr.rs new file mode 100644 index 0000000..9e81f7f --- /dev/null +++ b/sloth/src/parser/expr.rs @@ -0,0 +1,261 @@ +use super::ast::{BinaryOp, Expr, Literal, UnaryOp}; +use super::AstParser; +use crate::lexer::TokenType; + +/// Implementation containing parsers internal components related to expressions +impl<'a> AstParser<'a> { + // FIXME: Should probably avoid cloning token types + + pub fn expression(&mut self) -> Expr { + self.logical_or() + } + + fn unary(&mut self) -> Expr { + if !self.eof() + && matches!( + self.peek().tt, + TokenType::Bang | TokenType::Plus | TokenType::Minus + ) + { + let operator = match self.advance().unwrap().tt.clone() { + TokenType::Bang => UnaryOp::Not, + TokenType::Tilde => UnaryOp::BWComp, + TokenType::Minus => UnaryOp::Neg, + _ => panic!(), + }; + + let rhs = self.unary(); + return Expr::UnaryOp { + op: (operator), + value: (Box::new(rhs)), + }; + } + + self.call() + } + + fn call(&mut self) -> Expr { + let mut expr = self.primary(); + + if self.advance_if_eq(&TokenType::OpeningParen) { + let mut arguments = Vec::<Expr>::new(); + + if self.peek().tt != TokenType::ClosingParen { + loop { + arguments.push(self.expression()); + if !self.advance_if_eq(&TokenType::Comma) { + break; + } + } + } + + self.consume( + TokenType::ClosingParen, + "Expected ')' to close off function call", + ); + + // let Expr::Variable(_ident) = expr else { panic!("uh oh spaghettio"); }; + + expr = Expr::Call { + ident: (Box::new(expr)), + args: (arguments), + } + } + + expr + } + + fn primary(&mut self) -> Expr { + match self.advance().unwrap().tt.clone() { + TokenType::Integer(literal) => Expr::Literal(Literal::Integer(literal)), + TokenType::Float(literal) => Expr::Literal(Literal::Float(literal)), + TokenType::Boolean(literal) => Expr::Literal(Literal::Bool(literal)), + TokenType::Character(literal) => Expr::Literal(Literal::Char(literal)), + TokenType::String(literal) => Expr::Literal(Literal::String(literal)), + TokenType::Regex(literal) => Expr::Literal(Literal::Regex(literal)), + TokenType::Identifier(ident) => Expr::Variable(ident), + TokenType::OpeningParen => { + let expr = self.expression(); + self.consume(TokenType::ClosingParen, "Must end expression with ')'"); + Expr::Grouping(Box::new(expr)) + } + TokenType::OpeningBracket => { + let mut expr: Vec<Expr> = Vec::new(); + + while !self.eof() && self.peek().tt != TokenType::ClosingBracket { + let exp = self.expression(); + expr.push(exp); + + self.advance_if_eq(&TokenType::Comma); + } + self.consume(TokenType::ClosingBracket, "Expected ']' at end of list"); + Expr::Literal(Literal::List(expr)) + } + _ => unimplemented!("{:?}", self.peek()), + } + } +} + +// Macro to generate repetitive binary expressions. Things like addition, +// multiplication, exc. +macro_rules! binary_expr { + ($name:ident, $parent:ident, $pattern:pat) => { + fn $name(&mut self) -> Expr { + let mut expr = self.$parent(); + + while !self.eof() && matches!(self.peek().tt, $pattern) { + let operator = match self.advance().unwrap().tt.clone() { + TokenType::Plus => BinaryOp::Add, + TokenType::PlusPlus => BinaryOp::Con, + TokenType::Minus => BinaryOp::Sub, + TokenType::Star => BinaryOp::Mul, + TokenType::StarStar => BinaryOp::Pow, + TokenType::Slash => BinaryOp::Div, + TokenType::Perc => BinaryOp::Mod, + TokenType::DotDot => BinaryOp::Range, + + TokenType::LtLt => BinaryOp::BWSftRight, + TokenType::GtGt => BinaryOp::BWSftLeft, + TokenType::Amp => BinaryOp::BWAnd, + TokenType::Pipe => BinaryOp::BWOr, + TokenType::Caret => BinaryOp::BWXor, + + TokenType::Lt => BinaryOp::Lt, + TokenType::Gt => BinaryOp::Gt, + TokenType::LtEq => BinaryOp::LtEq, + TokenType::GtEq => BinaryOp::GtEq, + TokenType::EqEq => BinaryOp::EqEq, + TokenType::BangEq => BinaryOp::NotEq, + TokenType::AmpAmp => BinaryOp::LogAnd, + TokenType::PipePipe => BinaryOp::LogOr, + _ => panic!("uh oh spagghetio"), + }; + + let rhs = self.$parent(); + expr = Expr::BinaryOp { + op: (operator), + lhs: (Box::new(expr)), + rhs: (Box::new(rhs)), + } + } + + expr + } + }; +} + +#[rustfmt::skip] +#[allow(unused_parens)] +impl<'a> AstParser<'a> { + // Binary expressions in order of precedence from lowest to highest. + binary_expr!(logical_or , logical_and , (TokenType::PipePipe)); + binary_expr!(logical_and , range , (TokenType::AmpAmp)); + binary_expr!(range , equality , (TokenType::DotDot)); + binary_expr!(equality , comparison , (TokenType::BangEq | TokenType::EqEq)); + binary_expr!(comparison , bitwise_shifting, (TokenType::Lt | TokenType::Gt | TokenType::LtEq | TokenType::GtEq)); + binary_expr!(bitwise_shifting, additive , (TokenType::LtLt | TokenType::GtGt)); + binary_expr!(additive , multiplicative , (TokenType::Plus | TokenType::Minus)); + binary_expr!(multiplicative , unary , (TokenType::Star | TokenType::Slash | TokenType::Perc)); +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + + use super::{AstParser, BinaryOp, Expr, Literal}; + use crate::lexer::Lexer; + use crate::parser::ast::UnaryOp; + + #[test] + fn basic_expression_a() { + let lexer = Lexer::new("3 + 5 * 4"); + let tokens = lexer.collect_vec(); + + let expected_ast = Expr::BinaryOp { + op: BinaryOp::Add, + lhs: Box::new(Expr::Literal(Literal::Integer(3))), + rhs: Box::new(Expr::BinaryOp { + op: BinaryOp::Mul, + lhs: Box::new(Expr::Literal(Literal::Integer(5))), + rhs: Box::new(Expr::Literal(Literal::Integer(4))), + }), + }; + + let mut parser = AstParser::new(tokens); + let generated_ast = parser.expression(); + + println!("Expected AST:\n{expected_ast:#?}\n\n"); + println!("Generated AST:\n{generated_ast:#?}\n\n"); + + assert_eq!(expected_ast, generated_ast); + } + + #[test] + fn basic_expression_b() { + let lexer = Lexer::new("17 - (-5 + 5) / 6"); + let tokens = lexer.collect_vec(); + + let expected_ast = Expr::BinaryOp { + op: BinaryOp::Sub, + lhs: Box::new(Expr::Literal(Literal::Integer(17))), + rhs: Box::new(Expr::BinaryOp { + op: BinaryOp::Div, + lhs: Box::new(Expr::Grouping(Box::new(Expr::BinaryOp { + op: BinaryOp::Add, + lhs: Box::new(Expr::UnaryOp { + op: UnaryOp::Neg, + value: Box::new(Expr::Literal(Literal::Integer(5))), + }), + rhs: Box::new(Expr::Literal(Literal::Integer(5))), + }))), + rhs: Box::new(Expr::Literal(Literal::Integer(6))), + }), + }; + + let mut parser = AstParser::new(tokens); + let generated_ast = parser.expression(); + + println!("Expected AST:\n{expected_ast:#?}\n\n"); + println!("Generated AST:\n{generated_ast:#?}\n\n"); + + assert_eq!(expected_ast, generated_ast); + } + #[test] + fn basic_expression_c() { + let lexer = Lexer::new("[1, 2, 3]"); + let tokens = lexer.collect_vec(); + + let expected_ast = Expr::Literal(Literal::List(vec![ + Expr::Literal(Literal::Integer(1)), + Expr::Literal(Literal::Integer(2)), + Expr::Literal(Literal::Integer(3)), + ])); + + let mut parser = AstParser::new(tokens); + let generated_ast = parser.expression(); + + println!("Expected AST:\n{expected_ast:#?}\n\n"); + println!("Generated AST:\n{generated_ast:#?}\n\n"); + + assert_eq!(expected_ast, generated_ast); + } + #[test] + fn basic_expression_d() { + let lexer = Lexer::new("1 .. 17"); + let tokens = lexer.collect_vec(); + + let expected_ast = Expr::BinaryOp { + op: (BinaryOp::Range), + lhs: (Box::new(Expr::Literal(Literal::Integer(1)))), + rhs: (Box::new(Expr::Literal(Literal::Integer(17)))), + }; + + let mut parser = AstParser::new(tokens); + let generated_ast = parser.expression(); + + println!("Expected AST:\n{expected_ast:#?}\n\n"); + println!("Generated AST:\n{generated_ast:#?}\n\n"); + + assert_eq!(expected_ast, generated_ast); + } +} diff --git a/sloth/src/parser/mod.rs b/sloth/src/parser/mod.rs new file mode 100644 index 0000000..9d77acc --- /dev/null +++ b/sloth/src/parser/mod.rs @@ -0,0 +1,57 @@ +pub mod ast; +pub mod expr; +pub mod stmt; + +use crate::lexer::{Token, TokenType}; +#[derive(Debug)] +pub struct AstParser<'a> { + tokens: Vec<Token<'a>>, + index: usize, +} + +/// Implementation containing utilities used by the parsers internal components +impl<'a> AstParser<'a> { + pub fn new(tokens: Vec<Token<'a>>) -> Self { + Self { tokens, index: 0 } + } + pub fn peek(&self) -> &Token { + &self.tokens[self.index] + } + + pub fn advance(&mut self) -> Option<&Token> { + if self.eof() { + return None; + } + + self.index += 1; + Some(&self.tokens[self.index - 1]) + } + + pub fn advance_if(&mut self, next: impl FnOnce(&Token) -> bool) -> bool { + if self.eof() { + return false; + } + + if next(self.peek()) { + self.advance(); + return true; + } + + false + } + + pub fn advance_if_eq(&mut self, next: &TokenType) -> bool { + self.advance_if(|it| it.tt == *next) + } + + pub fn consume(&mut self, next: TokenType, error: &str) { + if std::mem::discriminant(&self.peek().tt) != std::mem::discriminant(&next) { + panic!("{error} at index {:?}", self.index); + } + self.advance(); + } + + pub fn eof(&self) -> bool { + self.index >= self.tokens.len() + } +} diff --git a/sloth/src/parser/stmt.rs b/sloth/src/parser/stmt.rs new file mode 100644 index 0000000..1a961b1 --- /dev/null +++ b/sloth/src/parser/stmt.rs @@ -0,0 +1,646 @@ +use super::ast::{Expr, FuncArgs, Stmt}; +use super::AstParser; +use crate::lexer::TokenType; + +impl<'a> AstParser<'a> { + pub fn parse(&mut self) -> Vec<Stmt> { + let mut statements = Vec::new(); + + while !self.eof() { + statements.push(self.statement()); + } + + statements + } + + fn statement(&mut self) -> Stmt { + if self.advance_if_eq(&TokenType::Var) { + return self.var_statement(); + } + + if self.advance_if_eq(&TokenType::Val) { + return self.val_statement(); + } + + if self.advance_if_eq(&TokenType::If) { + return self.if_statement(); + } + + if self.advance_if_eq(&TokenType::For) { + return self.for_statement(); + } + + if self.advance_if_eq(&TokenType::While) { + return self.while_statement(); + } + + if self.advance_if_eq(&TokenType::Fn) { + return self.function_statement(); + } + + if self.advance_if_eq(&TokenType::Return) { + return self.return_statement(); + } + + self.mut_statement() + + // If we couldn't parse a statement return an expression statement + // self.expression_statement() + } + + fn mut_statement(&mut self) -> Stmt { + let TokenType::Identifier(ident) = self.peek().tt.clone() else { + panic!("Identifier error {:?}", self.peek()); + }; + + self.advance(); + let next = self.advance().unwrap().tt.clone(); + if next == TokenType::Eq { + let value = self.expression(); + self.consume(TokenType::SemiColon, "No semi colon for me i guess"); + return Stmt::AssignVariable { + name: (ident), + value: (value), + }; + } else if next == TokenType::OpeningParen { + let mut arguments = Vec::<Expr>::new(); + + if self.peek().tt != TokenType::ClosingParen { + loop { + arguments.push(self.expression()); + if !self.advance_if_eq(&TokenType::Comma) { + break; + } + } + } + + self.consume( + TokenType::ClosingParen, + "Expected ')' to close off function call", + ); + + self.consume(TokenType::SemiColon, "No semi colon for me i guess"); + return Stmt::ExprStmt(Expr::Call { + ident: Box::new(Expr::Variable(ident)), + args: (arguments), + }); + } + self.expression_statement() + } + + fn var_statement(&mut self) -> Stmt { + let TokenType::Identifier(ident) = self.peek().tt.clone() else { + panic!("Identifier expected after 'var', not {:?}", self.peek()); + }; + + self.advance(); + + let mut typ: Option<String> = None; + if self.peek().tt.clone() == TokenType::Colon { + self.consume(TokenType::Colon, "How did you even get this error?"); + let TokenType::Identifier(name) = self.peek().tt.clone() else { + panic!("Type expected after identifier, not {:?}", self.peek()); + }; + self.advance(); + typ = Some(name); + } + + self.consume(TokenType::Eq, "Expected '=' after identifier at "); + + let value = self.expression(); + + self.consume(TokenType::SemiColon, "Expected ';' at end of statement"); + + Stmt::DefineVariable { + name: (ident), + value: (value), + typ: (typ), + } + } + + fn val_statement(&mut self) -> Stmt { + let TokenType::Identifier(ident) = self.peek().tt.clone() else { + panic!("Identifier expected after 'val'"); + }; + + self.advance(); // Advancing from the identifier + + let mut typ: Option<String> = None; + if self.peek().tt.clone() == TokenType::Colon { + self.consume(TokenType::Colon, "How did you even get this error?"); + let TokenType::Identifier(name) = self.peek().tt.clone() else { + panic!("Type expected after identifier, not {:?}", self.peek()); + }; + self.advance(); + typ = Some(name); + } + + self.consume(TokenType::Eq, "Expected '=' after identifier"); + + let value = self.expression(); + + self.consume(TokenType::SemiColon, "Expected ';' at end of statement"); + + Stmt::DefineValue { + name: (ident), + value: (value), + typ: (typ), + } + } + + fn if_statement(&mut self) -> Stmt { + let condition = self.expression(); + + self.consume( + TokenType::OpeningBrace, + "Expected '{' at beggining of block", + ); + let mut body = Vec::new(); + while !self.eof() && self.peek().tt != TokenType::ClosingBrace { + body.push(self.statement()); + } + self.advance(); + Stmt::If { + expr: (condition), + body: (body), + else_if: (Vec::new()), + els: (None), + } // TODO: implement else if and else + } + + fn for_statement(&mut self) -> Stmt { + let binding = self.expression(); + let Expr::Variable(binding) = binding else { + panic!("Left side of for statement must be identifier"); + }; + + self.consume( + TokenType::In, + "Expected 'in' in between identifier and range", + ); + + // let range_start = self.expression(); + // self.consume( + // TokenType::DotDot, + // "Expected '..' denoting min and max of range", + // ); + // let range_end = self.expression(); + + let expr = self.expression(); + + self.consume(TokenType::OpeningBrace, "Expected '{' after iterator"); + + let mut body = Vec::new(); + while !self.eof() && self.peek().tt != TokenType::ClosingBrace { + body.push(self.statement()); + } + self.advance(); + + Stmt::For { + name: (binding), + iter: (expr), + body: (body), + } + } // TODO: Fix this garbage + + fn while_statement(&mut self) -> Stmt { + let condition = self.expression(); + + self.consume( + TokenType::OpeningBrace, + "Expected '{' at beggining of block", + ); + let mut body = Vec::new(); + while !self.eof() && self.peek().tt != TokenType::ClosingBrace { + println!("{:?}", self.peek().tt); + body.push(self.statement()); + } + self.consume( + TokenType::ClosingBrace, + "Expected '}' after block on while loop", + ); + + self.advance(); + Stmt::While { condition, body } + } + + fn expression_statement(&mut self) -> Stmt { + let expr = self.expression(); + + // FIXME: Move assignment handling + // if self.advance_if_eq(&TokenType::Eq) { + // if let Expr::Literal(_ident) = &expr { + // let value = self.expression(); + + // self.consume( + // TokenType::SemiColon, + // "Expected ';' at end of + // statement", + // ); // return Stmt::DefineVariable { + // // name: (ident.clone()), + // // value: (value), + // // typ: (None), + // // }; + // return Stmt::ExprStmt(expr); + // } + // } + + self.consume( + TokenType::SemiColon, + "Expected ';' at end of expr statement", + ); + Stmt::ExprStmt(expr) + } + + fn function_statement(&mut self) -> Stmt { + let TokenType::Identifier(ident) = self.advance().unwrap().tt.clone() else { + panic!("Identifier expected after 'fn'"); + }; + + self.consume(TokenType::OpeningParen, "Expected '(' after identifier"); + let mut args: Vec<FuncArgs> = Vec::new(); + while !self.eof() && self.peek().tt != TokenType::ClosingParen { + let TokenType::Identifier(name) = self.advance().unwrap().tt.clone() else { + panic!("parameter expected after '('"); + }; + + let mut typ: Option<String> = None; + + if self.peek().tt.clone() == TokenType::Colon { + self.consume(TokenType::Colon, "How did you even get this error?"); + let TokenType::Identifier(name) = self.peek().tt.clone() else { + panic!("Type expected after ':', not {:?}", self.peek()); + }; + self.advance(); + typ = Some(name); + } + + self.advance_if_eq(&TokenType::Comma); + + let arg = FuncArgs { + name: (name), + typ: (typ), + }; + args.push(arg); + } + self.advance(); + let mut typ: Option<String> = None; + if self.peek().tt.clone() == TokenType::Arrow { + self.advance(); + let TokenType::Identifier(name) = self.peek().tt.clone() else { + panic!("Type expected after ':', not {:?}", self.peek()); + }; + typ = Some(name); + self.advance(); + } + self.consume(TokenType::OpeningBrace, "Expected '{' after parameters"); + let mut body = Vec::new(); + while !self.eof() && self.peek().tt != TokenType::ClosingBrace { + body.push(self.statement()); + } + self.consume(TokenType::ClosingBrace, "Expected '}' after body"); + + Stmt::DefineFunction { + ident: (ident), + args: (args), + body: (body), + return_type: (typ), + } + } + + fn return_statement(&mut self) -> Stmt { + let expr = self.expression(); + self.consume(TokenType::SemiColon, "Expected ';' after return statement"); + Stmt::Return { value: (expr) } + } +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + + use super::{AstParser, Stmt}; + use crate::lexer::Lexer; + use crate::parser::ast::{BinaryOp, Expr, FuncArgs, Literal, UnaryOp}; + + #[test] + fn basic_statement_a() { + let lexer = Lexer::new("var test_a: int = 5 + 3;"); + let tokens = lexer.collect_vec(); + + let expected_ast = Stmt::DefineVariable { + name: ("test_a".to_string()), + value: (Expr::BinaryOp { + op: (BinaryOp::Add), + lhs: (Box::new(Expr::Literal(Literal::Integer(5)))), + rhs: (Box::new(Expr::Literal(Literal::Integer(3)))), + }), + typ: Some("int".to_string()), + }; + + let mut parser = AstParser::new(tokens); + let generated_ast = parser.statement(); + + println!("Expected AST:\n{expected_ast:#?}\n\n"); + println!("Generated AST:\n{generated_ast:#?}\n\n"); + + assert_eq!(expected_ast, generated_ast); + } + + #[test] + fn basic_statement_b() { + let lexer = Lexer::new("val test_b = \"Hello World\";"); + let tokens = lexer.collect_vec(); + + let expected_ast = Stmt::DefineValue { + name: ("test_b".to_string()), + value: (Expr::Literal(Literal::String("Hello World".to_string()))), + typ: (None), + }; + + let mut parser = AstParser::new(tokens); + let generated_ast = parser.statement(); + + println!("Expected AST:\n{expected_ast:#?}\n\n"); + println!("Generated AST:\n{generated_ast:#?}\n\n"); + + assert_eq!(expected_ast, generated_ast); + } + + #[test] + fn basic_statement_c() { + let lexer = Lexer::new( + "\ + fn test_c (a, b, c) {\nreturn (a + b * c);\n}", + ); + let tokens = lexer.collect_vec(); + println!("{tokens:?}"); + + let expected_ast = Stmt::DefineFunction { + ident: ("test_c".to_string()), + args: (vec![ + FuncArgs { + name: ("a".to_string()), + typ: None, + }, + FuncArgs { + name: ("b".to_string()), + typ: None, + }, + FuncArgs { + name: ("c".to_string()), + typ: None, + }, + ]), + body: (vec![Stmt::Return { + value: (Expr::Grouping(Box::new(Expr::BinaryOp { + op: BinaryOp::Add, + lhs: Box::new(Expr::Variable("a".to_string())), + rhs: Box::new(Expr::BinaryOp { + op: BinaryOp::Mul, + lhs: Box::new(Expr::Variable("b".to_string())), + rhs: Box::new(Expr::Variable("c".to_string())), + }), + }))), + }]), + return_type: (None), + }; + + let mut parser = AstParser::new(tokens); + let generated_ast = parser.statement(); + + println!("Expected AST:\n{expected_ast:#?}\n\n"); + println!("Generated AST:\n{generated_ast:#?}\n\n"); + + assert_eq!(expected_ast, generated_ast); + } + #[test] + fn basic_statement_d() { + let lexer = Lexer::new( + "\ + while true {\nprint(\"Hello World\");\nprintln(5 + 7/-3);\n}", + ); + let tokens = lexer.collect_vec(); + println!("{tokens:?}"); + + let expected_ast = Stmt::While { + condition: (Expr::Literal(Literal::Bool(true))), + body: (vec![ + Stmt::ExprStmt(Expr::Call { + ident: Box::new(Expr::Variable("print".to_string())), + args: (vec![Expr::Literal(Literal::String("Hello World".to_string()))]), + }), + Stmt::ExprStmt(Expr::Call { + ident: Box::new(Expr::Variable("println".to_string())), + args: (vec![Expr::BinaryOp { + op: (BinaryOp::Add), + lhs: (Box::new(Expr::Literal(Literal::Integer(5)))), + rhs: (Box::new(Expr::BinaryOp { + op: (BinaryOp::Div), + lhs: (Box::new(Expr::Literal(Literal::Integer(7)))), + rhs: (Box::new(Expr::UnaryOp { + op: (UnaryOp::Neg), + value: (Box::new(Expr::Literal(Literal::Integer(3)))), + })), + })), + }]), + }), + ]), + }; + + let mut parser = AstParser::new(tokens); + let generated_ast = parser.statement(); + + println!("Expected AST:\n{expected_ast:#?}\n\n"); + println!("Generated AST:\n{generated_ast:#?}\n\n"); + + assert_eq!(expected_ast, generated_ast); + } + #[test] + fn basic_statement_e() { + let lexer = Lexer::new( + "\ + if a+5 > 10 {\nprint(a);\n}\nif a+5 < 10 {\nprintln(10);\n}\nif a+5 == 10 \ + {\nprint(toString(10));\na = true;\n}", + ); + let tokens = lexer.collect_vec(); + // println!("{tokens:?}"); + + let expected_ast = vec![ + Stmt::If { + expr: (Expr::BinaryOp { + op: (BinaryOp::Gt), + lhs: (Box::new(Expr::BinaryOp { + op: (BinaryOp::Add), + lhs: (Box::new(Expr::Variable("a".to_string()))), + rhs: (Box::new(Expr::Literal(Literal::Integer(5)))), + })), + rhs: (Box::new(Expr::Literal(Literal::Integer(10)))), + }), + body: (vec![Stmt::ExprStmt(Expr::Call { + ident: (Box::new(Expr::Variable("print".to_string()))), + args: (vec![Expr::Variable("a".to_string())]), + })]), + else_if: (Vec::new()), + els: (None), + }, + Stmt::If { + expr: (Expr::BinaryOp { + op: (BinaryOp::Lt), + lhs: (Box::new(Expr::BinaryOp { + op: (BinaryOp::Add), + lhs: (Box::new(Expr::Variable("a".to_string()))), + rhs: (Box::new(Expr::Literal(Literal::Integer(5)))), + })), + rhs: (Box::new(Expr::Literal(Literal::Integer(10)))), + }), + body: (vec![Stmt::ExprStmt(Expr::Call { + ident: (Box::new(Expr::Variable("println".to_string()))), + args: (vec![Expr::Literal(Literal::Integer(10))]), + })]), + else_if: (Vec::new()), + els: (None), + }, + Stmt::If { + expr: (Expr::BinaryOp { + op: (BinaryOp::EqEq), + lhs: (Box::new(Expr::BinaryOp { + op: (BinaryOp::Add), + lhs: (Box::new(Expr::Variable("a".to_string()))), + rhs: (Box::new(Expr::Literal(Literal::Integer(5)))), + })), + rhs: (Box::new(Expr::Literal(Literal::Integer(10)))), + }), + body: (vec![ + Stmt::ExprStmt(Expr::Call { + ident: (Box::new(Expr::Variable("print".to_string()))), + // ident: (Box::new(Expr::Literal(Literal::String("print".to_string())))), + args: (vec![Expr::Call { + ident: (Box::new(Expr::Variable("toString".to_string()))), + // ident: Box::new(Expr::Literal(Literal::String("toString". + // to_string()))), + args: vec![Expr::Literal(Literal::Integer(10))], + }]), + }), + Stmt::AssignVariable { + name: ("a".to_string()), + value: (Expr::Literal(Literal::Bool(true))), + }, + ]), + + else_if: (Vec::new()), + els: (None), + }, + ]; + + let mut parser = AstParser::new(tokens); + let generated_ast = parser.parse(); + + println!("Expected AST:\n{expected_ast:#?}\n\n"); + println!("Generated AST:\n{generated_ast:#?}\n\n"); + + assert_eq!(expected_ast, generated_ast); + } + + #[test] + fn basic_statement_f() { + let lexer = Lexer::new("test_a = 5 + 3;"); + let tokens = lexer.collect_vec(); + + let expected_ast = Stmt::AssignVariable { + name: ("test_a".to_string()), + value: (Expr::BinaryOp { + op: (BinaryOp::Add), + lhs: (Box::new(Expr::Literal(Literal::Integer(5)))), + rhs: (Box::new(Expr::Literal(Literal::Integer(3)))), + }), + }; + + let mut parser = AstParser::new(tokens); + let generated_ast = parser.statement(); + + println!("Expected AST:\n{expected_ast:#?}\n\n"); + println!("Generated AST:\n{generated_ast:#?}\n\n"); + + assert_eq!(expected_ast, generated_ast); + } + #[test] + fn basic_statement_g() { + let lexer = Lexer::new( + "\ + fn times_two(x: int) -> int {\nval y: int = x*2;\nreturn y;\n}", + ); + let tokens = lexer.collect_vec(); + + let expected_ast = Stmt::DefineFunction { + ident: ("times_two".to_string()), + args: (vec![FuncArgs { + name: ("x".to_string()), + typ: (Some("int".to_string())), + }]), + body: (vec![ + Stmt::DefineValue { + name: "y".to_string(), + value: (Expr::BinaryOp { + op: (BinaryOp::Mul), + lhs: (Box::new(Expr::Variable("x".to_string()))), + rhs: (Box::new(Expr::Literal(Literal::Integer(2)))), + }), + typ: Some("int".to_string()), + }, + Stmt::Return { + value: (Expr::Variable("y".to_string())), + }, + ]), + + return_type: Some("int".to_string()), + }; + + let mut parser = AstParser::new(tokens); + let generated_ast = parser.statement(); + + println!("Expected AST:\n{expected_ast:#?}\n\n"); + println!("Generated AST:\n{generated_ast:#?}\n\n"); + + assert_eq!(expected_ast, generated_ast); + } + + #[test] + fn basic_statement_h() { + let lexer = Lexer::new("for i in 1 .. 3 {\nfor j in [1, 2, 3] {\nprint(j*i);}}"); + let tokens = lexer.collect_vec(); + + let expected_ast = Stmt::For { + name: ("i".to_string()), + iter: (Expr::BinaryOp { + op: (BinaryOp::Range), + lhs: (Box::new(Expr::Literal(Literal::Integer(1)))), + rhs: (Box::new(Expr::Literal(Literal::Integer(3)))), + }), + body: (vec![Stmt::For { + name: ("j".to_string()), + iter: (Expr::Literal(Literal::List(vec![ + Expr::Literal(Literal::Integer(1)), + Expr::Literal(Literal::Integer(2)), + Expr::Literal(Literal::Integer(3)), + ]))), + body: (vec![Stmt::ExprStmt(Expr::Call { + ident: Box::new(Expr::Variable("print".to_string())), + args: (vec![Expr::BinaryOp { + op: (BinaryOp::Mul), + lhs: (Box::new(Expr::Variable("j".to_string()))), + rhs: (Box::new(Expr::Variable("i".to_string()))), + }]), + })]), + }]), + }; + + let mut parser = AstParser::new(tokens); + let generated_ast = parser.statement(); + + println!("Expected AST:\n{expected_ast:#?}\n\n"); + println!("Generated AST:\n{generated_ast:#?}\n\n"); + + assert_eq!(expected_ast, generated_ast); + } +} |
