From 2949f0705a294c8c4baed126c9d2019b352cde3e Mon Sep 17 00:00:00 2001 From: DavidLee18 Date: Fri, 21 Feb 2025 22:03:30 +0900 Subject: [PATCH] impl arithmetics --- src/compiler/mod.rs | 125 ++++++++++++++++++++++++++++++------- src/compiler/primitives.rs | 26 ++++++++ src/core/mod.rs | 8 +-- src/main.rs | 6 +- src/test.rs | 18 ++++++ 5 files changed, 156 insertions(+), 27 deletions(-) create mode 100644 src/compiler/primitives.rs diff --git a/src/compiler/mod.rs b/src/compiler/mod.rs index 972b2d3..90d526a 100644 --- a/src/compiler/mod.rs +++ b/src/compiler/mod.rs @@ -1,11 +1,14 @@ +use crate::compiler::primitives::Primitive; use crate::core::{map_accuml, Addr, Heap, ASSOC}; use crate::lang; use std::cmp::max; +pub mod primitives; + type TiState = (TiStack, TiDump, TiHeap, TiGlobals, TiStats); type TiStack = Vec; -type TiDump = (); +type TiDump = Vec; type TiHeap = Heap; @@ -15,6 +18,16 @@ pub(crate) enum Node { SuperComb(lang::Name, Vec, lang::CoreExpr), Num(i64), Ind(Addr), + Prim(lang::Name, Primitive), +} + +impl Node { + pub fn is_data_node(&self) -> bool { + match self { + Node::Num(_) => true, + _ => false, + } + } } type TiGlobals = ASSOC; @@ -39,7 +52,7 @@ pub(crate) fn compile(p: lang::CoreProgram) -> TiState { let alloc_count = init_heap.alloc_count(); ( vec![*main_addr], - (), + vec![], init_heap, globals, TiStats { @@ -58,8 +71,20 @@ const EXTRA_PRELUDE_DEFS: &'static str = ""; fn build_init_heap(sc_defs: Vec) -> (TiHeap, TiGlobals) { let mut init_heap = Heap::new(); - let globals = map_accuml(allocate_sc, &mut init_heap, sc_defs); - (init_heap, globals) + let mut sc_addrs = map_accuml(allocate_sc, &mut init_heap, sc_defs); + let mut prim_addrs = map_accuml( + allocate_prim, + &mut init_heap, + Vec::from(primitives::PRIMITIVES), + ); + sc_addrs.append(&mut prim_addrs); + (init_heap, sc_addrs) +} + +fn allocate_prim(heap: &mut TiHeap, primitive: (&'static str, Primitive)) -> (lang::Name, Addr) { + let (name, prim) = primitive; + let addr = heap.alloc(Node::Prim(lang::Name::from(name), prim)); + (lang::Name::from(name), addr) } fn allocate_sc(heap: &mut TiHeap, sc_defs: lang::CoreScDefn) -> (lang::Name, Addr) { @@ -84,23 +109,87 @@ pub(crate) fn eval(state: TiState) -> Vec { } fn step(state: &mut TiState) { - let (stack, _, heap, _, _) = state; + let (stack, dump, heap, _, _) = state; // println!("Stack: {:?}", stack); // println!("{:?}", heap); let last_stack = *stack.last().expect("Empty stack"); match heap.lookup(last_stack).expect("cannot be found on heap") { - Node::Ap(a1, _) => stack.push(*a1), - Node::SuperComb(sc, args, body) => { - let (_, args, body) = (sc.clone(), args.clone(), body.clone()); + Node::Ap(a1, a2) => { + let a1 = *a1; + if let Some(Node::Ind(a3)) = heap.lookup(*a2) { + heap.update(last_stack, Node::Ap(a1, *a3)); + } + stack.push(a1); + } + Node::SuperComb(_, args, body) => { + let (args, body) = (args.clone(), body.clone()); sc_step(state, last_stack, args, body) } - Node::Num(_) => panic!("Number applied as a function"), + Node::Num(_) => { + assert_eq!(stack.len(), 1, "Number applied as a function"); + *stack = dump.pop().expect("empty dump"); + } Node::Ind(r) => { stack.pop(); stack.push(*r); } + Node::Prim(_, p) => { + let p = p.clone(); + prim_step(state, p) + } + } +} + +fn prim_step(state: &mut TiState, prim: Primitive) { + let (stack, dump, heap, _, _) = state; + let args_len = match &prim { + Primitive::Neg => 1, + _ => 2, + }; + assert_eq!( + stack.len(), + args_len + 1, + "args length mismatch: expected {}, got {}", + args_len + 1, + stack.len() + ); + let args = heap.get_args(stack, args_len); + match &prim { + Primitive::Neg => { + let arg = heap.lookup(args[0]).expect("cannot find arg"); + match arg { + Node::Num(n) => { + stack.pop(); + heap.update(*stack.last().unwrap(), Node::Num(-n)); + } + _ => { + dump.push(stack.clone()); + *stack = args; + } + } + } + _ => { + let arg1 = heap.lookup(args[0]).expect("cannot find arg1"); + let Node::Num(n) = arg1 else { + dump.push(stack.clone()); + *stack = args; + return; + }; + let arg2 = heap.lookup(args[0]).expect("cannot find arg2"); + let Node::Num(m) = arg2 else { + dump.push(stack.clone()); + *stack = args; + return; + }; + stack.pop(); + stack.pop(); + heap.update( + *stack.last().unwrap(), + Node::Num(primitives::arith(&prim, *n, *m)), + ); + } } } @@ -133,24 +222,18 @@ fn sc_step(state: &mut TiState, sc_addr: Addr, arg_names: Vec, body: } fn ti_final(state: &TiState) -> bool { - match state.0.len() { - 1 => state - .2 - .lookup(state.0[0]) - .map(is_data_node) + let (stack, dump, heap, _, _) = state; + dump.is_empty() + && match stack.len() { + 1 => heap + .lookup(stack[0]) + .map(Node::is_data_node) .unwrap_or(false), 0 => panic!("Empty stack!"), _ => false, } } -fn is_data_node(node: &Node) -> bool { - match node { - Node::Num(_) => true, - _ => false, - } -} - fn do_admin(state: &mut TiState) { let (stack, _, heap, _, stat) = state; stat.heap_alloc_count = heap.alloc_count(); diff --git a/src/compiler/primitives.rs b/src/compiler/primitives.rs new file mode 100644 index 0000000..e84add3 --- /dev/null +++ b/src/compiler/primitives.rs @@ -0,0 +1,26 @@ +#[derive(Debug, Clone, PartialEq)] +pub enum Primitive { + Neg, + Add, + Sub, + Mul, + Div, +} + +pub const PRIMITIVES: [(&'static str, Primitive); 5] = [ + ("negate", Primitive::Neg), + ("+", Primitive::Add), + ("-", Primitive::Sub), + ("*", Primitive::Mul), + ("/", Primitive::Div), +]; + +pub fn arith(p: &Primitive, a: i64, b: i64) -> i64 { + match p { + Primitive::Neg => panic!("not supported"), + Primitive::Add => a + b, + Primitive::Sub => a - b, + Primitive::Mul => a * b, + Primitive::Div => a / b, + } +} diff --git a/src/core/mod.rs b/src/core/mod.rs index 4c68786..0d86614 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -96,14 +96,11 @@ impl Heap { i += 1; res.push(*arg) } - _ => continue, + _ => panic!("arg is not an Ap node"), } } if i < len { - panic!( - "not enough args: expected {}, got {}\nargs: {:?}", - len, i, res - ); + panic!("not enough args: expected {}, got {}", len, i); } res } @@ -141,6 +138,7 @@ impl Heap { root_addr: Addr, env: &mut ASSOC, ) { + // println!("env: {:?}", env); match body { CoreExpr::Var(a) => { let a_addr = env diff --git a/src/main.rs b/src/main.rs index 0b90ad0..20b3c76 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,6 +15,10 @@ fn run(s: String) -> String { } fn main() { - let program = String::from("oct g x = let h = twice g in let k = twice h in k (k x); main = oct I 4"); + let program = String::from( + r#" + main = negate (I 3) + "#, + ); println!("{}", run(program)); } diff --git a/src/test.rs b/src/test.rs index bc45788..00be23b 100644 --- a/src/test.rs +++ b/src/test.rs @@ -130,3 +130,21 @@ fn letrec() { let res_stack = compiler::get_stack_results(res.last().expect("Empty states")); assert_eq!(res_stack, vec![compiler::Node::Num(4)]) } + +#[test] +fn negate() { + let res = compiler::eval(compiler::compile(lang::parse_raw(String::from( + "main = negate 3", + )))); + let res_stack = compiler::get_stack_results(res.last().expect("Empty states")); + assert_eq!(res_stack, vec![compiler::Node::Num(-3)]) +} + +#[test] +fn simple_arithmetic() { + let res = compiler::eval(compiler::compile(lang::parse_raw(String::from( + "main = 4*5+(2-5)", + )))); + let res_stack = compiler::get_stack_results(res.last().expect("Empty states")); + assert_eq!(res_stack, vec![compiler::Node::Num(17)]) +} \ No newline at end of file