Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Parse HyperTransition syntax #295

Merged
merged 7 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/compiler/abepi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ impl<F: From<u64> + TryInto<u32> + Clone + Debug, V: Clone + Debug> CompilationU
Statement::Transition(dsym, id, stmt) => {
self.compiler_statement_transition(dsym, id, *stmt)
}
Statement::HyperTransition(dsym, ids, call, state) => {
self.compiler_statement_hyper_transition(dsym, ids, call, state)
}
_ => vec![],
}
}
Expand Down Expand Up @@ -420,6 +423,16 @@ impl<F: From<u64> + TryInto<u32> + Clone + Debug, V: Clone + Debug> CompilationU

result
}

fn compiler_statement_hyper_transition(
&self,
_dsym: DebugSymRef,
_ids: Vec<V>,
_call: Expression<F, V>,
_state: V,
) -> Vec<CompilationResult<F, V>> {
todo!("Compile expressions? Needs specs")
}
}

fn flatten_bin_op<F: Clone, V: Clone>(
Expand Down
86 changes: 45 additions & 41 deletions src/compiler/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -558,57 +558,22 @@ mod test {

use crate::{
compiler::{compile, compile_file, compile_legacy},
parser::ast::debug_sym_factory::DebugSymRefFactory,
parser::{ast::debug_sym_factory::DebugSymRefFactory, lang::TLDeclsParser},
wit_gen::TraceGenerator,
};

use super::Config;

// TODO rewrite the test after machines are able to call other machines
// TODO improve the test for HyperTransition
#[test]
fn test_compiler_fibo_multiple_machines() {
// Source code containing two machines
let circuit = "
machine fibo1 (signal n) (signal b: field) {
// n and be are created automatically as shared
// signals
signal a: field, i;

// there is always a state called initial
// input signals get bound to the signal
// in the initial state (first instance)
state initial {
signal c;

i, a, b, c <== 1, 1, 1, 2;

-> middle {
i', a', b', n' <== i + 1, b, c, n;
}
}

state middle {
signal c;

c <== a + b;

if i + 1 == n {
-> final {
i', b', n' <== i + 1, c, n;
}
} else {
-> middle {
i', a', b', n' <== i + 1, b, c, n;
}
}
}

// There is always a state called final.
// Output signals get automatically bound to the signals
// with the same name in the final step (last instance).
// This state can be implicit if there are no constraints in it.
machine caller (signal n) (signal b: field) {
signal b_1: field;
b_1' <== fibo(n) -> final;
}
machine fibo2 (signal n) (signal b: field) {
machine fibo (signal n) (signal b: field) {
// n and be are created automatically as shared
// signals
signal a: field, i;
Expand Down Expand Up @@ -839,4 +804,43 @@ mod test {
r#"[SemErr { msg: "use of undeclared variable c", dsym: test/circuit_error.chiquito:24:39 }, SemErr { msg: "use of undeclared variable c", dsym: test/circuit_error.chiquito:28:46 }]"#
)
}

#[test]
fn test_parse_hyper_transition() {
let circuit = "
machine caller (signal n) (signal b: field) {
a', b, c' <== fibo(d, e, f + g) -> final;
}
";

let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit);
let result = TLDeclsParser::new().parse(&debug_sym_ref_factory, circuit);

assert!(result.is_ok());

let circuit = "
machine caller (signal n) (signal b: field) {
-> final {
a', b, c' <== fibo(d, e, f + g);
}
}
";

let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit);
let result = TLDeclsParser::new().parse(&debug_sym_ref_factory, circuit);

assert!(result.is_ok());

// TODO should no-arg calls be allowed? Needs more specs for function/machine calls
let circuit = "
machine caller (signal n) (signal b: field) {
smth <== a() -> final;
}
";

let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit);
let result = TLDeclsParser::new().parse(&debug_sym_ref_factory, circuit);

assert!(result.is_ok());
}
}
11 changes: 11 additions & 0 deletions src/compiler/semantic/analyser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ impl Analyser {

Statement::SignalDecl(_, _) => {}
Statement::WGVarDecl(_, _) => {}
Statement::HyperTransition(_, ids, call, state) => {
self.analyse_expression(call);
self.collect_id_usages(&[state]);
self.collect_id_usages(&ids);
}
}
}

