Skip to content

Commit

Permalink
impl arithmetics
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidLee18 committed Feb 21, 2025
1 parent 9396340 commit 2949f07
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 27 deletions.
125 changes: 104 additions & 21 deletions src/compiler/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use crate::compiler::primitives::Primitive;
use crate::core::{map_accuml, Addr, Heap, ASSOC};
use crate::lang;
use std::cmp::max;

pub mod primitives;

type TiState = (TiStack, TiDump, TiHeap, TiGlobals, TiStats);

type TiStack = Vec<Addr>;
type TiDump = ();
type TiDump = Vec<TiStack>;

type TiHeap = Heap<Node>;

Expand All @@ -15,6 +18,16 @@ pub(crate) enum Node {
SuperComb(lang::Name, Vec<lang::Name>, lang::CoreExpr),
Num(i64),
Ind(Addr),
Prim(lang::Name, Primitive),
}

impl Node {
pub fn is_data_node(&self) -> bool {
match self {
Node::Num(_) => true,
_ => false,
}
}
}
type TiGlobals = ASSOC<lang::Name, Addr>;

Expand All @@ -39,7 +52,7 @@ pub(crate) fn compile(p: lang::CoreProgram) -> TiState {
let alloc_count = init_heap.alloc_count();
(
vec![*main_addr],
(),
vec![],
init_heap,
globals,
TiStats {
Expand All @@ -58,8 +71,20 @@ const EXTRA_PRELUDE_DEFS: &'static str = "";

fn build_init_heap(sc_defs: Vec<lang::CoreScDefn>) -> (TiHeap, TiGlobals) {
let mut init_heap = Heap::new();
let globals = map_accuml(allocate_sc, &mut init_heap, sc_defs);
(init_heap, globals)
let mut sc_addrs = map_accuml(allocate_sc, &mut init_heap, sc_defs);
let mut prim_addrs = map_accuml(
allocate_prim,
&mut init_heap,
Vec::from(primitives::PRIMITIVES),
);
sc_addrs.append(&mut prim_addrs);
(init_heap, sc_addrs)
}

fn allocate_prim(heap: &mut TiHeap, primitive: (&'static str, Primitive)) -> (lang::Name, Addr) {
let (name, prim) = primitive;
let addr = heap.alloc(Node::Prim(lang::Name::from(name), prim));
(lang::Name::from(name), addr)
}

fn allocate_sc(heap: &mut TiHeap, sc_defs: lang::CoreScDefn) -> (lang::Name, Addr) {
Expand All @@ -84,23 +109,87 @@ pub(crate) fn eval(state: TiState) -> Vec<TiState> {
}

fn step(state: &mut TiState) {
let (stack, _, heap, _, _) = state;
let (stack, dump, heap, _, _) = state;
// println!("Stack: {:?}", stack);
// println!("{:?}", heap);

let last_stack = *stack.last().expect("Empty stack");

match heap.lookup(last_stack).expect("cannot be found on heap") {
Node::Ap(a1, _) => stack.push(*a1),
Node::SuperComb(sc, args, body) => {
let (_, args, body) = (sc.clone(), args.clone(), body.clone());
Node::Ap(a1, a2) => {
let a1 = *a1;
if let Some(Node::Ind(a3)) = heap.lookup(*a2) {
heap.update(last_stack, Node::Ap(a1, *a3));
}
stack.push(a1);
}
Node::SuperComb(_, args, body) => {
let (args, body) = (args.clone(), body.clone());
sc_step(state, last_stack, args, body)
}
Node::Num(_) => panic!("Number applied as a function"),
Node::Num(_) => {
assert_eq!(stack.len(), 1, "Number applied as a function");
*stack = dump.pop().expect("empty dump");
}
Node::Ind(r) => {
stack.pop();
stack.push(*r);
}
Node::Prim(_, p) => {
let p = p.clone();
prim_step(state, p)
}
}
}

fn prim_step(state: &mut TiState, prim: Primitive) {
let (stack, dump, heap, _, _) = state;
let args_len = match &prim {
Primitive::Neg => 1,
_ => 2,
};
assert_eq!(
stack.len(),
args_len + 1,
"args length mismatch: expected {}, got {}",
args_len + 1,
stack.len()
);
let args = heap.get_args(stack, args_len);
match &prim {
Primitive::Neg => {
let arg = heap.lookup(args[0]).expect("cannot find arg");
match arg {
Node::Num(n) => {
stack.pop();
heap.update(*stack.last().unwrap(), Node::Num(-n));
}
_ => {
dump.push(stack.clone());
*stack = args;
}
}
}
_ => {
let arg1 = heap.lookup(args[0]).expect("cannot find arg1");
let Node::Num(n) = arg1 else {
dump.push(stack.clone());
*stack = args;
return;
};
let arg2 = heap.lookup(args[0]).expect("cannot find arg2");
let Node::Num(m) = arg2 else {
dump.push(stack.clone());
*stack = args;
return;
};
stack.pop();
stack.pop();
heap.update(
*stack.last().unwrap(),
Node::Num(primitives::arith(&prim, *n, *m)),
);
}
}
}

Expand Down Expand Up @@ -133,24 +222,18 @@ fn sc_step(state: &mut TiState, sc_addr: Addr, arg_names: Vec<lang::Name>, body:
}

fn ti_final(state: &TiState) -> bool {
match state.0.len() {
1 => state
.2
.lookup(state.0[0])
.map(is_data_node)
let (stack, dump, heap, _, _) = state;
dump.is_empty()
&& match stack.len() {
1 => heap
.lookup(stack[0])
.map(Node::is_data_node)
.unwrap_or(false),
0 => panic!("Empty stack!"),
_ => false,
}
}

fn is_data_node(node: &Node) -> bool {
match node {
Node::Num(_) => true,
_ => false,
}
}

fn do_admin(state: &mut TiState) {
let (stack, _, heap, _, stat) = state;
stat.heap_alloc_count = heap.alloc_count();
Expand Down
26 changes: 26 additions & 0 deletions src/compiler/primitives.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#[derive(Debug, Clone, PartialEq)]
pub enum Primitive {
Neg,
Add,
Sub,
Mul,
Div,
}

pub const PRIMITIVES: [(&'static str, Primitive); 5] = [
("negate", Primitive::Neg),
("+", Primitive::Add),
("-", Primitive::Sub),
("*", Primitive::Mul),
("/", Primitive::Div),
];

pub fn arith(p: &Primitive, a: i64, b: i64) -> i64 {
match p {
Primitive::Neg => panic!("not supported"),
Primitive::Add => a + b,
Primitive::Sub => a - b,
Primitive::Mul => a * b,
Primitive::Div => a / b,
}
}
8 changes: 3 additions & 5 deletions src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,11 @@ impl Heap<Node> {
i += 1;
res.push(*arg)
}
_ => continue,
_ => panic!("arg is not an Ap node"),
}
}
if i < len {
panic!(
"not enough args: expected {}, got {}\nargs: {:?}",
len, i, res
);
panic!("not enough args: expected {}, got {}", len, i);
}
res
}
Expand Down Expand Up @@ -141,6 +138,7 @@ impl Heap<Node> {
root_addr: Addr,
env: &mut ASSOC<Name, Addr>,
) {
// println!("env: {:?}", env);
match body {
CoreExpr::Var(a) => {
let a_addr = env
Expand Down
6 changes: 5 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ fn run(s: String) -> String {
}

fn main() {
let program = String::from("oct g x = let h = twice g in let k = twice h in k (k x); main = oct I 4");
let program = String::from(
r#"
main = negate (I 3)
"#,
);
println!("{}", run(program));
}
18 changes: 18 additions & 0 deletions src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,21 @@ fn letrec() {
let res_stack = compiler::get_stack_results(res.last().expect("Empty states"));
assert_eq!(res_stack, vec![compiler::Node::Num(4)])
}

#[test]
fn negate() {
let res = compiler::eval(compiler::compile(lang::parse_raw(String::from(
"main = negate 3",
))));
let res_stack = compiler::get_stack_results(res.last().expect("Empty states"));
assert_eq!(res_stack, vec![compiler::Node::Num(-3)])
}

#[test]
fn simple_arithmetic() {
let res = compiler::eval(compiler::compile(lang::parse_raw(String::from(
"main = 4*5+(2-5)",
))));
let res_stack = compiler::get_stack_results(res.last().expect("Empty states"));
assert_eq!(res_stack, vec![compiler::Node::Num(17)])
}

0 comments on commit 2949f07

Please sign in to comment.