Expand Down Expand Up @@ -308,6 +313,12 @@ impl Analyser {
} => {
self.extract_usages_expression(&sub);
}
Expression::Call(_, fun, exprs) => {
self.collect_id_usages(&[fun]);
exprs
.into_iter()
.for_each(|expr| self.extract_usages_expression(&expr));
}
_ => {}
}
}
Expand Down
3 changes: 3 additions & 0 deletions src/compiler/semantic/rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ fn undeclared_rule(analyser: &mut Analyser, expr: &Expression<BigInt, Identifier
undeclared_rule(analyser, when_false);
}
Expression::Const(_, _) | Expression::True(_) | Expression::False(_) => {}
Expression::Call(_, _, args) => {
args.iter().for_each(|arg| undeclared_rule(analyser, arg));
}
}
}

Expand Down
1 change: 1 addition & 0 deletions src/compiler/setup_inter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ impl SetupInterpreter {

SignalAssignment(_, _, _) | WGAssignment(_, _, _) => vec![],
SignalDecl(_, _) | WGVarDecl(_, _) => vec![],
HyperTransition(_, _, _, _) => todo!("Implement compilation for hyper transitions"),
};

self.add_poly_constraints(result.into_iter().map(|cr| cr.anti_booly).collect());
Expand Down
3 changes: 3 additions & 0 deletions src/interpreter/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ pub(crate) fn eval_expr<F: Field + Hash, V: Identifiable>(
Const(_, v) => Ok(Value::Field(F::from_big_int(v))),
True(_) => Ok(Value::Bool(true)),
False(_) => Ok(Value::Bool(false)),
Call(_, _, _) => {
todo!("Needs specs. Evaluate the argument expressions, evaluate the function output?")
}
}
.map_err(|msg| Message::RuntimeErr {
msg,
Expand Down
1 change: 1 addition & 0 deletions src/interpreter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ impl<'a, F: Field + Hash> Interpreter<'a, F> {
Block(_, stmts) => self.exec_step_block(stmts),
Assert(_, _) => Ok(None),
StateDecl(_, _, _) => Ok(None),
HyperTransition(_, _, _, _) => todo!("Needs specs"),
}
}

Expand Down
15 changes: 15 additions & 0 deletions src/parser/ast/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,12 @@ pub enum Expression<F, V> {
Const(DebugSymRef, F),
True(DebugSymRef),
False(DebugSymRef),
/// Function or machine call.
/// Tuple values:
/// - debug symbol reference;
/// - function/machine ID;
/// - call argument expressions vector.
Call(DebugSymRef, V, Vec<Expression<F, V>>),
}

// Shorthand for BigInt expression
Expand All @@ -217,6 +223,9 @@ impl<F, V> Expression<F, V> {
Const(_, _) => true,
True(_) => false,
False(_) => false,
Call(_, _, _) => {
todo!("Needs specs. For a function call, depends on the function return type?")
}
}
}

Expand All @@ -234,6 +243,9 @@ impl<F, V> Expression<F, V> {

when_true.is_logic()
}
Expression::Call { .. } => {
todo!("Needs specs. For a function call, depends on the function return type?")
}
_ => false,
}
}
Expand All @@ -247,6 +259,7 @@ impl<F, V> Expression<F, V> {
Expression::Const(dsym, _) => dsym,
Expression::True(dsym) => dsym,
Expression::False(dsym) => dsym,
Expression::Call(dsym, _, _) => dsym,
}
}

Expand All @@ -260,6 +273,7 @@ impl<F, V> Expression<F, V> {
Expression::Query(_, _) => false,
Expression::True(_) => false,
Expression::False(_) => false,
Expression::Call(_, _, _) => false,
}
}
}
Expand Down Expand Up @@ -315,6 +329,7 @@ impl<F: Debug, V: Debug> Debug for Expression<F, V> {

Expression::True(_) => write!(f, "true"),
Expression::False(_) => write!(f, "false"),
Expression::Call(_, fun, exprs) => write!(f, "{:?}({:?})", fun, exprs),
}
}
}
61 changes: 44 additions & 17 deletions src/parser/ast/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,42 @@ pub struct TypedIdDecl<V> {

#[derive(Clone)]
pub enum Statement<F, V> {
Assert(DebugSymRef, Expression<F, V>), // assert x;

SignalAssignment(DebugSymRef, Vec<V>, Vec<Expression<F, V>>), // x <-- y;
SignalAssignmentAssert(DebugSymRef, Vec<V>, Vec<Expression<F, V>>), // x <== y;
WGAssignment(DebugSymRef, Vec<V>, Vec<Expression<F, V>>), // x = y;

IfThen(DebugSymRef, Box<Expression<F, V>>, Box<Statement<F, V>>), // if x { y }
/// assert x;
Assert(DebugSymRef, Expression<F, V>),
/// x <-- y;
SignalAssignment(DebugSymRef, Vec<V>, Vec<Expression<F, V>>),
/// x <== y;
SignalAssignmentAssert(DebugSymRef, Vec<V>, Vec<Expression<F, V>>),
/// x = y;
WGAssignment(DebugSymRef, Vec<V>, Vec<Expression<F, V>>),
/// if x { y }
IfThen(DebugSymRef, Box<Expression<F, V>>, Box<Statement<F, V>>),
/// if x { y } else { z }
IfThenElse(
DebugSymRef,
Box<Expression<F, V>>,
Box<Statement<F, V>>,
Box<Statement<F, V>>,
), // if x { y } else { z }

SignalDecl(DebugSymRef, Vec<TypedIdDecl<V>>), // signal x;
WGVarDecl(DebugSymRef, Vec<TypedIdDecl<V>>), // var x;

StateDecl(DebugSymRef, V, Box<Statement<F, V>>), // state x { y }

Transition(DebugSymRef, V, Box<Statement<F, V>>), // -> x { y }

Block(DebugSymRef, Vec<Statement<F, V>>), // { x }
),
/// signal x;
SignalDecl(DebugSymRef, Vec<TypedIdDecl<V>>),
/// var x;
WGVarDecl(DebugSymRef, Vec<TypedIdDecl<V>>),
/// state x { y }
StateDecl(DebugSymRef, V, Box<Statement<F, V>>),
/// Transition to another state.
/// -> x { y }
Transition(DebugSymRef, V, Box<Statement<F, V>>),
/// { x }
Block(DebugSymRef, Vec<Statement<F, V>>),
/// Call into another machine with assertion and subsequent transition to another
/// state.
/// Tuple values:
/// - debug symbol reference;
/// - assigned signal IDs;
/// - call expression;
/// - next state ID;
HyperTransition(DebugSymRef, Vec<V>, Expression<F, V>, V),
}

impl<F: Debug> Debug for Statement<F, Identifier> {
Expand Down Expand Up @@ -84,6 +98,18 @@ impl<F: Debug> Debug for Statement<F, Identifier> {
.join(" ")
)
}
Statement::HyperTransition(_, ids, call, state) => {
write!(
f,
"{:?} <== {:?} -> {:?};",
ids.iter()
.map(|id| id.name())
.collect::<Vec<_>>()
.join(", "),
call,
state
)
}
}
}
}
Expand All @@ -102,6 +128,7 @@ impl<F, V> Statement<F, V> {
Statement::StateDecl(dsym, _, _) => dsym.clone(),
Statement::Transition(dsym, _, _) => dsym.clone(),
Statement::Block(dsym, _) => dsym.clone(),
Statement::HyperTransition(dsym, _, _, _) => dsym.clone(),
}
}
}
29 changes: 8 additions & 21 deletions src/parser/build.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use num_bigint::BigInt;

use super::ast::{expression::Expression, statement::Statement, DebugSymRef, Identifier};

pub fn build_bin_op<S: Into<String>, F, V>(
Expand Down Expand Up @@ -64,25 +62,14 @@ pub fn build_transition<F>(
Statement::Transition(dsym, id, Box::new(block))
}

pub fn add_dsym(
pub fn build_hyper_transition<F: Clone>(
dsym: DebugSymRef,
stmt: Statement<BigInt, Identifier>,
) -> Statement<BigInt, Identifier> {
match stmt {
Statement::Assert(_, expr) => Statement::Assert(dsym, expr),
Statement::SignalAssignment(_, ids, exprs) => Statement::SignalAssignment(dsym, ids, exprs),
Statement::SignalAssignmentAssert(_, ids, exprs) => {
Statement::SignalAssignmentAssert(dsym, ids, exprs)
}
Statement::WGAssignment(_, ids, exprs) => Statement::WGAssignment(dsym, ids, exprs),
Statement::StateDecl(_, id, block) => Statement::StateDecl(dsym, id, block),
Statement::IfThen(_, cond, then_block) => Statement::IfThen(dsym, cond, then_block),
Statement::IfThenElse(_, cond, then_block, else_block) => {
Statement::IfThenElse(dsym, cond, then_block, else_block)
}
Statement::Block(_, stmts) => Statement::Block(dsym, stmts),
Statement::SignalDecl(_, ids) => Statement::SignalDecl(dsym, ids),
Statement::WGVarDecl(_, ids) => Statement::WGVarDecl(dsym, ids),
Statement::Transition(_, id, stmt) => Statement::Transition(dsym, id, stmt),
ids: Vec<Identifier>,
call: Expression<F, Identifier>,
state: Identifier,
) -> Statement<F, Identifier> {
match call {
Expression::Call(_, _, _) => Statement::HyperTransition(dsym, ids, call, state),
_ => unreachable!("Hyper transition must include a call statement"),
}
}
Loading
Loading