diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 0d3d27bdc..979a5ba9e 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -2,7 +2,9 @@ use crate::{ extension::{ExtensionId, ExtensionSet, OpDef, SignatureFunc}, hugr::{IdentList, NodeMetadataMap}, - ops::{constant::CustomSerialized, DataflowBlock, OpName, OpTrait, OpType, Value}, + ops::{ + constant::CustomSerialized, DataflowBlock, DataflowOpTrait, OpName, OpTrait, OpType, Value, + }, std_extensions::{ arithmetic::{float_types::ConstF64, int_types::ConstInt}, collections::array::ArrayValue, @@ -21,9 +23,6 @@ use hugr_model::v0::{self as model}; use petgraph::unionfind::UnionFind; use std::fmt::Write; -pub(crate) const OP_FUNC_CALL_INDIRECT: &str = "func.call-indirect"; -const TERM_PARAM_TUPLE: &str = "param.tuple"; - /// Export a [`Hugr`] graph to its representation in the model. pub fn export_hugr<'a>(hugr: &'a Hugr, bump: &'a Bump) -> model::Module<'a> { let mut ctx = Context::new(hugr, bump); @@ -313,55 +312,46 @@ impl<'a> Context<'a> { OpType::FuncDefn(func) => self.with_local_scope(node_id, |this| { let name = this.get_func_name(node).unwrap(); - let (params, constraints, signature) = this.export_poly_func_type(&func.signature); - let decl = this.bump.alloc(model::FuncDecl { - name, - params, - constraints, - signature, - }); + let symbol = this.export_poly_func_type(name, &func.signature); let extensions = this.export_ext_set(&func.signature.body().runtime_reqs); regions = this.bump.alloc_slice_copy(&[this.export_dfg( node, extensions, model::ScopeClosure::Closed, )]); - model::Operation::DefineFunc { decl } + model::Operation::DefineFunc(symbol) }), OpType::FuncDecl(func) => self.with_local_scope(node_id, |this| { let name = this.get_func_name(node).unwrap(); - let (params, constraints, func) = this.export_poly_func_type(&func.signature); - let decl = this.bump.alloc(model::FuncDecl { - name, - params, - constraints, - signature: func, - }); - model::Operation::DeclareFunc { decl } + let symbol = this.export_poly_func_type(name, &func.signature); + model::Operation::DeclareFunc(symbol) }), OpType::AliasDecl(alias) => self.with_local_scope(node_id, |this| { // TODO: We should support aliases with different types and with parameters - let r#type = this.make_term(model::Term::Type); - let decl = this.bump.alloc(model::AliasDecl { + let signature = this.make_term_apply(model::CORE_TYPE, &[]); + let symbol = this.bump.alloc(model::Symbol { name: &alias.name, params: &[], - r#type, + constraints: &[], + signature, }); - model::Operation::DeclareAlias { decl } + model::Operation::DeclareAlias(symbol) }), OpType::AliasDefn(alias) => self.with_local_scope(node_id, |this| { let value = this.export_type(&alias.definition); // TODO: We should support aliases with different types and with parameters - let r#type = this.make_term(model::Term::Type); - let decl = this.bump.alloc(model::AliasDecl { + let signature = this.make_term_apply(model::CORE_TYPE, &[]); + let symbol = this.bump.alloc(model::Symbol { name: &alias.name, params: &[], - r#type, + constraints: &[], + signature, }); - model::Operation::DefineAlias { decl, value } + params = self.bump.alloc_slice_copy(&[value]); + model::Operation::DefineAlias(symbol) }), OpType::Call(call) => { @@ -371,22 +361,29 @@ impl<'a> Context<'a> { let mut args = BumpVec::new_in(self.bump); args.extend(call.type_args.iter().map(|arg| self.export_type_arg(arg))); let args = args.into_bump_slice(); + let func = self.make_term(model::Term::Apply(symbol, args)); + + // TODO PERFORMANCE: Avoid exporting the signature here again. + let signature = call.signature(); + let inputs = self.export_type_row(&signature.input); + let outputs = self.export_type_row(&signature.output); + let ext = self.export_ext_set(&signature.runtime_reqs); - let func = self.make_term(model::Term::ApplyFull { symbol, args }); - model::Operation::CallFunc { func } + params = self.bump.alloc_slice_copy(&[inputs, outputs, ext, func]); + model::Operation::Custom(self.resolve_symbol(model::CORE_CALL_INDIRECT)) } OpType::LoadFunction(load) => { - // TODO: If the node is not connected to a function, we should do better than panic. let node = self.connected_function(node).unwrap(); let symbol = self.node_to_id[&node]; - let mut args = BumpVec::new_in(self.bump); args.extend(load.type_args.iter().map(|arg| self.export_type_arg(arg))); let args = args.into_bump_slice(); - - let func = self.make_term(model::Term::ApplyFull { symbol, args }); - model::Operation::LoadFunc { func } + let func = self.make_term(model::Term::Apply(symbol, args)); + let runtime_type = self.make_term(model::Term::Wildcard); + let ext = self.make_term(model::Term::Wildcard); + params = self.bump.alloc_slice_copy(&[runtime_type, ext, func]); + model::Operation::Custom(self.resolve_symbol(model::CORE_LOAD_CONST)) } OpType::Const(_) => { @@ -404,15 +401,28 @@ impl<'a> Context<'a> { // TODO: Share the constant value between all nodes that load it. + let runtime_type = self.make_term(model::Term::Wildcard); + let ext = self.make_term(model::Term::Wildcard); let value = self.export_value(&const_node_data.value); - model::Operation::Const { value } + params = self.bump.alloc_slice_copy(&[runtime_type, ext, value]); + model::Operation::Custom(self.resolve_symbol(model::CORE_LOAD_CONST)) } - OpType::CallIndirect(_) => model::Operation::CustomFull { - operation: self.resolve_symbol(OP_FUNC_CALL_INDIRECT), - }, + OpType::CallIndirect(call) => { + let inputs = self.export_type_row(&call.signature.input); + let outputs = self.export_type_row(&call.signature.output); + let ext = self.export_ext_set(&call.signature.runtime_reqs); + params = self.bump.alloc_slice_copy(&[inputs, outputs, ext]); + model::Operation::Custom(self.resolve_symbol(model::CORE_CALL_INDIRECT)) + } - OpType::Tag(tag) => model::Operation::Tag { tag: tag.tag as _ }, + OpType::Tag(tag) => { + let variants = self.make_term(model::Term::Wildcard); + let types = self.make_term(model::Term::Wildcard); + let tag = self.make_term(model::Term::Nat(tag.tag as u64)); + params = self.bump.alloc_slice_copy(&[variants, types, tag]); + model::Operation::Custom(self.resolve_symbol(model::CORE_MAKE_ADT)) + } OpType::TailLoop(tail_loop) => { let extensions = self.export_ext_set(&tail_loop.extension_delta); @@ -439,7 +449,7 @@ impl<'a> Context<'a> { .bump .alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_type_arg(arg))); - model::Operation::CustomFull { operation } + model::Operation::Custom(operation) } OpType::OpaqueOp(op) => { @@ -449,7 +459,7 @@ impl<'a> Context<'a> { .bump .alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_type_arg(arg))); - model::Operation::CustomFull { operation } + model::Operation::Custom(operation) } }; @@ -525,16 +535,9 @@ impl<'a> Context<'a> { } }; - let decl = self.with_local_scope(node, |this| { + let symbol = self.with_local_scope(node, |this| { let name = this.make_qualified_name(opdef.extension_id(), opdef.name()); - let (params, constraints, r#type) = this.export_poly_func_type(poly_func_type); - let decl = this.bump.alloc(model::OperationDecl { - name, - params, - constraints, - r#type, - }); - decl + this.export_poly_func_type(name, poly_func_type) }); let meta = { @@ -555,7 +558,7 @@ impl<'a> Context<'a> { }; let node_data = self.module.get_node_mut(node).unwrap(); - node_data.operation = model::Operation::DeclareOperation { decl }; + node_data.operation = model::Operation::DeclareOperation(symbol); node_data.meta = meta; node @@ -566,10 +569,10 @@ impl<'a> Context<'a> { pub fn export_block_signature(&mut self, block: &DataflowBlock) -> model::TermId { let inputs = { let inputs = self.export_type_row(&block.inputs); - let inputs = self.make_term(model::Term::Control { values: inputs }); - self.make_term(model::Term::List { - parts: self.bump.alloc_slice_copy(&[model::ListPart::Item(inputs)]), - }) + let inputs = self.make_term_apply(model::CORE_CTRL, &[inputs]); + self.make_term(model::Term::List( + self.bump.alloc_slice_copy(&[model::ListPart::Item(inputs)]), + )) }; let tail = self.export_type_row(&block.other_outputs); @@ -578,20 +581,14 @@ impl<'a> Context<'a> { let mut outputs = BumpVec::with_capacity_in(block.sum_rows.len(), self.bump); for sum_row in block.sum_rows.iter() { let variant = self.export_type_row_with_tail(sum_row, Some(tail)); - let control = self.make_term(model::Term::Control { values: variant }); + let control = self.make_term_apply(model::CORE_CTRL, &[variant]); outputs.push(model::ListPart::Item(control)); } - self.make_term(model::Term::List { - parts: outputs.into_bump_slice(), - }) + self.make_term(model::Term::List(outputs.into_bump_slice())) }; let extensions = self.export_ext_set(&block.extension_delta); - self.make_term(model::Term::FuncType { - inputs, - outputs, - extensions, - }) + self.make_term_apply(model::CORE_FN, &[inputs, outputs, extensions]) } /// Creates a data flow region from the given node's children. @@ -643,12 +640,7 @@ impl<'a> Context<'a> { let signature = { let inputs = self.export_type_row(input_types.unwrap()); let outputs = self.export_type_row(output_types.unwrap()); - - Some(self.make_term(model::Term::FuncType { - inputs, - outputs, - extensions, - })) + Some(self.make_term_apply(model::CORE_FN, &[inputs, outputs, extensions])) }; let scope = match closure { @@ -715,25 +707,17 @@ impl<'a> Context<'a> { let mut wrap_ctrl = |types: &TypeRow| { let types = self.export_type_row(types); - let types_ctrl = self.make_term(model::Term::Control { values: types }); - self.make_term(model::Term::List { - parts: self - .bump + let types_ctrl = self.make_term_apply(model::CORE_CTRL, &[types]); + self.make_term(model::Term::List( + self.bump .alloc_slice_copy(&[model::ListPart::Item(types_ctrl)]), - }) + )) }; let inputs = wrap_ctrl(node_signature.input()); let outputs = wrap_ctrl(node_signature.output()); let extensions = self.export_ext_set(&node_signature.runtime_reqs); - - let func_type = self.make_term(model::Term::FuncType { - inputs, - outputs, - extensions, - }); - - Some(func_type) + Some(self.make_term_apply(model::CORE_FN, &[inputs, outputs, extensions])) }; let scope = match closure { @@ -776,15 +760,11 @@ impl<'a> Context<'a> { } /// Exports a polymorphic function type. - /// - /// The returned triple consists of: - /// - The static parameters of the polymorphic function type. - /// - The constraints of the polymorphic function type. - /// - The function type itself. pub fn export_poly_func_type( &mut self, + name: &'a str, t: &PolyFuncTypeBase, - ) -> (&'a [model::Param<'a>], &'a [model::TermId], model::TermId) { + ) -> &'a model::Symbol<'a> { let mut params = BumpVec::with_capacity_in(t.params().len(), self.bump); let scope = self .local_scope @@ -793,18 +773,19 @@ impl<'a> Context<'a> { for (i, param) in t.params().iter().enumerate() { let name = self.bump.alloc_str(&i.to_string()); let r#type = self.export_type_param(param, Some((scope, i as _))); - let param = model::Param { - name, - r#type, - sort: model::ParamSort::Implicit, - }; + let param = model::Param { name, r#type }; params.push(param) } let constraints = self.bump.alloc_slice_copy(&self.local_constraints); let body = self.export_func_type(t.body()); - (params.into_bump_slice(), constraints, body) + self.bump.alloc(model::Symbol { + name, + params: params.into_bump_slice(), + constraints, + signature: body, + }) } pub fn export_type(&mut self, t: &TypeBase) -> model::TermId { @@ -815,12 +796,8 @@ impl<'a> Context<'a> { match t { TypeEnum::Extension(ext) => self.export_custom_type(ext), TypeEnum::Alias(alias) => { - let global = self.resolve_symbol(self.bump.alloc_str(alias.name())); - let args = &[]; - self.make_term(model::Term::ApplyFull { - symbol: global, - args, - }) + let symbol = self.resolve_symbol(self.bump.alloc_str(alias.name())); + self.make_term(model::Term::Apply(symbol, &[])) } TypeEnum::Function(func) => self.export_func_type(func), TypeEnum::Variable(index, _) => { @@ -836,11 +813,7 @@ impl<'a> Context<'a> { let inputs = self.export_type_row(t.input()); let outputs = self.export_type_row(t.output()); let extensions = self.export_ext_set(&t.runtime_reqs); - self.make_term(model::Term::FuncType { - inputs, - outputs, - extensions, - }) + self.make_term_apply(model::CORE_FN, &[inputs, outputs, extensions]) } pub fn export_custom_type(&mut self, t: &CustomType) -> model::TermId { @@ -849,7 +822,7 @@ impl<'a> Context<'a> { let args = self .bump .alloc_slice_fill_iter(t.args().iter().map(|p| self.export_type_arg(p))); - let term = model::Term::ApplyFull { symbol, args }; + let term = model::Term::Apply(symbol, args); self.make_term(term) } @@ -865,7 +838,7 @@ impl<'a> Context<'a> { .iter() .map(|elem| model::ListPart::Item(self.export_type_arg(elem))), ); - self.make_term(model::Term::List { parts }) + self.make_term(model::Term::List(parts)) } TypeArg::Extensions { es } => self.export_ext_set(es), TypeArg::Variable { v } => self.export_type_arg_var(v), @@ -882,27 +855,31 @@ impl<'a> Context<'a> { self.make_term(model::Term::Var(model::VarId(node, t.0 as _))) } - pub fn export_sum_type(&mut self, t: &SumType) -> model::TermId { + pub fn export_sum_variants(&mut self, t: &SumType) -> model::TermId { match t { SumType::Unit { size } => { - let parts = self.bump.alloc_slice_fill_iter((0..*size).map(|_| { - model::ListPart::Item(self.make_term(model::Term::List { parts: &[] })) - })); - let variants = self.make_term(model::Term::List { parts }); - self.make_term(model::Term::Adt { variants }) + let parts = + self.bump + .alloc_slice_fill_iter((0..*size).map(|_| { + model::ListPart::Item(self.make_term(model::Term::List(&[]))) + })); + self.make_term(model::Term::List(parts)) } SumType::General { rows } => { let parts = self.bump.alloc_slice_fill_iter( rows.iter() .map(|row| model::ListPart::Item(self.export_type_row(row))), ); - let list = model::Term::List { parts }; - let variants = { self.make_term(list) }; - self.make_term(model::Term::Adt { variants }) + self.make_term(model::Term::List(parts)) } } } + pub fn export_sum_type(&mut self, t: &SumType) -> model::TermId { + let variants = self.export_sum_variants(t); + self.make_term_apply(model::CORE_ADT, &[variants]) + } + #[inline] pub fn export_type_row(&mut self, row: &TypeRowBase) -> model::TermId { self.export_type_row_with_tail(row, None) @@ -931,7 +908,7 @@ impl<'a> Context<'a> { } let parts = parts.into_bump_slice(); - self.make_term(model::Term::List { parts }) + self.make_term(model::Term::List(parts)) } /// Exports a `TypeParam` to a term. @@ -949,18 +926,18 @@ impl<'a> Context<'a> { TypeParam::Type { b } => { if let (Some((node, index)), TypeBound::Copyable) = (var, b) { let term = self.make_term(model::Term::Var(model::VarId(node, index))); - let non_linear = self.make_term(model::Term::NonLinearConstraint { term }); + let non_linear = self.make_term_apply(model::CORE_NON_LINEAR, &[term]); self.local_constraints.push(non_linear); } - self.make_term(model::Term::Type) + self.make_term_apply(model::CORE_TYPE, &[]) } // This ignores the bound on the natural for now. - TypeParam::BoundedNat { .. } => self.make_term(model::Term::NatType), - TypeParam::String => self.make_term(model::Term::StrType), + TypeParam::BoundedNat { .. } => self.make_term_apply(model::CORE_NAT_TYPE, &[]), + TypeParam::String => self.make_term_apply(model::CORE_STR_TYPE, &[]), TypeParam::List { param } => { let item_type = self.export_type_param(param, None); - self.make_term(model::Term::ListType { item_type }) + self.make_term_apply(model::CORE_LIST_TYPE, &[item_type]) } TypeParam::Tuple { params } => { let parts = self.bump.alloc_slice_fill_iter( @@ -968,17 +945,10 @@ impl<'a> Context<'a> { .iter() .map(|param| model::ListPart::Item(self.export_type_param(param, None))), ); - let types = self.make_term(model::Term::List { parts }); - let symbol = self.resolve_symbol(TERM_PARAM_TUPLE); - self.make_term(model::Term::ApplyFull { - symbol, - args: self.bump.alloc_slice_copy(&[types]), - }) - } - TypeParam::Extensions => { - let term = model::Term::ExtSetType; - self.make_term(term) + let types = self.make_term(model::Term::List(parts)); + self.make_term_apply(model::CORE_TUPLE_TYPE, &[types]) } + TypeParam::Extensions => self.make_term_apply(model::CORE_EXT_SET, &[]), } } @@ -998,9 +968,7 @@ impl<'a> Context<'a> { } } - self.make_term(model::Term::ExtSet { - parts: parts.into_bump_slice(), - }) + self.make_term(model::Term::ExtSet(parts.into_bump_slice())) } fn export_value(&mut self, value: &'a Value) -> model::TermId { @@ -1019,13 +987,11 @@ impl<'a> Context<'a> { contents.push(model::ListPart::Item(self.export_value(element))); } - let contents = self.make_term(model::Term::List { - parts: contents.into_bump_slice(), - }); + let contents = self.make_term(model::Term::List(contents.into_bump_slice())); let symbol = self.resolve_symbol(ArrayValue::CTR_NAME); let args = self.bump.alloc_slice_copy(&[len, element_type, contents]); - return self.make_term(model::Term::ApplyFull { symbol, args }); + return self.make_term(model::Term::Apply(symbol, args)); } if let Some(v) = e.value().downcast_ref::() { @@ -1034,16 +1000,14 @@ impl<'a> Context<'a> { let symbol = self.resolve_symbol(ConstInt::CTR_NAME); let args = self.bump.alloc_slice_copy(&[bitwidth, literal]); - return self.make_term(model::Term::ApplyFull { symbol, args }); + return self.make_term(model::Term::Apply(symbol, args)); } if let Some(v) = e.value().downcast_ref::() { - let literal = self.make_term(model::Term::Float { - value: v.value().into(), - }); + let literal = self.make_term(model::Term::Float(v.value().into())); let symbol = self.resolve_symbol(ConstF64::CTR_NAME); let args = self.bump.alloc_slice_copy(&[literal]); - return self.make_term(model::Term::ApplyFull { symbol, args }); + return self.make_term(model::Term::Apply(symbol, args)); } let json = match e.value().downcast_ref::() { @@ -1057,9 +1021,9 @@ impl<'a> Context<'a> { let extensions = self.export_ext_set(&e.extension_reqs()); let args = self .bump - .alloc_slice_copy(&[runtime_type, json, extensions]); + .alloc_slice_copy(&[runtime_type, extensions, json]); let symbol = self.resolve_symbol(model::COMPAT_CONST_JSON); - self.make_term(model::Term::ApplyFull { symbol, args }) + self.make_term(model::Term::Apply(symbol, args)) } Value::Function { hugr } => { @@ -1077,22 +1041,26 @@ impl<'a> Context<'a> { self.node_to_id = outer_node_to_id; self.hugr = outer_hugr; - self.make_term(model::Term::ConstFunc { region }) + self.make_term(model::Term::ConstFunc(region)) } Value::Sum(sum) => { - let tag = sum.tag as _; - let mut values = BumpVec::with_capacity_in(sum.values.len(), self.bump); + let variants = self.export_sum_variants(&sum.sum_type); + let ext = self.make_term(model::Term::Wildcard); + let types = self.make_term(model::Term::Wildcard); + let tag = self.make_term(model::Term::Nat(sum.tag as u64)); - for value in &sum.values { - values.push(model::ListPart::Item(self.export_value(value))); - } + let values = { + let mut values = BumpVec::with_capacity_in(sum.values.len(), self.bump); - let values = self.make_term(model::Term::List { - parts: values.into_bump_slice(), - }); + for value in &sum.values { + values.push(model::TuplePart::Item(self.export_value(value))); + } + + self.make_term(model::Term::Tuple(values.into_bump_slice())) + }; - self.make_term(model::Term::ConstAdt { tag, values }) + self.make_term_apply(model::CORE_CONST_ADT, &[variants, ext, types, tag, values]) } } } @@ -1131,7 +1099,7 @@ impl<'a> Context<'a> { fn make_term_apply(&mut self, name: &'a str, args: &[model::TermId]) -> model::TermId { let symbol = self.resolve_symbol(name); let args = self.bump.alloc_slice_copy(args); - self.make_term(model::Term::ApplyFull { symbol, args }) + self.make_term(model::Term::Apply(symbol, args)) } } diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index c1b7d1cff..a87acf5ac 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -6,7 +6,6 @@ use std::sync::Arc; use crate::{ - export::OP_FUNC_CALL_INDIRECT, extension::{ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError}, hugr::{HugrMut, IdentList}, ops::{ @@ -280,13 +279,13 @@ impl<'a> Context<'a> { &mut self, func_node: model::NodeId, ) -> Result { - let decl = match self.get_node(func_node)?.operation { - model::Operation::DefineFunc { decl } => decl, - model::Operation::DeclareFunc { decl } => decl, + let symbol = match self.get_node(func_node)?.operation { + model::Operation::DefineFunc(symbol) => symbol, + model::Operation::DeclareFunc(symbol) => symbol, _ => return Err(model::ModelError::UnexpectedOperation(func_node).into()), }; - self.import_poly_func_type(func_node, *decl, |_, signature| Ok(signature)) + self.import_poly_func_type(func_node, *symbol, |_, signature| Ok(signature)) } /// Import the root region of the module. @@ -341,10 +340,10 @@ impl<'a> Context<'a> { Ok(Some(node)) } - model::Operation::DefineFunc { decl } => { - self.import_poly_func_type(node_id, *decl, |ctx, signature| { + model::Operation::DefineFunc(symbol) => { + self.import_poly_func_type(node_id, *symbol, |ctx, signature| { let optype = OpType::FuncDefn(FuncDefn { - name: decl.name.to_string(), + name: symbol.name.to_string(), signature, }); @@ -360,10 +359,10 @@ impl<'a> Context<'a> { }) } - model::Operation::DeclareFunc { decl } => { - self.import_poly_func_type(node_id, *decl, |ctx, signature| { + model::Operation::DeclareFunc(symbol) => { + self.import_poly_func_type(node_id, *symbol, |ctx, signature| { let optype = OpType::FuncDecl(FuncDecl { - name: decl.name.to_string(), + name: symbol.name.to_string(), signature, }); @@ -373,45 +372,6 @@ impl<'a> Context<'a> { }) } - model::Operation::CallFunc { func } => { - let model::Term::ApplyFull { symbol, args } = self.get_term(func)? else { - return Err(model::ModelError::TypeError(func).into()); - }; - - let func_sig = self.get_func_signature(*symbol)?; - - let type_args = args - .iter() - .map(|term| self.import_type_arg(*term)) - .collect::, _>>()?; - - self.static_edges.push((*symbol, node_id)); - let optype = OpType::Call(Call::try_new(func_sig, type_args)?); - - let node = self.make_node(node_id, optype, parent)?; - Ok(Some(node)) - } - - model::Operation::LoadFunc { func } => { - let model::Term::ApplyFull { symbol, args } = self.get_term(func)? else { - return Err(model::ModelError::TypeError(func).into()); - }; - - let func_sig = self.get_func_signature(*symbol)?; - - let type_args = args - .iter() - .map(|term| self.import_type_arg(*term)) - .collect::, _>>()?; - - self.static_edges.push((*symbol, node_id)); - - let optype = OpType::LoadFunction(LoadFunction::try_new(func_sig, type_args)?); - - let node = self.make_node(node_id, optype, parent)?; - Ok(Some(node)) - } - model::Operation::TailLoop => { let node = self.import_tail_loop(node_id, parent)?; Ok(Some(node)) @@ -421,16 +381,125 @@ impl<'a> Context<'a> { Ok(Some(node)) } - model::Operation::CustomFull { operation } => { + model::Operation::Custom(operation) => { let name = self.get_symbol_name(operation)?; - if name == OP_FUNC_CALL_INDIRECT { + if name == model::CORE_CALL_INDIRECT { let signature = self.get_node_signature(node_id)?; let optype = OpType::CallIndirect(CallIndirect { signature }); let node = self.make_node(node_id, optype, parent)?; return Ok(Some(node)); } + if name == model::CORE_CALL { + let &[_, _, _, func] = node_data.params else { + return Err(model::ModelError::InvalidOperation(node_id).into()); + }; + + let model::Term::Apply(symbol, args) = self.get_term(func)? else { + return Err(model::ModelError::TypeError(func).into()); + }; + + let func_sig = self.get_func_signature(*symbol)?; + + let type_args = args + .iter() + .map(|term| self.import_type_arg(*term)) + .collect::, _>>()?; + + self.static_edges.push((*symbol, node_id)); + let optype = OpType::Call(Call::try_new(func_sig, type_args)?); + + let node = self.make_node(node_id, optype, parent)?; + return Ok(Some(node)); + } + + if name == model::CORE_LOAD_CONST { + let &[_, _, value] = node_data.params else { + return Err(model::ModelError::InvalidOperation(node_id).into()); + }; + + // If the constant refers directly to a function, import this as the `LoadFunc` operation. + if let model::Term::Apply(symbol, args) = self.get_term(value)? { + let func_node_data = self + .module + .get_node(*symbol) + .ok_or(model::ModelError::NodeNotFound(*symbol))?; + + if let model::Operation::DefineFunc(_) | model::Operation::DeclareFunc(_) = + func_node_data.operation + { + let func_sig = self.get_func_signature(*symbol)?; + let type_args = args + .iter() + .map(|term| self.import_type_arg(*term)) + .collect::, _>>()?; + + self.static_edges.push((*symbol, node_id)); + + let optype = + OpType::LoadFunction(LoadFunction::try_new(func_sig, type_args)?); + + let node = self.make_node(node_id, optype, parent)?; + return Ok(Some(node)); + } + } + + // Otherwise use const nodes + let signature = node_data + .signature + .ok_or_else(|| error_uninferred!("node signature"))?; + let [_, outputs, _] = self.get_func_type(signature)?; + let outputs = self.import_closed_list(outputs)?; + let output = outputs + .first() + .ok_or(model::ModelError::TypeError(signature))?; + let datatype = self.import_type(*output)?; + + let imported_value = self.import_value(value, *output)?; + + let load_const_node = self.make_node( + node_id, + OpType::LoadConstant(LoadConstant { + datatype: datatype.clone(), + }), + parent, + )?; + + let const_node = self + .hugr + .add_node_with_parent(parent, OpType::Const(Const::new(imported_value))); + + self.hugr.connect(const_node, 0, load_const_node, 0); + + return Ok(Some(load_const_node)); + } + + if name == model::CORE_MAKE_ADT { + let &[_, _, tag] = node_data.params else { + return Err(model::ModelError::InvalidOperation(node_id).into()); + }; + + let model::Term::Nat(tag) = self.get_term(tag)? else { + return Err(model::ModelError::TypeError(tag).into()); + }; + + let signature = node_data + .signature + .ok_or_else(|| error_uninferred!("node signature"))?; + let [_, outputs, _] = self.get_func_type(signature)?; + let (variants, _) = self.import_adt_and_rest(node_id, outputs)?; + let node = self.make_node( + node_id, + OpType::Tag(Tag { + variants, + tag: *tag as usize, + }), + parent, + )?; + return Ok(Some(node)); + } + let signature = self.get_node_signature(node_id)?; let args = node_data .params @@ -458,19 +527,19 @@ impl<'a> Context<'a> { Ok(Some(node)) } - model::Operation::Custom { .. } => Err(error_unsupported!( - "custom operation with implicit parameters" - )), - - model::Operation::DefineAlias { decl, value } => { - if !decl.params.is_empty() { + model::Operation::DefineAlias(symbol) => { + if !symbol.params.is_empty() { return Err(error_unsupported!( "parameters or constraints in alias definition" )); } + let &[value] = node_data.params else { + return Err(model::ModelError::InvalidOperation(node_id).into()); + }; + let optype = OpType::AliasDefn(AliasDefn { - name: decl.name.to_smolstr(), + name: symbol.name.to_smolstr(), definition: self.import_type(value)?, }); @@ -478,15 +547,15 @@ impl<'a> Context<'a> { Ok(Some(node)) } - model::Operation::DeclareAlias { decl } => { - if !decl.params.is_empty() { + model::Operation::DeclareAlias(symbol) => { + if !symbol.params.is_empty() { return Err(error_unsupported!( "parameters or constraints in alias declaration" )); } let optype = OpType::AliasDecl(AliasDecl { - name: decl.name.to_smolstr(), + name: symbol.name.to_smolstr(), bound: TypeBound::Copyable, }); @@ -494,57 +563,10 @@ impl<'a> Context<'a> { Ok(Some(node)) } - model::Operation::Tag { tag } => { - let signature = node_data - .signature - .ok_or_else(|| error_uninferred!("node signature"))?; - let (_, outputs, _) = self.get_func_type(signature)?; - let (variants, _) = self.import_adt_and_rest(node_id, outputs)?; - let node = self.make_node( - node_id, - OpType::Tag(Tag { - variants, - tag: tag as _, - }), - parent, - )?; - Ok(Some(node)) - } - model::Operation::Import { .. } => Ok(None), model::Operation::DeclareConstructor { .. } => Ok(None), model::Operation::DeclareOperation { .. } => Ok(None), - - model::Operation::Const { value } => { - let signature = node_data - .signature - .ok_or_else(|| error_uninferred!("node signature"))?; - let (_, outputs, _) = self.get_func_type(signature)?; - let outputs = self.import_closed_list(outputs)?; - let output = outputs - .first() - .ok_or(model::ModelError::TypeError(signature))?; - let datatype = self.import_type(*output)?; - - let imported_value = self.import_value(value, *output)?; - - let load_const_node = self.make_node( - node_id, - OpType::LoadConstant(LoadConstant { - datatype: datatype.clone(), - }), - parent, - )?; - - let const_node = self - .hugr - .add_node_with_parent(parent, OpType::Const(Const::new(imported_value))); - - self.hugr.connect(const_node, 0, load_const_node, 0); - - Ok(Some(load_const_node)) - } } } @@ -610,11 +632,8 @@ impl<'a> Context<'a> { }; let sum_rows: Vec<_> = { - let model::Term::Adt { variants } = self.get_term(*first)? else { - return Err(model::ModelError::TypeError(*first).into()); - }; - - self.import_type_rows(*variants)? + let [variants] = self.expect_symbol(*first, model::CORE_ADT)?; + self.import_type_rows(variants)? }; let rest = rest @@ -639,7 +658,7 @@ impl<'a> Context<'a> { }; let region_data = self.get_region(*region)?; - let (_, region_outputs, _) = self.get_func_type( + let [_, region_outputs, _] = self.get_func_type( region_data .signature .ok_or_else(|| error_uninferred!("region signature"))?, @@ -680,7 +699,7 @@ impl<'a> Context<'a> { ) -> Result { let node_data = self.get_node(node_id)?; debug_assert_eq!(node_data.operation, model::Operation::Conditional); - let (inputs, outputs, _) = self.get_func_type( + let [inputs, outputs, _] = self.get_func_type( node_data .signature .ok_or_else(|| error_uninferred!("node signature"))?, @@ -732,7 +751,7 @@ impl<'a> Context<'a> { self.region_scope = region; } - let (region_source, region_targets, _) = self.get_func_type( + let [region_source, region_targets, _] = self.get_func_type( region_data .signature .ok_or_else(|| error_uninferred!("region signature"))?, @@ -750,11 +769,8 @@ impl<'a> Context<'a> { return Err(model::ModelError::TypeError(region_source).into()); }; - let model::Term::Control { values: types } = self.get_term(*ctrl_type)? else { - return Err(model::ModelError::TypeError(*ctrl_type).into()); - }; - - self.import_type_row(*types)? + let [types] = self.expect_symbol(*ctrl_type, model::CORE_CTRL)?; + self.import_type_row(types)? }; let entry = self.hugr.add_node_with_parent( @@ -825,11 +841,8 @@ impl<'a> Context<'a> { return Err(model::ModelError::TypeError(region_targets).into()); }; - let model::Term::Control { values: types } = self.get_term(*ctrl_type)? else { - return Err(model::ModelError::TypeError(*ctrl_type).into()); - }; - - self.import_type_row(*types)? + let [types] = self.expect_symbol(*ctrl_type, model::CORE_CTRL)?; + self.import_type_row(types)? }; let exit = self @@ -855,7 +868,7 @@ impl<'a> Context<'a> { return Err(model::ModelError::InvalidRegions(node_id).into()); }; let region_data = self.get_region(*region)?; - let (inputs, outputs, extensions) = self.get_func_type( + let [inputs, outputs, extensions] = self.get_func_type( region_data .signature .ok_or_else(|| error_uninferred!("region signature"))?, @@ -879,41 +892,40 @@ impl<'a> Context<'a> { fn import_poly_func_type( &mut self, node: model::NodeId, - decl: model::FuncDecl<'a>, + symbol: model::Symbol<'a>, in_scope: impl FnOnce(&mut Self, PolyFuncTypeBase) -> Result, ) -> Result { - let mut imported_params = Vec::with_capacity(decl.params.len()); + let mut imported_params = Vec::with_capacity(symbol.params.len()); - for (index, param) in decl.params.iter().enumerate() { + for (index, param) in symbol.params.iter().enumerate() { self.local_vars .insert(model::VarId(node, index as _), LocalVar::new(param.r#type)); } - for constraint in decl.constraints { - match self.get_term(*constraint)? { - model::Term::NonLinearConstraint { term } => { - let model::Term::Var(var) = self.get_term(*term)? else { - return Err(error_unsupported!( - "constraint on term that is not a variable" - )); - }; + for constraint in symbol.constraints { + if let Some([term]) = self.match_symbol(*constraint, model::CORE_NON_LINEAR)? { + let model::Term::Var(var) = self.get_term(term)? else { + return Err(error_unsupported!( + "constraint on term that is not a variable" + )); + }; - self.local_vars - .get_mut(var) - .ok_or(model::ModelError::InvalidVar(*var))? - .bound = TypeBound::Copyable; - } - _ => return Err(error_unsupported!("constraint other than copy or discard")), + self.local_vars + .get_mut(var) + .ok_or(model::ModelError::InvalidVar(*var))? + .bound = TypeBound::Copyable; + } else { + return Err(error_unsupported!("constraint other than copy or discard")); } } - for (index, param) in decl.params.iter().enumerate() { + for (index, param) in symbol.params.iter().enumerate() { // NOTE: `PolyFuncType` only has explicit type parameters at present. let bound = self.local_vars[&model::VarId(node, index as _)].bound; imported_params.push(self.import_type_param(param.r#type, bound)?); } - let body = self.import_func_type::(decl.signature)?; + let body = self.import_func_type::(symbol.signature)?; in_scope(self, PolyFuncTypeBase::new(imported_params, body)) } @@ -923,59 +935,163 @@ impl<'a> Context<'a> { term_id: model::TermId, bound: TypeBound, ) -> Result { + if let Some([]) = self.match_symbol(term_id, model::CORE_STR_TYPE)? { + return Ok(TypeParam::String); + } + + if let Some([]) = self.match_symbol(term_id, model::CORE_NAT_TYPE)? { + return Ok(TypeParam::max_nat()); + } + + if let Some([]) = self.match_symbol(term_id, model::CORE_BYTES_TYPE)? { + return Err(error_unsupported!( + "`{}` as `TypeParam`", + model::CORE_BYTES_TYPE + )); + } + + if let Some([]) = self.match_symbol(term_id, model::CORE_FLOAT_TYPE)? { + return Err(error_unsupported!( + "`{}` as `TypeParam`", + model::CORE_FLOAT_TYPE + )); + } + + if let Some([]) = self.match_symbol(term_id, model::CORE_TYPE)? { + return Ok(TypeParam::Type { b: bound }); + } + + if let Some([]) = self.match_symbol(term_id, model::CORE_STATIC)? { + return Err(error_unsupported!( + "`{}` as `TypeParam`", + model::CORE_STATIC + )); + } + + if let Some([]) = self.match_symbol(term_id, model::CORE_CONSTRAINT)? { + return Err(error_unsupported!( + "`{}` as `TypeParam`", + model::CORE_CONSTRAINT + )); + } + + if let Some([]) = self.match_symbol(term_id, model::CORE_CONST)? { + return Err(error_unsupported!("`{}` as `TypeParam`", model::CORE_CONST)); + } + + if let Some([]) = self.match_symbol(term_id, model::CORE_CTRL_TYPE)? { + return Err(error_unsupported!( + "`{}` as `TypeParam`", + model::CORE_CTRL_TYPE + )); + } + + if let Some([]) = self.match_symbol(term_id, model::CORE_EXT_SET)? { + return Ok(TypeParam::Extensions); + } + + if let Some([item_type]) = self.match_symbol(term_id, model::CORE_LIST_TYPE)? { + // At present `hugr-model` has no way to express that the item + // type of a list must be copyable. Therefore we import it as `Any`. + let param = Box::new(self.import_type_param(item_type, TypeBound::Any)?); + return Ok(TypeParam::List { param }); + } + + if let Some([_]) = self.match_symbol(term_id, model::CORE_TUPLE_TYPE)? { + // At present `hugr-model` has no way to express that the item + // types of a tuple must be copyable. Therefore we import it as `Any`. + todo!("import tuple type"); + } + match self.get_term(term_id)? { model::Term::Wildcard => Err(error_uninferred!("wildcard")), - model::Term::Type => Ok(TypeParam::Type { b: bound }), - - model::Term::StaticType => Err(error_unsupported!("`type` as `TypeParam`")), - model::Term::Constraint => Err(error_unsupported!("`constraint` as `TypeParam`")), model::Term::Var { .. } => Err(error_unsupported!("type variable as `TypeParam`")), - model::Term::Apply { .. } => Err(error_unsupported!("custom type as `TypeParam`")), - model::Term::ApplyFull { .. } => Err(error_unsupported!("custom type as `TypeParam`")), - model::Term::BytesType { .. } => Err(error_unsupported!("`bytes` as `TypeParam`")), - model::Term::FloatType { .. } => Err(error_unsupported!("`float` as `TypeParam`")), - model::Term::Const { .. } => Err(error_unsupported!("`(const ...)` as `TypeParam`")), - model::Term::FuncType { .. } => Err(error_unsupported!("`(fn ...)` as `TypeParam`")), - - model::Term::ListType { item_type } => { - // At present `hugr-model` has no way to express that the item - // type of a list must be copyable. Therefore we import it as `Any`. - let param = Box::new(self.import_type_param(*item_type, TypeBound::Any)?); - Ok(TypeParam::List { param }) + model::Term::Apply(symbol, _) => { + let name = self.get_symbol_name(*symbol)?; + Err(error_unsupported!("custom type `{}` as `TypeParam`", name)) } - model::Term::StrType => Ok(TypeParam::String), - model::Term::ExtSetType => Ok(TypeParam::Extensions), - - model::Term::NatType => Ok(TypeParam::max_nat()), - model::Term::Nat(_) + | model::Term::Tuple(_) | model::Term::Str(_) | model::Term::List { .. } | model::Term::ExtSet { .. } - | model::Term::Adt { .. } - | model::Term::Control { .. } - | model::Term::NonLinearConstraint { .. } | model::Term::ConstFunc { .. } | model::Term::Bytes { .. } - | model::Term::Meta - | model::Term::Float { .. } - | model::Term::ConstAdt { .. } => Err(model::ModelError::TypeError(term_id).into()), - - model::Term::ControlType => { - Err(error_unsupported!("type of control types as `TypeParam`")) - } + | model::Term::Float { .. } => Err(model::ModelError::TypeError(term_id).into()), } } /// Import a `TypeArg` from a term that represents a static type or value. fn import_type_arg(&mut self, term_id: model::TermId) -> Result { + if let Some([]) = self.match_symbol(term_id, model::CORE_STR_TYPE)? { + return Err(error_unsupported!( + "`{}` as `TypeArg`", + model::CORE_STR_TYPE + )); + } + + if let Some([]) = self.match_symbol(term_id, model::CORE_NAT_TYPE)? { + return Err(error_unsupported!( + "`{}` as `TypeArg`", + model::CORE_NAT_TYPE + )); + } + + if let Some([]) = self.match_symbol(term_id, model::CORE_BYTES_TYPE)? { + return Err(error_unsupported!( + "`{}` as `TypeArg`", + model::CORE_BYTES_TYPE + )); + } + + if let Some([]) = self.match_symbol(term_id, model::CORE_FLOAT_TYPE)? { + return Err(error_unsupported!( + "`{}` as `TypeArg`", + model::CORE_FLOAT_TYPE + )); + } + + if let Some([]) = self.match_symbol(term_id, model::CORE_TYPE)? { + return Err(error_unsupported!("`{}` as `TypeArg`", model::CORE_TYPE)); + } + + if let Some([]) = self.match_symbol(term_id, model::CORE_CONSTRAINT)? { + return Err(error_unsupported!( + "`{}` as `TypeArg`", + model::CORE_CONSTRAINT + )); + } + + if let Some([]) = self.match_symbol(term_id, model::CORE_STATIC)? { + return Err(error_unsupported!("`{}` as `TypeArg`", model::CORE_STATIC)); + } + + if let Some([]) = self.match_symbol(term_id, model::CORE_EXT_SET)? { + return Err(error_unsupported!("`{}` as `TypeArg`", model::CORE_EXT_SET)); + } + + if let Some([]) = self.match_symbol(term_id, model::CORE_CTRL_TYPE)? { + return Err(error_unsupported!( + "`{}` as `TypeArg`", + model::CORE_CTRL_TYPE + )); + } + + if let Some([]) = self.match_symbol(term_id, model::CORE_CONST)? { + return Err(error_unsupported!("`{}` as `TypeArg`", model::CORE_CONST)); + } + + if let Some([]) = self.match_symbol(term_id, model::CORE_LIST_TYPE)? { + return Err(error_unsupported!( + "`{}` as `TypeArg`", + model::CORE_LIST_TYPE + )); + } + match self.get_term(term_id)? { model::Term::Wildcard => Err(error_uninferred!("wildcard")), - model::Term::Apply { .. } => { - Err(error_uninferred!("application with implicit parameters")) - } model::Term::Var(var) => { let var_info = self @@ -996,6 +1112,13 @@ impl<'a> Context<'a> { Ok(TypeArg::Sequence { elems }) } + model::Term::Tuple { .. } => { + // NOTE: While `TypeArg`s can represent tuples as + // `TypeArg::Sequence`s, this conflates lists and tuples. To + // avoid ambiguity we therefore report an error here for now. + Err(error_unsupported!("tuples as `TypeArg`")) + } + model::Term::Str(value) => Ok(TypeArg::String { arg: value.to_string(), }), @@ -1005,36 +1128,16 @@ impl<'a> Context<'a> { es: self.import_extension_set(term_id)?, }), - model::Term::StrType => Err(error_unsupported!("`str` as `TypeArg`")), - model::Term::NatType => Err(error_unsupported!("`nat` as `TypeArg`")), - model::Term::ListType { .. } => Err(error_unsupported!("`(list ...)` as `TypeArg`")), - model::Term::ExtSetType => Err(error_unsupported!("`ext-set` as `TypeArg`")), - model::Term::Type => Err(error_unsupported!("`type` as `TypeArg`")), - model::Term::Constraint => Err(error_unsupported!("`constraint` as `TypeArg`")), - model::Term::StaticType => Err(error_unsupported!("`static` as `TypeArg`")), - model::Term::ControlType => Err(error_unsupported!("`ctrl` as `TypeArg`")), - model::Term::BytesType => Err(error_unsupported!("`bytes` as `TypeArg`")), - model::Term::FloatType => Err(error_unsupported!("`float` as `TypeArg`")), model::Term::Bytes { .. } => Err(error_unsupported!("`(bytes ..)` as `TypeArg`")), - model::Term::Const { .. } => Err(error_unsupported!("`const` as `TypeArg`")), model::Term::Float { .. } => Err(error_unsupported!("float literal as `TypeArg`")), - model::Term::ConstAdt { .. } => Err(error_unsupported!("adt constant as `TypeArg`")), model::Term::ConstFunc { .. } => { Err(error_unsupported!("function constant as `TypeArg`")) } - model::Term::FuncType { .. } - | model::Term::Adt { .. } - | model::Term::ApplyFull { .. } => { + model::Term::Apply { .. } => { let ty = self.import_type(term_id)?; Ok(TypeArg::Type { ty }) } - - model::Term::Control { .. } - | model::Term::Meta - | model::Term::NonLinearConstraint { .. } => { - Err(model::ModelError::TypeError(term_id).into()) - } } } @@ -1053,7 +1156,7 @@ impl<'a> Context<'a> { es.insert_type_var(*index as _); } - model::Term::ExtSet { parts } => { + model::Term::ExtSet(parts) => { for part in *parts { match part { model::ExtSetPart::Extension(ext) => { @@ -1081,13 +1184,24 @@ impl<'a> Context<'a> { &mut self, term_id: model::TermId, ) -> Result, ImportError> { + if let Some([_, _, _]) = self.match_symbol(term_id, model::CORE_FN)? { + let func_type = self.import_func_type::(term_id)?; + return Ok(TypeBase::new_function(func_type)); + } + + if let Some([variants]) = self.match_symbol(term_id, model::CORE_ADT)? { + let variants = self.import_closed_list(variants)?; + let variants = variants + .iter() + .map(|variant| self.import_type_row::(*variant)) + .collect::, _>>()?; + return Ok(TypeBase::new_sum(variants)); + } + match self.get_term(term_id)? { model::Term::Wildcard => Err(error_uninferred!("wildcard")), - model::Term::Apply { .. } => { - Err(error_uninferred!("application with implicit parameters")) - } - model::Term::ApplyFull { symbol, args } => { + model::Term::Apply(symbol, args) => { let args = args .iter() .map(|arg| self.import_type_arg(*arg)) @@ -1127,66 +1241,29 @@ impl<'a> Context<'a> { Ok(TypeBase::new_var_use(*index as _, TypeBound::Copyable)) } - model::Term::FuncType { .. } => { - let func_type = self.import_func_type::(term_id)?; - Ok(TypeBase::new_function(func_type)) - } - - model::Term::Adt { variants } => { - let variants = self.import_closed_list(*variants)?; - let variants = variants - .iter() - .map(|variant| self.import_type_row::(*variant)) - .collect::, _>>()?; - Ok(TypeBase::new_sum(variants)) - } - // The following terms are not runtime types, but the core `Type` only contains runtime types. // We therefore report a type error here. - model::Term::ListType { .. } - | model::Term::StrType - | model::Term::NatType - | model::Term::ExtSetType - | model::Term::StaticType - | model::Term::Type - | model::Term::Constraint - | model::Term::Const { .. } - | model::Term::Str(_) + model::Term::Str(_) | model::Term::ExtSet { .. } | model::Term::List { .. } - | model::Term::Control { .. } - | model::Term::ControlType + | model::Term::Tuple { .. } | model::Term::Nat(_) - | model::Term::NonLinearConstraint { .. } | model::Term::Bytes { .. } - | model::Term::BytesType - | model::Term::FloatType | model::Term::Float { .. } - | model::Term::ConstFunc { .. } - | model::Term::Meta - | model::Term::ConstAdt { .. } => Err(model::ModelError::TypeError(term_id).into()), + | model::Term::ConstFunc { .. } => Err(model::ModelError::TypeError(term_id).into()), } } - fn get_func_type( - &mut self, - term_id: model::TermId, - ) -> Result<(model::TermId, model::TermId, model::TermId), ImportError> { - match self.get_term(term_id)? { - model::Term::FuncType { - inputs, - outputs, - extensions, - } => Ok((*inputs, *outputs, *extensions)), - _ => Err(model::ModelError::TypeError(term_id).into()), - } + fn get_func_type(&mut self, term_id: model::TermId) -> Result<[model::TermId; 3], ImportError> { + self.match_symbol(term_id, model::CORE_FN)? + .ok_or(model::ModelError::TypeError(term_id).into()) } fn import_func_type( &mut self, term_id: model::TermId, ) -> Result, ImportError> { - let (inputs, outputs, extensions) = self.get_func_type(term_id)?; + let [inputs, outputs, extensions] = self.get_func_type(term_id)?; let inputs = self.import_type_row(inputs)?; let outputs = self.import_type_row(outputs)?; let extensions = self.import_extension_set(extensions)?; @@ -1203,7 +1280,7 @@ impl<'a> Context<'a> { types: &mut Vec, ) -> Result<(), ImportError> { match ctx.get_term(term_id)? { - model::Term::List { parts } => { + model::Term::List(parts) => { types.reserve(parts.len()); for part in *parts { @@ -1228,6 +1305,41 @@ impl<'a> Context<'a> { Ok(types) } + fn import_closed_tuple( + &mut self, + term_id: model::TermId, + ) -> Result, ImportError> { + fn import_into( + ctx: &mut Context, + term_id: model::TermId, + types: &mut Vec, + ) -> Result<(), ImportError> { + match ctx.get_term(term_id)? { + model::Term::Tuple(parts) => { + types.reserve(parts.len()); + + for part in *parts { + match part { + model::TuplePart::Item(term_id) => { + types.push(*term_id); + } + model::TuplePart::Splice(term_id) => { + import_into(ctx, *term_id, types)?; + } + } + } + } + _ => return Err(model::ModelError::TypeError(term_id).into()), + } + + Ok(()) + } + + let mut types = Vec::new(); + import_into(self, term_id, &mut types)?; + Ok(types) + } + fn import_type_rows( &mut self, term_id: model::TermId, @@ -1248,7 +1360,7 @@ impl<'a> Context<'a> { types: &mut Vec>, ) -> Result<(), ImportError> { match ctx.get_term(term_id)? { - model::Term::List { parts } => { + model::Term::List(parts) => { types.reserve(parts.len()); for item in *parts { @@ -1303,27 +1415,15 @@ impl<'a> Context<'a> { &mut self, term_id: model::TermId, ) -> Result<(&'a str, serde_json::Value), ImportError> { - let (global, args) = match self.get_term(term_id)? { - model::Term::Apply { symbol, args } | model::Term::ApplyFull { symbol, args } => { - (symbol, args) - } - _ => return Err(model::ModelError::TypeError(term_id).into()), - }; - - let global = self.get_symbol_name(*global)?; - if global != model::COMPAT_META_JSON { - return Err(model::ModelError::TypeError(term_id).into()); - } - - let [name_arg, json_arg] = args else { - return Err(model::ModelError::TypeError(term_id).into()); - }; + let [name_arg, json_arg] = self + .match_symbol(term_id, model::COMPAT_META_JSON)? + .ok_or(model::ModelError::TypeError(term_id))?; - let model::Term::Str(name) = self.get_term(*name_arg)? else { + let model::Term::Str(name) = self.get_term(name_arg)? else { return Err(model::ModelError::TypeError(term_id).into()); }; - let model::Term::Str(json_str) = self.get_term(*json_arg)? else { + let model::Term::Str(json_str) = self.get_term(json_arg)? else { return Err(model::ModelError::TypeError(term_id).into()); }; @@ -1340,169 +1440,166 @@ impl<'a> Context<'a> { ) -> Result { let term_data = self.get_term(term_id)?; - match term_data { - model::Term::Wildcard => Err(error_uninferred!("wildcard")), - model::Term::Apply { .. } => { - Err(error_uninferred!("application with implicit parameters")) - } - model::Term::Var(_) => Err(error_unsupported!("constant value containing a variable")), + // NOTE: We have special cased arrays, integers, and floats for now. + // TODO: Allow arbitrary extension values to be imported from terms. - model::Term::ApplyFull { symbol, args } => { - let symbol_name = self.get_symbol_name(*symbol)?; + if let Some([runtime_type, extensions, json]) = + self.match_symbol(term_id, model::COMPAT_CONST_JSON)? + { + let model::Term::Str(json) = self.get_term(json)? else { + return Err(model::ModelError::TypeError(term_id).into()); + }; - if symbol_name == model::COMPAT_CONST_JSON { - let value = args.get(1).ok_or(model::ModelError::TypeError(term_id))?; + // We attempt to deserialize as the custom const directly. + // This might fail due to the custom const struct not being included when + // this code was compiled; in that case, we fall back to the serialized form. + let value: Option> = serde_json::from_str(json).ok(); + + if let Some(value) = value { + let opaque_value = OpaqueValue::from(value); + return Ok(Value::Extension { e: opaque_value }); + } else { + let runtime_type = self.import_type(runtime_type)?; + let extensions = self.import_extension_set(extensions)?; + + let value: serde_json::Value = serde_json::from_str(json) + .map_err(|_| model::ModelError::TypeError(term_id))?; + let custom_const = CustomSerialized::new(runtime_type, value, extensions); + let opaque_value = OpaqueValue::new(custom_const); + return Ok(Value::Extension { e: opaque_value }); + } + } - let model::Term::Str(json) = self.get_term(*value)? else { - return Err(model::ModelError::TypeError(term_id).into()); - }; + if let Some([_, element_type_term, contents]) = + self.match_symbol(term_id, ArrayValue::CTR_NAME)? + { + let element_type = self.import_type(element_type_term)?; + let contents = self.import_closed_list(contents)?; + let contents = contents + .iter() + .map(|item| self.import_value(*item, element_type_term)) + .collect::, _>>()?; + return Ok(ArrayValue::new(element_type, contents).into()); + } - // We attempt to deserialize as the custom const directly. - // This might fail due to the custom const struct not being included when - // this code was compiled; in that case, we fall back to the serialized form. - let value: Option> = serde_json::from_str(json).ok(); - - if let Some(value) = value { - let opaque_value = OpaqueValue::from(value); - return Ok(Value::Extension { e: opaque_value }); - } else { - let runtime_type = - args.first().ok_or(model::ModelError::TypeError(term_id))?; - let runtime_type = self.import_type(*runtime_type)?; - - let extensions = - args.get(2).ok_or(model::ModelError::TypeError(term_id))?; - let extensions = self.import_extension_set(*extensions)?; - - let value: serde_json::Value = serde_json::from_str(json) - .map_err(|_| model::ModelError::TypeError(term_id))?; - let custom_const = CustomSerialized::new(runtime_type, value, extensions); - let opaque_value = OpaqueValue::new(custom_const); - return Ok(Value::Extension { e: opaque_value }); - } + if let Some([bitwidth, value]) = self.match_symbol(term_id, ConstInt::CTR_NAME)? { + let bitwidth = { + let model::Term::Nat(bitwidth) = self.get_term(bitwidth)? else { + return Err(model::ModelError::TypeError(term_id).into()); + }; + if *bitwidth > 6 { + return Err(model::ModelError::TypeError(term_id).into()); } + *bitwidth as u8 + }; - // NOTE: We have special cased arrays, integers, and floats for now. - // TODO: Allow arbitrary extension values to be imported from terms. - - if symbol_name == ArrayValue::CTR_NAME { - let element_type_term = - args.get(1).ok_or(model::ModelError::TypeError(term_id))?; - let element_type = self.import_type(*element_type_term)?; - - let contents = { - let contents = args.get(2).ok_or(model::ModelError::TypeError(term_id))?; - let contents = self.import_closed_list(*contents)?; - contents - .iter() - .map(|item| self.import_value(*item, *element_type_term)) - .collect::, _>>()? - }; + let value = { + let model::Term::Nat(value) = self.get_term(value)? else { + return Err(model::ModelError::TypeError(term_id).into()); + }; + *value + }; - return Ok(ArrayValue::new(element_type, contents).into()); - } + return Ok(ConstInt::new_u(bitwidth, value) + .map_err(|_| model::ModelError::TypeError(term_id))? + .into()); + } - if symbol_name == ConstInt::CTR_NAME { - let bitwidth = { - let bitwidth = args.first().ok_or(model::ModelError::TypeError(term_id))?; - let model::Term::Nat(bitwidth) = self.get_term(*bitwidth)? else { - return Err(model::ModelError::TypeError(term_id).into()); - }; - if *bitwidth > 6 { - return Err(model::ModelError::TypeError(term_id).into()); - } - *bitwidth as u8 - }; + if let Some([value]) = self.match_symbol(term_id, ConstF64::CTR_NAME)? { + let model::Term::Float(value) = self.get_term(value)? else { + return Err(model::ModelError::TypeError(term_id).into()); + }; - let value = { - let value = args.get(1).ok_or(model::ModelError::TypeError(term_id))?; - let model::Term::Nat(value) = self.get_term(*value)? else { - return Err(model::ModelError::TypeError(term_id).into()); - }; - *value - }; + return Ok(ConstF64::new(value.into_inner()).into()); + } - return Ok(ConstInt::new_u(bitwidth, value) - .map_err(|_| model::ModelError::TypeError(term_id))? - .into()); - } + if let Some([_, _, _, tag, values]) = self.match_symbol(term_id, model::CORE_CONST_ADT)? { + let [variants] = self.expect_symbol(type_id, model::CORE_ADT)?; + let values = self.import_closed_tuple(values)?; + let variants = self.import_closed_list(variants)?; - if symbol_name == ConstF64::CTR_NAME { - let value = { - let value = args.first().ok_or(model::ModelError::TypeError(term_id))?; - let model::Term::Float { value } = self.get_term(*value)? else { - return Err(model::ModelError::TypeError(term_id).into()); - }; - value.into_inner() - }; + let model::Term::Nat(tag) = self.get_term(tag)? else { + return Err(model::ModelError::TypeError(term_id).into()); + }; + + let variant = variants + .get(*tag as usize) + .ok_or(model::ModelError::TypeError(term_id))?; + + let variant = self.import_closed_list(*variant)?; + + let items = values + .iter() + .zip(variant.iter()) + .map(|(value, typ)| self.import_value(*value, *typ)) + .collect::, _>>()?; - return Ok(ConstF64::new(value).into()); + let typ = { + // TODO: Import as a `SumType` directly and avoid the copy. + let typ: Type = self.import_type(type_id)?; + match typ.as_type_enum() { + TypeEnum::Sum(sum) => sum.clone(), + _ => unreachable!(), } + }; + + return Ok(Value::sum(*tag as _, items, typ).unwrap()); + } - Err(error_unsupported!("unknown custom constant value")) + match term_data { + model::Term::Wildcard => Err(error_uninferred!("wildcard")), + model::Term::Var(_) => Err(error_unsupported!("constant value containing a variable")), + + model::Term::Apply(symbol, _) => { + let symbol_name = self.get_symbol_name(*symbol)?; + Err(error_unsupported!( + "unknown custom constant value `{}`", + symbol_name + )) // TODO: This should ultimately include the following cases: // - function definitions // - custom constructors for values } - model::Term::StaticType - | model::Term::Constraint - | model::Term::Const { .. } - | model::Term::List { .. } - | model::Term::ListType { .. } + model::Term::List { .. } | model::Term::Str(_) - | model::Term::StrType | model::Term::Nat(_) - | model::Term::NatType | model::Term::ExtSet { .. } - | model::Term::ExtSetType - | model::Term::Adt { .. } - | model::Term::FuncType { .. } - | model::Term::Control { .. } - | model::Term::ControlType - | model::Term::Type | model::Term::Bytes { .. } - | model::Term::BytesType - | model::Term::Meta - | model::Term::Float { .. } - | model::Term::FloatType - | model::Term::NonLinearConstraint { .. } => { - Err(model::ModelError::TypeError(term_id).into()) - } + | model::Term::Tuple(_) + | model::Term::Float { .. } => Err(model::ModelError::TypeError(term_id).into()), model::Term::ConstFunc { .. } => Err(error_unsupported!("constant function value")), + } + } - model::Term::ConstAdt { tag, values } => { - let model::Term::Adt { variants } = self.get_term(type_id)? else { - return Err(model::ModelError::TypeError(term_id).into()); - }; + fn match_symbol( + &self, + term_id: model::TermId, + name: &str, + ) -> Result, ImportError> { + let term = self.get_term(term_id)?; - let values = self.import_closed_list(*values)?; - let variants = self.import_closed_list(*variants)?; + // TODO: Follow alias chains? - let variant = variants - .get(*tag as usize) - .ok_or(model::ModelError::TypeError(term_id))?; - let variant = self.import_closed_list(*variant)?; + let model::Term::Apply(symbol, args) = term else { + return Ok(None); + }; - let items = values - .iter() - .zip(variant.iter()) - .map(|(value, typ)| self.import_value(*value, *typ)) - .collect::, _>>()?; + if name != self.get_symbol_name(*symbol)? { + return Ok(None); + } - let typ = { - // TODO: Import as a `SumType` directly and avoid the copy. - let typ: Type = self.import_type(type_id)?; - match typ.as_type_enum() { - TypeEnum::Sum(sum) => sum.clone(), - _ => unreachable!(), - } - }; + Ok((*args).try_into().ok()) + } - Ok(Value::sum(*tag as _, items, typ).unwrap()) - } - } + fn expect_symbol( + &self, + term_id: model::TermId, + name: &str, + ) -> Result<[model::TermId; N], ImportError> { + self.match_symbol(term_id, name)? + .ok_or(model::ModelError::TypeError(term_id).into()) } } diff --git a/hugr-core/tests/snapshots/model__roundtrip_add.snap b/hugr-core/tests/snapshots/model__roundtrip_add.snap index 288bce4ba..8e05ed3f7 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_add.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_add.snap @@ -4,24 +4,21 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-add. --- (hugr 0) -(import arithmetic.int.iadd) - -(import arithmetic.int.types.int) - (define-func example.add - [(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)] - [(@ arithmetic.int.types.int)] - (ext) + (core.fn + [arithmetic.int.types.int arithmetic.int.types.int] + [arithmetic.int.types.int] + (ext)) (dfg [%0 %1] [%2] (signature - (-> - [(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)] - [(@ arithmetic.int.types.int)] + (core.fn + [arithmetic.int.types.int arithmetic.int.types.int] + [arithmetic.int.types.int] (ext))) - ((@ arithmetic.int.iadd) [%0 %1] [%2] + (arithmetic.int.iadd [%0 %1] [%2] (signature - (-> - [(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)] - [(@ arithmetic.int.types.int)] + (core.fn + [arithmetic.int.types.int arithmetic.int.types.int] + [arithmetic.int.types.int] (ext arithmetic.int)))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_alias.snap b/hugr-core/tests/snapshots/model__roundtrip_alias.snap index de4d36952..e525888ae 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_alias.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_alias.snap @@ -4,10 +4,8 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-alia --- (hugr 0) -(import arithmetic.int.types.int) +(declare-alias local.float core.type) -(declare-alias local.float type) +(define-alias local.int core.type arithmetic.int.types.int) -(define-alias local.int type (@ arithmetic.int.types.int)) - -(define-alias local.endo type (-> [] [] (ext))) +(define-alias local.endo core.type (core.fn [] [] (ext))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_call.snap b/hugr-core/tests/snapshots/model__roundtrip_call.snap index c9d735ba0..ddec67530 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_call.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_call.snap @@ -4,66 +4,68 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-call --- (hugr 0) -(import compat.meta-json) - -(import arithmetic.int.types.int) - (declare-func example.callee - (forall ?0 ext-set) - [(@ arithmetic.int.types.int)] - [(@ arithmetic.int.types.int)] - (ext ?0 ... arithmetic.int) - (meta - (@ compat.meta-json "description" "\"This is a function declaration.\"")) - (meta (@ compat.meta-json "title" "\"Callee\""))) + (param ?0 core.ext_set) + (core.fn + [arithmetic.int.types.int] + [arithmetic.int.types.int] + (ext ?0 ... arithmetic.int)) + (meta (compat.meta_json "description" "\"This is a function declaration.\"")) + (meta (compat.meta_json "title" "\"Callee\""))) (define-func example.caller - [(@ arithmetic.int.types.int)] - [(@ arithmetic.int.types.int)] - (ext arithmetic.int) + (core.fn + [arithmetic.int.types.int] + [arithmetic.int.types.int] + (ext arithmetic.int)) (meta - (@ - compat.meta-json + (compat.meta_json "description" "\"This defines a function that calls the function which we declared earlier.\"")) - (meta (@ compat.meta-json "title" "\"Caller\"")) + (meta (compat.meta_json "title" "\"Caller\"")) (dfg [%0] [%1] (signature - (-> - [(@ arithmetic.int.types.int)] - [(@ arithmetic.int.types.int)] + (core.fn + [arithmetic.int.types.int] + [arithmetic.int.types.int] (ext arithmetic.int))) - (call (@ example.callee (ext)) [%0] [%1] + ((core.call_indirect + [arithmetic.int.types.int] + [arithmetic.int.types.int] + (ext arithmetic.int) + (example.callee (ext))) + [%0] [%1] (signature - (-> - [(@ arithmetic.int.types.int)] - [(@ arithmetic.int.types.int)] + (core.fn + [arithmetic.int.types.int] + [arithmetic.int.types.int] (ext arithmetic.int)))))) (define-func example.load - [] - [(-> - [(@ arithmetic.int.types.int)] - [(@ arithmetic.int.types.int)] - (ext arithmetic.int))] - (ext) + (core.fn + [] + [(core.fn + [arithmetic.int.types.int] + [arithmetic.int.types.int] + (ext arithmetic.int))] + (ext)) (dfg [] [%0] (signature - (-> + (core.fn [] - [(-> - [(@ arithmetic.int.types.int)] - [(@ arithmetic.int.types.int)] + [(core.fn + [arithmetic.int.types.int] + [arithmetic.int.types.int] (ext arithmetic.int))] (ext))) - (load-func (@ example.caller) [] [%0] + ((core.load_const _ _ example.caller) [] [%0] (signature - (-> + (core.fn [] - [(-> - [(@ arithmetic.int.types.int)] - [(@ arithmetic.int.types.int)] + [(core.fn + [arithmetic.int.types.int] + [arithmetic.int.types.int] (ext arithmetic.int))] (ext)))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap index 0e9d11b4e..975902a24 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap @@ -5,28 +5,31 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cfg. (hugr 0) (define-func example.cfg - (forall ?0 type) - [?0] [?0] (ext) + (param ?0 core.type) + (core.fn [?0] [?0] (ext)) (dfg [%0] [%1] - (signature (-> [?0] [?0] (ext))) + (signature (core.fn [?0] [?0] (ext))) (cfg [%0] [%1] - (signature (-> [?0] [?0] (ext))) + (signature (core.fn [?0] [?0] (ext))) (cfg [%2] [%3] - (signature (-> [(ctrl [?0])] [(ctrl [?0])] (ext))) + (signature (core.fn [(core.ctrl [?0])] [(core.ctrl [?0])] (ext))) (block [%2] [%6] - (signature (-> [(ctrl [?0])] [(ctrl [?0])] (ext))) + (signature (core.fn [(core.ctrl [?0])] [(core.ctrl [?0])] (ext))) (dfg [%4] [%5] - (signature (-> [?0] [(adt [[?0]])] (ext))) - (tag 0 [%4] [%5] (signature (-> [?0] [(adt [[?0]])] (ext)))))) + (signature (core.fn [?0] [(core.adt [[?0]])] (ext))) + ((core.make_adt _ _ 0) [%4] [%5] + (signature (core.fn [?0] [(core.adt [[?0]])] (ext)))))) (block [%6] [%3 %6] - (signature (-> [(ctrl [?0])] [(ctrl [?0]) (ctrl [?0])] (ext))) + (signature + (core.fn + [(core.ctrl [?0])] + [(core.ctrl [?0]) (core.ctrl [?0])] + (ext))) (dfg [%7] [%8] - (signature (-> [?0] [(adt [[?0] [?0]])] (ext))) - (tag - 0 - [%7] [%8] - (signature (-> [?0] [(adt [[?0] [?0]])] (ext)))))))))) + (signature (core.fn [?0] [(core.adt [[?0] [?0]])] (ext))) + ((core.make_adt _ _ 0) [%7] [%8] + (signature (core.fn [?0] [(core.adt [[?0] [?0]])] (ext)))))))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_cond.snap b/hugr-core/tests/snapshots/model__roundtrip_cond.snap index 45c654ac4..e1fdd57fa 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_cond.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_cond.snap @@ -4,45 +4,39 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cond --- (hugr 0) -(import arithmetic.int.types.int) - -(import arithmetic.int.ineg) - (define-func example.cond - [(adt [[] []]) (@ arithmetic.int.types.int)] - [(@ arithmetic.int.types.int)] - (ext) + (core.fn + [(core.adt [[] []]) arithmetic.int.types.int] + [arithmetic.int.types.int] + (ext)) (dfg [%0 %1] [%2] (signature - (-> - [(adt [[] []]) (@ arithmetic.int.types.int)] - [(@ arithmetic.int.types.int)] + (core.fn + [(core.adt [[] []]) arithmetic.int.types.int] + [arithmetic.int.types.int] (ext))) (cond [%0 %1] [%2] (signature - (-> - [(adt [[] []]) (@ arithmetic.int.types.int)] - [(@ arithmetic.int.types.int)] + (core.fn + [(core.adt [[] []]) arithmetic.int.types.int] + [arithmetic.int.types.int] (ext))) (dfg [%3] [%3] (signature - (-> - [(@ arithmetic.int.types.int)] - [(@ arithmetic.int.types.int)] + (core.fn + [arithmetic.int.types.int] + [arithmetic.int.types.int] (ext)))) (dfg [%4] [%5] (signature - (-> - [(@ arithmetic.int.types.int)] - [(@ arithmetic.int.types.int)] - (ext))) - ((@ arithmetic.int.ineg) [%4] [%5] + (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int] (ext))) + (arithmetic.int.ineg [%4] [%5] (signature - (-> - [(@ arithmetic.int.types.int)] - [(@ arithmetic.int.types.int)] + (core.fn + [arithmetic.int.types.int] + [arithmetic.int.types.int] (ext arithmetic.int)))))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_const.snap b/hugr-core/tests/snapshots/model__roundtrip_const.snap index 78630ea50..1a8e0d394 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_const.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_const.snap @@ -4,83 +4,77 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cons --- (hugr 0) -(import collections.array.array) - -(import collections.array.const) - -(import compat.const-json) - -(import arithmetic.float.types.float64) - -(import arithmetic.int.const) - -(import arithmetic.int.types.int) - -(import arithmetic.float.const-f64) - (define-func example.bools - [] [(adt [[] []]) (adt [[] []])] (ext) + (core.fn [] [(core.adt [[] []]) (core.adt [[] []])] (ext)) (dfg [] [%0 %1] - (signature (-> [] [(adt [[] []]) (adt [[] []])] (ext))) - (const (tag 0 []) [] [%0] (signature (-> [] [(adt [[] []])] (ext)))) - (const (tag 1 []) [] [%1] (signature (-> [] [(adt [[] []])] (ext)))))) + (signature (core.fn [] [(core.adt [[] []]) (core.adt [[] []])] (ext))) + ((core.load_const _ _ (core.const.adt [[] []] _ _ 0 (tuple))) [] [%0] + (signature (core.fn [] [(core.adt [[] []])] (ext)))) + ((core.load_const _ _ (core.const.adt [[] []] _ _ 1 (tuple))) [] [%1] + (signature (core.fn [] [(core.adt [[] []])] (ext)))))) (define-func example.make-pair - [] - [(adt - [[(@ collections.array.array 5 (@ arithmetic.int.types.int 6)) - (@ arithmetic.float.types.float64)]])] - (ext) + (core.fn + [] + [(core.adt + [[(collections.array.array 5 (arithmetic.int.types.int 6)) + arithmetic.float.types.float64]])] + (ext)) (dfg [] [%0] (signature - (-> + (core.fn [] - [(adt - [[(@ collections.array.array 5 (@ arithmetic.int.types.int 6)) - (@ arithmetic.float.types.float64)]])] + [(core.adt + [[(collections.array.array 5 (arithmetic.int.types.int 6)) + arithmetic.float.types.float64]])] (ext))) - (const - (tag - 0 - [(@ - collections.array.const - 5 - (@ arithmetic.int.types.int 6) - [(@ arithmetic.int.const 6 1) - (@ arithmetic.int.const 6 2) - (@ arithmetic.int.const 6 3) - (@ arithmetic.int.const 6 4) - (@ arithmetic.int.const 6 5)]) - (@ arithmetic.float.const-f64 -3.0)]) + ((core.load_const + _ + _ + (core.const.adt + [[(collections.array.array 5 (arithmetic.int.types.int 6)) + arithmetic.float.types.float64]] + _ + _ + 0 + (tuple + (collections.array.const + 5 + (arithmetic.int.types.int 6) + [(arithmetic.int.const 6 1) + (arithmetic.int.const 6 2) + (arithmetic.int.const 6 3) + (arithmetic.int.const 6 4) + (arithmetic.int.const 6 5)]) + (arithmetic.float.const-f64 -3.0)))) [] [%0] (signature - (-> + (core.fn [] - [(adt - [[(@ collections.array.array 5 (@ arithmetic.int.types.int 6)) - (@ arithmetic.float.types.float64)]])] + [(core.adt + [[(collections.array.array 5 (arithmetic.int.types.int 6)) + arithmetic.float.types.float64]])] (ext)))))) (define-func example.f64-json - [] [(@ arithmetic.float.types.float64)] (ext) + (core.fn [] [arithmetic.float.types.float64] (ext)) (dfg [] [%0 %1] (signature - (-> + (core.fn [] - [(@ arithmetic.float.types.float64) (@ arithmetic.float.types.float64)] + [arithmetic.float.types.float64 arithmetic.float.types.float64] (ext))) - (const - (@ arithmetic.float.const-f64 1.0) - [] [%0] - (signature (-> [] [(@ arithmetic.float.types.float64)] (ext)))) - (const - (@ - compat.const-json - (@ arithmetic.float.types.float64) - "{\"c\":\"ConstUnknown\",\"v\":{\"value\":1.0}}" - (ext)) + ((core.load_const _ _ (arithmetic.float.const-f64 1.0)) [] [%0] + (signature (core.fn [] [arithmetic.float.types.float64] (ext)))) + ((core.load_const + _ + _ + (compat.const_json + arithmetic.float.types.float64 + (ext) + "{\"c\":\"ConstUnknown\",\"v\":{\"value\":1.0}}")) [] [%1] - (signature (-> [] [(@ arithmetic.float.types.float64)] (ext)))))) + (signature (core.fn [] [arithmetic.float.types.float64] (ext)))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_constraints.snap b/hugr-core/tests/snapshots/model__roundtrip_constraints.snap index abac5878f..9868a730a 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_constraints.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_constraints.snap @@ -4,24 +4,23 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cons --- (hugr 0) -(import collections.array.array) - (declare-func array.replicate - (forall ?0 nat) - (forall ?1 type) - (where (nonlinear ?1)) - [?1] [(@ collections.array.array ?0 ?1)] (ext)) + (param ?0 core.nat) + (param ?1 core.type) + (where (core.nonlinear ?1)) + (core.fn [?1] [(collections.array.array ?0 ?1)] (ext))) (declare-func array.copy - (forall ?0 nat) - (forall ?1 type) - (where (nonlinear ?1)) - [(@ collections.array.array ?0 ?1)] - [(@ collections.array.array ?0 ?1) (@ collections.array.array ?0 ?1)] - (ext)) + (param ?0 core.nat) + (param ?1 core.type) + (where (core.nonlinear ?1)) + (core.fn + [(collections.array.array ?0 ?1)] + [(collections.array.array ?0 ?1) (collections.array.array ?0 ?1)] + (ext))) (define-func util.copy - (forall ?0 type) - (where (nonlinear ?0)) - [?0] [?0 ?0] (ext) - (dfg [%0] [%0 %0] (signature (-> [?0] [?0 ?0] (ext))))) + (param ?0 core.type) + (where (core.nonlinear ?0)) + (core.fn [?0] [?0 ?0] (ext)) + (dfg [%0] [%0 %0] (signature (core.fn [?0] [?0 ?0] (ext))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_loop.snap b/hugr-core/tests/snapshots/model__roundtrip_loop.snap index a7c21dfec..2c82107ce 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_loop.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_loop.snap @@ -5,15 +5,16 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-loop (hugr 0) (define-func example.loop - (forall ?0 type) - [?0] [?0] (ext) + (param ?0 core.type) + (core.fn [?0] [?0] (ext)) (dfg [%0] [%1] - (signature (-> [?0] [?0] (ext))) + (signature (core.fn [?0] [?0] (ext))) (tail-loop [%0] [%1] - (signature (-> [?0] [?0] (ext))) + (signature (core.fn [?0] [?0] (ext))) (dfg [%2] [%3] - (signature (-> [?0] [(adt [[?0] [?0]])] (ext))) - (tag 0 [%2] [%3] (signature (-> [?0] [(adt [[?0] [?0]])] (ext)))))))) + (signature (core.fn [?0] [(core.adt [[?0] [?0]])] (ext))) + ((core.make_adt _ _ 0) [%2] [%3] + (signature (core.fn [?0] [(core.adt [[?0] [?0]])] (ext)))))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_params.snap b/hugr-core/tests/snapshots/model__roundtrip_params.snap index 214cb9755..48d7bd191 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_params.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_params.snap @@ -5,7 +5,7 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-para (hugr 0) (define-func example.swap - (forall ?0 type) - (forall ?1 type) - [?0 ?1] [?1 ?0] (ext) - (dfg [%0 %1] [%1 %0] (signature (-> [?0 ?1] [?1 ?0] (ext))))) + (param ?0 core.type) + (param ?1 core.type) + (core.fn [?0 ?1] [?1 ?0] (ext)) + (dfg [%0 %1] [%1 %0] (signature (core.fn [?0 ?1] [?1 ?0] (ext))))) diff --git a/hugr-model/capnp/hugr-v0.capnp b/hugr-model/capnp/hugr-v0.capnp index 2bbff0f1b..5f9c367cd 100644 --- a/hugr-model/capnp/hugr-v0.capnp +++ b/hugr-model/capnp/hugr-v0.capnp @@ -37,67 +37,28 @@ struct Node { struct Operation { union { - invalid @0 :Void; + custom @0 :NodeId; dfg @1 :Void; cfg @2 :Void; block @3 :Void; - funcDefn @4 :FuncDecl; - funcDecl @5 :FuncDecl; - aliasDefn @6 :AliasDefn; - aliasDecl @7 :AliasDecl; - custom @8 :NodeId; - customFull @9 :NodeId; - tag @10 :UInt16; - tailLoop @11 :Void; - conditional @12 :Void; - callFunc @13 :TermId; - loadFunc @14 :TermId; - constructorDecl @15 :ConstructorDecl; - operationDecl @16 :OperationDecl; - import @17 :Text; - const @18 :TermId; - } - - struct FuncDefn { - name @0 :Text; - params @1 :List(Param); - constraints @2 :List(TermId); - signature @3 :TermId; - } - - struct FuncDecl { - name @0 :Text; - params @1 :List(Param); - constraints @2 :List(TermId); - signature @3 :TermId; - } - - struct AliasDefn { - name @0 :Text; - params @1 :List(Param); - type @2 :TermId; - value @3 :TermId; - } - - struct AliasDecl { - name @0 :Text; - params @1 :List(Param); - type @2 :TermId; - } - - struct ConstructorDecl { - name @0 :Text; - params @1 :List(Param); - constraints @2 :List(TermId); - type @3 :TermId; + funcDefn @4 :Symbol; + funcDecl @5 :Symbol; + aliasDefn @6 :Symbol; + aliasDecl @7 :Symbol; + invalid @8 :Void; + tailLoop @9 :Void; + conditional @10 :Void; + import @11 :Text; + constructorDecl @12 :Symbol; + operationDecl @13 :Symbol; } +} - struct OperationDecl { - name @0 :Text; - params @1 :List(Param); - constraints @2 :List(TermId); - type @3 :TermId; - } +struct Symbol { + name @0 :Text; + params @1 :List(Param); + constraints @2 :List(TermId); + signature @3 :TermId; } struct Region { @@ -115,9 +76,6 @@ struct RegionScope { ports @1 :UInt32; } -# Either `0` for an open scope, or the number of links in the closed scope incremented by `1`. -using LinkScope = UInt32; - enum RegionKind { dataFlow @0; controlFlow @1; @@ -126,51 +84,23 @@ enum RegionKind { struct Term { union { - wildcard @0 :Void; - runtimeType @1 :Void; - staticType @2 :Void; - constraint @3 :Void; + apply :group { + symbol @0 :NodeId; + args @1 :List(TermId); + } variable :group { - variableNode @4 :NodeId; - variableIndex @21 :UInt16; + node @2 :NodeId; + index @3 :UInt16; } - apply @5 :Apply; - applyFull @6 :ApplyFull; - const @7 :Const; - list @8 :ListTerm; - listType @9 :TermId; - string @10 :Text; - stringType @11 :Void; - nat @12 :UInt64; - natType @13 :Void; - extSet @14 :ExtSet; - extSetType @15 :Void; - adt @16 :TermId; - funcType @17 :FuncType; - control @18 :TermId; - controlType @19 :Void; - nonLinearConstraint @20 :TermId; - constFunc @22 :RegionId; - constAdt @23 :ConstAdt; - bytes @24 :Data; - bytesType @25 :Void; - meta @26 :Void; - float @27 :Float64; - floatType @28 :Void; - } - - struct Apply { - symbol @0 :NodeId; - args @1 :List(TermId); - } - - struct ApplyFull { - symbol @0 :NodeId; - args @1 :List(TermId); - } - - struct ListTerm { - items @0 :List(ListPart); + list @4 :List(ListPart); + string @5 :Text; + nat @6 :UInt64; + extSet @7 :List(ExtSetPart); + bytes @8 :Data; + float @9 :Float64; + constFunc @10 :RegionId; + wildcard @11 :Void; + tuple @12 :List(TuplePart); } struct ListPart { @@ -180,10 +110,6 @@ struct Term { } } - struct ExtSet { - items @0 :List(ExtSetPart); - } - struct ExtSetPart { union { extension @0 :Text; @@ -191,30 +117,15 @@ struct Term { } } - struct ConstAdt { - tag @0 :UInt16; - values @1 :TermId; - } - - struct FuncType { - inputs @0 :TermId; - outputs @1 :TermId; - extensions @2 :TermId; - } - - struct Const { - type @0 :TermId; - extensions @1 :TermId; + struct TuplePart { + union { + item @0 :TermId; + splice @1 :TermId; + } } } struct Param { name @0 :Text; type @1 :TermId; - sort @2 :ParamSort; -} - -enum ParamSort { - implicit @0; - explicit @1; } diff --git a/hugr-model/src/v0/binary/read.rs b/hugr-model/src/v0/binary/read.rs index 9118e3799..29dd3c4f0 100644 --- a/hugr-model/src/v0/binary/read.rs +++ b/hugr-model/src/v0/binary/read.rs @@ -15,8 +15,8 @@ pub fn read_from_slice<'a>(slice: &[u8], bump: &'a Bump) -> ReadResult {{ - let mut __list_reader = $reader.$get()?; + ($bump:expr, $reader:expr, $read:expr) => {{ + let mut __list_reader = $reader; let mut __list = BumpVec::with_capacity_in(__list_reader.len() as _, $bump); for __item_reader in __list_reader.iter() { __list.push($read($bump, __item_reader)?); @@ -102,105 +102,91 @@ fn read_operation<'a>( Which::FuncDefn(reader) => { let reader = reader?; let name = bump.alloc_str(reader.get_name()?.to_str()?); - let params = read_list!(bump, reader, get_params, read_param); + let params = read_list!(bump, reader.get_params()?, read_param); let constraints = read_scalar_list!(bump, reader, get_constraints, model::TermId); let signature = model::TermId(reader.get_signature()); - let decl = bump.alloc(model::FuncDecl { + let symbol = bump.alloc(model::Symbol { name, params, constraints, signature, }); - model::Operation::DefineFunc { decl } + model::Operation::DefineFunc(symbol) } Which::FuncDecl(reader) => { let reader = reader?; let name = bump.alloc_str(reader.get_name()?.to_str()?); - let params = read_list!(bump, reader, get_params, read_param); + let params = read_list!(bump, reader.get_params()?, read_param); let constraints = read_scalar_list!(bump, reader, get_constraints, model::TermId); let signature = model::TermId(reader.get_signature()); - let decl = bump.alloc(model::FuncDecl { + let symbol = bump.alloc(model::Symbol { name, params, constraints, signature, }); - model::Operation::DeclareFunc { decl } + model::Operation::DeclareFunc(symbol) } Which::AliasDefn(reader) => { let reader = reader?; let name = bump.alloc_str(reader.get_name()?.to_str()?); - let params = read_list!(bump, reader, get_params, read_param); - let r#type = model::TermId(reader.get_type()); - let value = model::TermId(reader.get_value()); - let decl = bump.alloc(model::AliasDecl { + let params = read_list!(bump, reader.get_params()?, read_param); + let signature = model::TermId(reader.get_signature()); + let symbol = bump.alloc(model::Symbol { name, params, - r#type, + constraints: &[], + signature, }); - model::Operation::DefineAlias { decl, value } + model::Operation::DefineAlias(symbol) } Which::AliasDecl(reader) => { let reader = reader?; let name = bump.alloc_str(reader.get_name()?.to_str()?); - let params = read_list!(bump, reader, get_params, read_param); - let r#type = model::TermId(reader.get_type()); - let decl = bump.alloc(model::AliasDecl { + let params = read_list!(bump, reader.get_params()?, read_param); + let signature = model::TermId(reader.get_signature()); + let symbol = bump.alloc(model::Symbol { name, params, - r#type, + constraints: &[], + signature, }); - model::Operation::DeclareAlias { decl } + model::Operation::DeclareAlias(symbol) } Which::ConstructorDecl(reader) => { let reader = reader?; let name = bump.alloc_str(reader.get_name()?.to_str()?); - let params = read_list!(bump, reader, get_params, read_param); + let params = read_list!(bump, reader.get_params()?, read_param); let constraints = read_scalar_list!(bump, reader, get_constraints, model::TermId); - let r#type = model::TermId(reader.get_type()); - let decl = bump.alloc(model::ConstructorDecl { + let signature = model::TermId(reader.get_signature()); + let symbol = bump.alloc(model::Symbol { name, params, constraints, - r#type, + signature, }); - model::Operation::DeclareConstructor { decl } + model::Operation::DeclareConstructor(symbol) } Which::OperationDecl(reader) => { let reader = reader?; let name = bump.alloc_str(reader.get_name()?.to_str()?); - let params = read_list!(bump, reader, get_params, read_param); + let params = read_list!(bump, reader.get_params()?, read_param); let constraints = read_scalar_list!(bump, reader, get_constraints, model::TermId); - let r#type = model::TermId(reader.get_type()); - let decl = bump.alloc(model::OperationDecl { + let signature = model::TermId(reader.get_signature()); + let symbol = bump.alloc(model::Symbol { name, params, constraints, - r#type, + signature, }); - model::Operation::DeclareOperation { decl } + model::Operation::DeclareOperation(symbol) } - Which::Custom(operation) => model::Operation::Custom { - operation: model::NodeId(operation), - }, - Which::CustomFull(operation) => model::Operation::CustomFull { - operation: model::NodeId(operation), - }, - Which::Tag(tag) => model::Operation::Tag { tag }, + Which::Custom(operation) => model::Operation::Custom(model::NodeId(operation)), Which::TailLoop(()) => model::Operation::TailLoop, Which::Conditional(()) => model::Operation::Conditional, - Which::CallFunc(func) => model::Operation::CallFunc { - func: model::TermId(func), - }, - Which::LoadFunc(func) => model::Operation::LoadFunc { - func: model::TermId(func), - }, Which::Import(name) => model::Operation::Import { name: bump.alloc_str(name?.to_str()?), }, - Which::Const(value) => model::Operation::Const { - value: model::TermId(value), - }, }) } @@ -247,105 +233,40 @@ fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult use hugr_capnp::term::Which; Ok(match reader.which()? { Which::Wildcard(()) => model::Term::Wildcard, - Which::RuntimeType(()) => model::Term::Type, - Which::StaticType(()) => model::Term::StaticType, - Which::Constraint(()) => model::Term::Constraint, Which::String(value) => model::Term::Str(bump.alloc_str(value?.to_str()?)), - Which::StringType(()) => model::Term::StrType, Which::Nat(value) => model::Term::Nat(value), - Which::NatType(()) => model::Term::NatType, - Which::ExtSetType(()) => model::Term::ExtSetType, - Which::ControlType(()) => model::Term::ControlType, - Which::Meta(()) => model::Term::Meta, Which::Variable(reader) => { - let node = model::NodeId(reader.get_variable_node()); - let index = reader.get_variable_index(); + let node = model::NodeId(reader.get_node()); + let index = reader.get_index(); model::Term::Var(model::VarId(node, index)) } Which::Apply(reader) => { - let reader = reader?; - let symbol = model::NodeId(reader.get_symbol()); - let args = read_scalar_list!(bump, reader, get_args, model::TermId); - model::Term::Apply { symbol, args } - } - - Which::ApplyFull(reader) => { - let reader = reader?; let symbol = model::NodeId(reader.get_symbol()); let args = read_scalar_list!(bump, reader, get_args, model::TermId); - model::Term::ApplyFull { symbol, args } - } - - Which::Const(reader) => { - let reader = reader?; - model::Term::Const { - r#type: model::TermId(reader.get_type()), - extensions: model::TermId(reader.get_extensions()), - } + model::Term::Apply(symbol, args) } Which::List(reader) => { - let reader = reader?; - let parts = read_list!(bump, reader, get_items, read_list_part); - model::Term::List { parts } + let parts = read_list!(bump, reader?, read_list_part); + model::Term::List(parts) } - Which::ListType(item_type) => model::Term::ListType { - item_type: model::TermId(item_type), - }, - Which::ExtSet(reader) => { - let reader = reader?; - let parts = read_list!(bump, reader, get_items, read_ext_set_part); - model::Term::ExtSet { parts } - } - - Which::Adt(variants) => model::Term::Adt { - variants: model::TermId(variants), - }, - - Which::FuncType(reader) => { - let reader = reader?; - let inputs = model::TermId(reader.get_inputs()); - let outputs = model::TermId(reader.get_outputs()); - let extensions = model::TermId(reader.get_extensions()); - model::Term::FuncType { - inputs, - outputs, - extensions, - } + let parts = read_list!(bump, reader?, read_ext_set_part); + model::Term::ExtSet(parts) } - Which::Control(values) => model::Term::Control { - values: model::TermId(values), - }, - - Which::NonLinearConstraint(term) => model::Term::NonLinearConstraint { - term: model::TermId(term), - }, - - Which::ConstFunc(region) => model::Term::ConstFunc { - region: model::RegionId(region), - }, - - Which::ConstAdt(reader) => { - let reader = reader?; - let tag = reader.get_tag(); - let values = model::TermId(reader.get_values()); - model::Term::ConstAdt { tag, values } + Which::Tuple(reader) => { + let parts = read_list!(bump, reader?, read_tuple_part); + model::Term::Tuple(parts) } - Which::Bytes(bytes) => model::Term::Bytes { - data: bump.alloc_slice_copy(bytes?), - }, - Which::BytesType(()) => model::Term::BytesType, + Which::ConstFunc(region) => model::Term::ConstFunc(model::RegionId(region)), - Which::Float(value) => model::Term::Float { - value: value.into(), - }, - Which::FloatType(()) => model::Term::FloatType, + Which::Bytes(bytes) => model::Term::Bytes(bump.alloc_slice_copy(bytes?)), + Which::Float(value) => model::Term::Float(value.into()), }) } @@ -360,6 +281,17 @@ fn read_list_part( }) } +fn read_tuple_part( + _: &Bump, + reader: hugr_capnp::term::tuple_part::Reader, +) -> ReadResult { + use hugr_capnp::term::tuple_part::Which; + Ok(match reader.which()? { + Which::Item(term) => model::TuplePart::Item(model::TermId(term)), + Which::Splice(list) => model::TuplePart::Splice(model::TermId(list)), + }) +} + fn read_ext_set_part<'a>( bump: &'a Bump, reader: hugr_capnp::term::ext_set_part::Reader, @@ -377,11 +309,5 @@ fn read_param<'a>( ) -> ReadResult> { let name = bump.alloc_str(reader.get_name()?.to_str()?); let r#type = model::TermId(reader.get_type()); - - let sort = match reader.get_sort()? { - hugr_capnp::ParamSort::Implicit => model::ParamSort::Implicit, - hugr_capnp::ParamSort::Explicit => model::ParamSort::Explicit, - }; - - Ok(model::Param { name, r#type, sort }) + Ok(model::Param { name, r#type }) } diff --git a/hugr-model/src/v0/binary/write.rs b/hugr-model/src/v0/binary/write.rs index b753807a2..eb644c58f 100644 --- a/hugr-model/src/v0/binary/write.rs +++ b/hugr-model/src/v0/binary/write.rs @@ -46,56 +46,33 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &mode model::Operation::Block => builder.set_block(()), model::Operation::TailLoop => builder.set_tail_loop(()), model::Operation::Conditional => builder.set_conditional(()), - model::Operation::Tag { tag } => builder.set_tag(*tag), - model::Operation::Custom { operation } => builder.set_custom(operation.0), - model::Operation::CustomFull { operation } => { - builder.set_custom_full(operation.0); - } - model::Operation::CallFunc { func } => builder.set_call_func(func.0), - model::Operation::LoadFunc { func } => builder.set_load_func(func.0), - - model::Operation::DefineFunc { decl } => { - let mut builder = builder.init_func_defn(); - builder.set_name(decl.name); - write_list!(builder, init_params, write_param, decl.params); - let _ = builder.set_constraints(model::TermId::unwrap_slice(decl.constraints)); - builder.set_signature(decl.signature.0); + model::Operation::Custom(operation) => builder.set_custom(operation.0), + + model::Operation::DefineFunc(symbol) => { + let builder = builder.init_func_defn(); + write_symbol(builder, symbol); } - model::Operation::DeclareFunc { decl } => { - let mut builder = builder.init_func_decl(); - builder.set_name(decl.name); - write_list!(builder, init_params, write_param, decl.params); - let _ = builder.set_constraints(model::TermId::unwrap_slice(decl.constraints)); - builder.set_signature(decl.signature.0); + model::Operation::DeclareFunc(symbol) => { + let builder = builder.init_func_decl(); + write_symbol(builder, symbol); } - model::Operation::DefineAlias { decl, value } => { - let mut builder = builder.init_alias_defn(); - builder.set_name(decl.name); - write_list!(builder, init_params, write_param, decl.params); - builder.set_type(decl.r#type.0); - builder.set_value(value.0); + model::Operation::DefineAlias(symbol) => { + let builder = builder.init_alias_defn(); + write_symbol(builder, symbol); } - model::Operation::DeclareAlias { decl } => { - let mut builder = builder.init_alias_decl(); - builder.set_name(decl.name); - write_list!(builder, init_params, write_param, decl.params); - builder.set_type(decl.r#type.0); + model::Operation::DeclareAlias(symbol) => { + let builder = builder.init_alias_decl(); + write_symbol(builder, symbol); } - model::Operation::DeclareConstructor { decl } => { - let mut builder = builder.init_constructor_decl(); - builder.set_name(decl.name); - write_list!(builder, init_params, write_param, decl.params); - let _ = builder.set_constraints(model::TermId::unwrap_slice(decl.constraints)); - builder.set_type(decl.r#type.0); + model::Operation::DeclareConstructor(symbol) => { + let builder = builder.init_constructor_decl(); + write_symbol(builder, symbol); } - model::Operation::DeclareOperation { decl } => { - let mut builder = builder.init_operation_decl(); - builder.set_name(decl.name); - write_list!(builder, init_params, write_param, decl.params); - let _ = builder.set_constraints(model::TermId::unwrap_slice(decl.constraints)); - builder.set_type(decl.r#type.0); + model::Operation::DeclareOperation(symbol) => { + let builder = builder.init_operation_decl(); + write_symbol(builder, symbol); } model::Operation::Import { name } => { @@ -103,18 +80,19 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &mode } model::Operation::Invalid => builder.set_invalid(()), - - model::Operation::Const { value } => builder.set_const(value.0), } } +fn write_symbol(mut builder: hugr_capnp::symbol::Builder, symbol: &model::Symbol) { + builder.set_name(symbol.name); + write_list!(builder, init_params, write_param, symbol.params); + let _ = builder.set_constraints(model::TermId::unwrap_slice(symbol.constraints)); + builder.set_signature(symbol.signature.0); +} + fn write_param(mut builder: hugr_capnp::param::Builder, param: &model::Param) { builder.set_name(param.name); builder.set_type(param.r#type.0); - builder.set_sort(match param.sort { - model::ParamSort::Implicit => hugr_capnp::ParamSort::Implicit, - model::ParamSort::Explicit => hugr_capnp::ParamSort::Explicit, - }); } fn write_region(mut builder: hugr_capnp::region::Builder, region: &model::Region) { @@ -143,104 +121,56 @@ fn write_region_scope(mut builder: hugr_capnp::region_scope::Builder, scope: &mo fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) { match term { model::Term::Wildcard => builder.set_wildcard(()), - model::Term::Type => builder.set_runtime_type(()), - model::Term::StaticType => builder.set_static_type(()), - model::Term::Constraint => builder.set_constraint(()), model::Term::Var(model::VarId(node, index)) => { let mut builder = builder.init_variable(); - builder.set_variable_node(node.0); - builder.set_variable_index(*index); + builder.set_node(node.0); + builder.set_index(*index); } - model::Term::ListType { item_type } => builder.set_list_type(item_type.0), model::Term::Str(value) => builder.set_string(value), - model::Term::StrType => builder.set_string_type(()), model::Term::Nat(value) => builder.set_nat(*value), - model::Term::NatType => builder.set_nat_type(()), - model::Term::ExtSetType => builder.set_ext_set_type(()), - model::Term::Adt { variants } => builder.set_adt(variants.0), - model::Term::Const { r#type, extensions } => { - let mut builder = builder.init_const(); - builder.set_type(r#type.0); - builder.set_extensions(extensions.0); - } - model::Term::Control { values } => builder.set_control(values.0), - model::Term::ControlType => builder.set_control_type(()), + model::Term::ConstFunc(region) => builder.set_const_func(region.0), + model::Term::Bytes(data) => builder.set_bytes(data), + model::Term::Float(value) => builder.set_float(value.into_inner()), - model::Term::Apply { symbol, args } => { + model::Term::Apply(symbol, args) => { let mut builder = builder.init_apply(); builder.set_symbol(symbol.0); let _ = builder.set_args(model::TermId::unwrap_slice(args)); } - model::Term::ApplyFull { symbol, args } => { - let mut builder = builder.init_apply_full(); - builder.set_symbol(symbol.0); - let _ = builder.set_args(model::TermId::unwrap_slice(args)); - } - - model::Term::List { parts } => { - let mut builder = builder.init_list(); - write_list!(builder, init_items, write_list_item, parts); - } - - model::Term::ExtSet { parts } => { - let mut builder = builder.init_ext_set(); - write_list!(builder, init_items, write_ext_set_item, parts); - } - - model::Term::FuncType { - inputs, - outputs, - extensions, - } => { - let mut builder = builder.init_func_type(); - builder.set_inputs(inputs.0); - builder.set_outputs(outputs.0); - builder.set_extensions(extensions.0); + model::Term::List(parts) => { + write_list!(builder, init_list, write_list_part, parts); } - model::Term::NonLinearConstraint { term } => { - builder.set_non_linear_constraint(term.0); + model::Term::ExtSet(parts) => { + write_list!(builder, init_ext_set, write_ext_set_part, parts); } - model::Term::ConstFunc { region } => { - builder.set_const_func(region.0); + model::Term::Tuple(parts) => { + write_list!(builder, init_tuple, write_tuple_part, parts); } - - model::Term::ConstAdt { tag, values } => { - let mut builder = builder.init_const_adt(); - builder.set_tag(*tag); - builder.set_values(values.0); - } - - model::Term::Bytes { data } => { - builder.set_bytes(data); - } - - model::Term::BytesType => { - builder.set_bytes_type(()); - } - - model::Term::Meta => { - builder.set_meta(()); - } - model::Term::Float { value } => builder.set_float(value.into_inner()), - model::Term::FloatType => builder.set_float_type(()), } } -fn write_list_item(mut builder: hugr_capnp::term::list_part::Builder, item: &model::ListPart) { - match item { +fn write_list_part(mut builder: hugr_capnp::term::list_part::Builder, part: &model::ListPart) { + match part { model::ListPart::Item(term_id) => builder.set_item(term_id.0), model::ListPart::Splice(term_id) => builder.set_splice(term_id.0), } } -fn write_ext_set_item( +fn write_tuple_part(mut builder: hugr_capnp::term::tuple_part::Builder, item: &model::TuplePart) { + match item { + model::TuplePart::Item(term_id) => builder.set_item(term_id.0), + model::TuplePart::Splice(term_id) => builder.set_splice(term_id.0), + } +} + +fn write_ext_set_part( mut builder: hugr_capnp::term::ext_set_part::Builder, - item: &model::ExtSetPart, + part: &model::ExtSetPart, ) { - match item { + match part { model::ExtSetPart::Extension(ext) => builder.set_extension(ext), model::ExtSetPart::Splice(term_id) => builder.set_splice(term_id.0), } diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index c7748aaf2..99a101738 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -91,10 +91,167 @@ use ordered_float::OrderedFloat; use smol_str::SmolStr; use thiserror::Error; +/// Core function types. +/// +/// - **Parameter:** `?inputs : (core.list core.type)` +/// - **Parameter:** `?outputs : (core.list core.type)` +/// - **Parameter:** `?ext : core.ext-set` +/// - **Result:** `core.type` +pub const CORE_FN: &str = "core.fn"; + +/// The type of runtime types. +/// +/// Runtime types are the types of values that can flow between nodes at runtime. +/// +/// - **Result:** `?type : core.static` +pub const CORE_TYPE: &str = "core.type"; + +/// The type of static types. +/// +/// Static types are the types of statically known parameters. +/// +/// This is the only term that is its own type. +/// +/// - **Result:** `?type : core.static` +pub const CORE_STATIC: &str = "core.static"; + +/// The type of constraints. +/// +/// - **Result:** `?type : core.static` +pub const CORE_CONSTRAINT: &str = "core.constraint"; + +/// The constraint for non-linear runtime data. +/// +/// Runtime values are copied implicitly by connecting an output port to more +/// than one input port. Similarly runtime values can be deleted implicitly when +/// an output port is not connected to any input port. In either of these cases +/// the type of the runtime value must satisfy this constraint. +/// +/// - **Parameter:** `?type : core.type` +/// - **Result:** `core.constraint` +pub const CORE_NON_LINEAR: &str = "core.nonlinear"; + +/// The type of metadata. +/// +/// - **Result:** `?type : core.static` +pub const CORE_META: &str = "core.meta"; + +/// Runtime algebraic data types. +/// +/// Algebraic data types are sums of products of other runtime types. +/// +/// - **Parameter:** `?variants : (core.list (core.list core.type))` +/// - **Result:** `core.type` +pub const CORE_ADT: &str = "core.adt"; + +/// Type of string literals. +/// +/// - **Result:** `core.static` +pub const CORE_STR_TYPE: &str = "core.str"; + +/// Type of natural number literals. +/// +/// - **Result:** `core.static` +pub const CORE_NAT_TYPE: &str = "core.nat"; + +/// Type of bytes literals. +/// +/// - **Result:** `core.static` +pub const CORE_BYTES_TYPE: &str = "core.bytes"; + +/// Type of float literals. +/// +/// - **Result:** `core.static` +pub const CORE_FLOAT_TYPE: &str = "core.float"; + +/// Type of a control flow edge. +/// +/// - **Parameter:** `?types : (core.list core.type)` +/// - **Result:** `core.ctrl_type` +pub const CORE_CTRL: &str = "core.ctrl"; + +/// The type of the types for control flow edges. +/// +/// - **Result:** `?type : core.static` +pub const CORE_CTRL_TYPE: &str = "core.ctrl_type"; + +/// The type of extension sets. +/// +/// - **Result:** `?type : core.static` +pub const CORE_EXT_SET: &str = "core.ext_set"; + +/// The type for runtime constants. +/// +/// - **Parameter:** `?type : core.type` +/// - **Parameter:** `?ext : core.ext_set` +/// - **Result:** `core.static` +pub const CORE_CONST: &str = "core.const"; + +/// Constants for runtime algebraic data types. +/// +/// - **Parameter:** `?variants : (core.list core.type)` +/// - **Parameter:** `?ext : core.ext_set` +/// - **Parameter:** `?types : (core.list core.static)` +/// - **Parameter:** `?tag : core.nat` +/// - **Parameter:** `?values : (core.tuple ?types)` +/// - **Result:** `(core.const (core.adt ?variants) ?ext)` +pub const CORE_CONST_ADT: &str = "core.const.adt"; + +/// The type for lists of static data. +/// +/// Lists are finite sequences such that all elements have the same type. +/// For heterogeneous sequences, see [`CORE_TUPLE_TYPE`]. +/// +/// - **Parameter:** `?type : core.static` +/// - **Result:** `core.static` +pub const CORE_LIST_TYPE: &str = "core.list"; + +/// The type for tuples of static data. +/// +/// Tuples are finite sequences that allow elements to have different types. +/// For homogeneous sequences, see [`CORE_LIST_TYPE`]. +/// +/// - **Parameter:** `?types : (core.list core.static)` +/// - **Result:** `core.static` +pub const CORE_TUPLE_TYPE: &str = "core.tuple"; + +/// Operation to call a statically known function. +/// +/// - **Parameter:** `?inputs : (core.list core.type)` +/// - **Parameter:** `?outputs : (core.list core.type)` +/// - **Parameter:** `?ext : core.ext_set` +/// - **Parameter:** `?func : (core.const (core.fn ?inputs ?outputs ?ext) ?ext)` +/// - **Result:** `(core.fn ?inputs ?outputs ?ext)` +pub const CORE_CALL: &str = "core.call"; + +/// Operation to call a functiion known at runtime. +/// +/// - **Parameter:** `?inputs : (core.list core.type)` +/// - **Parameter:** `?outputs : (core.list core.type)` +/// - **Parameter:** `?ext : core.ext_set` +/// - **Result:** `(core.fn [(core.fn ?inputs ?outputs ?ext) ?inputs ...] ?outputs ?ext)` +pub const CORE_CALL_INDIRECT: &str = "core.call_indirect"; + +/// Operation to load a constant value. +/// +/// - **Parameter:** `?type : core.type` +/// - **Parameter:** `?ext : core.ext_set` +/// - **Parameter:** `?value : (core.const ?type ?ext)` +/// - **Result:** `(core.fn [] [?type] ?ext)` +pub const CORE_LOAD_CONST: &str = "core.load_const"; + +/// Operation to create a value of an algebraic data type. +/// +/// - **Parameter:** `?variants : (core.list (core.list core.type))` +/// - **Parameter:** `?types : (core.list core.type)` +/// - **Parameter:** `?tag : core.nat` +/// - **Result:** `(core.fn ?types [(core.adt ?variants)] (ext))` +pub const CORE_MAKE_ADT: &str = "core.make_adt"; + /// Constructor for documentation metadata. /// -/// - **Parameter:** `?description : str` -/// - **Result:** `meta` +/// - **Parameter:** `?description : core.str` +/// - **Result:** `core.meta` pub const CORE_META_DESCRIPTION: &str = "core.meta.description"; /// Constructor for JSON encoded metadata. @@ -103,10 +260,10 @@ pub const CORE_META_DESCRIPTION: &str = "core.meta.description"; /// The intention is to deprecate this in the future in favor of metadata /// expressed with custom constructors. /// -/// - **Parameter:** `?name : str` -/// - **Parameter:** `?json : str` -/// - **Result:** `meta` -pub const COMPAT_META_JSON: &str = "compat.meta-json"; +/// - **Parameter:** `?name : core.str` +/// - **Parameter:** `?json : core.str` +/// - **Result:** `core.meta` +pub const COMPAT_META_JSON: &str = "compat.meta_json"; /// Constructor for JSON encoded constants. /// @@ -114,11 +271,11 @@ pub const COMPAT_META_JSON: &str = "compat.meta-json"; /// The intention is to deprecate this in the future in favor of constants /// expressed with custom constructors. /// -/// - **Parameter:** `?type : type` -/// - **Parameter:** `?json : str` -/// - **Parameter:** `?exts : ext-set` -/// - **Result:** `(const ?type ?exts)` -pub const COMPAT_CONST_JSON: &str = "compat.const-json"; +/// - **Parameter:** `?type : core.type` +/// - **Parameter:** `?ext : core.ext_set` +/// - **Parameter:** `?json : core.str` +/// - **Result:** `(core.const ?type ?ext)` +pub const COMPAT_CONST_JSON: &str = "compat.const_json"; pub mod binary; pub mod scope; @@ -305,58 +462,15 @@ pub enum Operation<'a> { /// Basic blocks. Block, /// Function definitions. - DefineFunc { - /// The declaration of the function to be defined. - decl: &'a FuncDecl<'a>, - }, + DefineFunc(&'a Symbol<'a>), /// Function declarations. - DeclareFunc { - /// The function to be declared. - decl: &'a FuncDecl<'a>, - }, - /// Function calls. - CallFunc { - /// The function to be called. - func: TermId, - }, - /// Function constants. - LoadFunc { - /// The function to be loaded. - func: TermId, - }, + DeclareFunc(&'a Symbol<'a>), /// Custom operation. - /// - /// The node's parameters correspond to the explicit parameter of the custom operation, - /// leaving out the implicit parameters. Once the declaration of the custom operation - /// becomes known by resolving the reference, the node can be transformed into a [`Operation::CustomFull`] - /// by inferring terms for the implicit parameters or at least filling them in with a wildcard term. - Custom { - /// The symbol of the custom operation. - operation: NodeId, - }, - /// Custom operation with full parameters. - /// - /// The node's parameters correspond to both the explicit and implicit parameters of the custom operation. - /// Since this can be tedious to write, the [`Operation::Custom`] variant can be used to indicate that - /// the implicit parameters should be inferred. - CustomFull { - /// The symbol of the custom operation. - operation: NodeId, - }, + Custom(NodeId), /// Alias definitions. - DefineAlias { - /// The declaration of the alias to be defined. - decl: &'a AliasDecl<'a>, - /// The value of the alias. - value: TermId, - }, - + DefineAlias(&'a Symbol<'a>), /// Alias declarations. - DeclareAlias { - /// The alias to be declared. - decl: &'a AliasDecl<'a>, - }, - + DeclareAlias(&'a Symbol<'a>), /// Tail controlled loop. /// Nodes with this operation contain a dataflow graph that is executed in a loop. /// The loop body is executed at least once, producing a result that indicates whether @@ -378,51 +492,33 @@ pub enum Operation<'a> { /// - **Outputs**: `outputs` Conditional, - /// Create an ADT value from a sequence of inputs. - Tag { - /// The tag of the ADT value. - tag: u16, - }, - /// Declaration for a term constructor. /// /// Nodes with this operation must be within a module region. - DeclareConstructor { - /// The declaration of the constructor. - decl: &'a ConstructorDecl<'a>, - }, + DeclareConstructor(&'a Symbol<'a>), /// Declaration for a operation. /// /// Nodes with this operation must be within a module region. - DeclareOperation { - /// The declaration of the operation. - decl: &'a OperationDecl<'a>, - }, + DeclareOperation(&'a Symbol<'a>), /// Import a symbol. Import { /// The name of the symbol to be imported. name: &'a str, }, - - /// Create a constant value. - Const { - /// The term that describes how to construct the constant value. - value: TermId, - }, } impl<'a> Operation<'a> { /// Returns the symbol introduced by the operation, if any. pub fn symbol(&self) -> Option<&'a str> { match self { - Operation::DefineFunc { decl } => Some(decl.name), - Operation::DeclareFunc { decl } => Some(decl.name), - Operation::DefineAlias { decl, .. } => Some(decl.name), - Operation::DeclareAlias { decl } => Some(decl.name), - Operation::DeclareConstructor { decl } => Some(decl.name), - Operation::DeclareOperation { decl } => Some(decl.name), + Operation::DefineFunc(symbol) => Some(symbol.name), + Operation::DeclareFunc(symbol) => Some(symbol.name), + Operation::DefineAlias(symbol) => Some(symbol.name), + Operation::DeclareAlias(symbol) => Some(symbol.name), + Operation::DeclareConstructor(symbol) => Some(symbol.name), + Operation::DeclareOperation(symbol) => Some(symbol.name), Operation::Import { name } => Some(name), _ => None, } @@ -481,56 +577,19 @@ pub enum RegionKind { Module = 2, } -/// A function declaration. +/// A symbol. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct FuncDecl<'a> { - /// The name of the function to be declared. +pub struct Symbol<'a> { + /// The name of the symbol. pub name: &'a str, - /// The static parameters of the function. + /// The static parameters. pub params: &'a [Param<'a>], /// The constraints on the static parameters. pub constraints: &'a [TermId], - /// The signature of the function. + /// The signature of the symbol. pub signature: TermId, } -/// An alias declaration. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct AliasDecl<'a> { - /// The name of the alias to be declared. - pub name: &'a str, - /// The static parameters of the alias. - pub params: &'a [Param<'a>], - /// The type of the alias. - pub r#type: TermId, -} - -/// A term constructor declaration. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct ConstructorDecl<'a> { - /// The name of the constructor to be declared. - pub name: &'a str, - /// The static parameters of the constructor. - pub params: &'a [Param<'a>], - /// The constraints on the static parameters. - pub constraints: &'a [TermId], - /// The type of the constructed term. - pub r#type: TermId, -} - -/// An operation declaration. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct OperationDecl<'a> { - /// The name of the operation to be declared. - pub name: &'a str, - /// The static parameters of the operation. - pub params: &'a [Param<'a>], - /// The constraints on the static parameters. - pub constraints: &'a [TermId], - /// The type of the operation. This must be a function type. - pub r#type: TermId, -} - /// An index of a variable within a node's parameter list. pub type VarIndex = u16; @@ -541,193 +600,68 @@ pub enum Term<'a> { #[default] Wildcard, - /// The type of runtime types. - /// - /// `type : static` - Type, - - /// The type of static types. - /// - /// `static : static` - StaticType, - - /// The type of constraints. - /// - /// `constraint : static` - Constraint, - /// A local variable. Var(VarId), - /// A symbolic function application. - /// - /// The arguments of this application cover only the explicit parameters of the referenced declaration, - /// leaving out the implicit parameters. Once the type of the declaration is known, the implicit parameters - /// can be inferred and the term replaced with [`Term::ApplyFull`]. - /// - /// `(GLOBAL ARG-0 ... ARG-n)` - Apply { - /// Reference to the symbol to apply. - symbol: NodeId, - /// Arguments to the function, covering only the explicit parameters. - args: &'a [TermId], - }, - - /// A symbolic function application with all arguments applied. + /// Apply a symbol to a sequence of arguments. /// - /// The arguments to this application cover both the implicit and explicit parameters of the referenced declaration. - /// Since this can be tedious to write out, only the explicit parameters can be provided via [`Term::Apply`]. - /// - /// `(@GLOBAL ARG-0 ... ARG-n)` - ApplyFull { - /// Reference to the symbol to apply. - symbol: NodeId, - /// Arguments to the function, covering both implicit and explicit parameters. - args: &'a [TermId], - }, + /// The symbol is defined by a node in the same graph. The type of this term + /// is derived from instantiating the symbol's parameters in the symbol's + /// signature. + Apply(NodeId, &'a [TermId]), - /// Type for a constant runtime value. + /// List of static data. /// - /// `(const T) : static` where `T : type`. - Const { - /// The runtime type of the constant value. - /// - /// **Type:** `type` - r#type: TermId, - /// The extension set required to be present in order to use the constant value. - /// - /// **Type:** `ext-set` - extensions: TermId, - }, - - /// A list. May include individual items or other lists to be spliced in. - List { - /// The parts of the list. - parts: &'a [ListPart], - }, - - /// The type of lists, given a type for the items. + /// Lists can include individual items or other lists to be spliced in. /// - /// `(list T) : static` where `T : static`. - ListType { - /// The type of the items in the list. - /// - /// `item_type : static` - item_type: TermId, - }, + /// **Type:** `(core.list ?t)` + List(&'a [ListPart]), /// A literal string. /// - /// `"STRING" : str` + /// **Type:** `core.str` Str(&'a str), - /// The type of literal strings. - /// - /// `str : static` - StrType, - /// A literal natural number. /// - /// `N : nat` + /// **Type:** `core.nat` Nat(u64), - /// The type of literal natural numbers. - /// - /// `nat : static` - NatType, - /// Extension set. - ExtSet { - /// The parts of the extension set. - /// - /// Since extension sets are unordered, the parts may occur in any order. - parts: &'a [ExtSetPart<'a>], - }, - - /// The type of extension sets. /// - /// `ext-set : static` - ExtSetType, + /// **Type:** `core.ext_set` + ExtSet(&'a [ExtSetPart<'a>]), - /// An algebraic data type. + /// A constant anonymous function. /// - /// `(adt VARIANTS) : type` where `VARIANTS : (list (list type))`. - Adt { - /// List of variants in the algrebaic data type. - /// Each of the variants is itself a list of runtime types. - variants: TermId, - }, - - /// The type of functions, given lists of input and output types and an extension set. - FuncType { - /// The input types of the function, given as a list of runtime types. - /// - /// `inputs : (list type)` - inputs: TermId, - /// The output types of the function, given as a list of runtime types. - /// - /// `outputs : (list type)` - outputs: TermId, - /// The set of extensions that the function requires to be present in - /// order to be called. - /// - /// `extensions : ext-set` - extensions: TermId, - }, + /// **Type:** `(core.const (core.fn ?ins ?outs ?ext) (ext))` + ConstFunc(RegionId), - /// Control flow. + /// A literal byte string. /// - /// `(ctrl VALUES) : ctrl` where `VALUES : (list type)`. - Control { - /// List of values. - values: TermId, - }, + /// **Type:**: `core.bytes` + Bytes(&'a [u8]), - /// Type of control flow edges. + /// A literal floating-point number. /// - /// `ctrl : static` - ControlType, - - /// Constraint that requires a runtime type to be copyable and discardable. - NonLinearConstraint { - /// The runtime type that must be copyable and discardable. - term: TermId, - }, - - /// A constant anonymous function. - ConstFunc { - /// The body of the constant anonymous function. - region: RegionId, - }, - - /// A constant value for an algebraic data type. - ConstAdt { - /// The tag of the variant. - tag: u16, - /// The values of the variant. - values: TermId, - }, + /// **Type:** `core.float` + Float(OrderedFloat), - /// A literal byte string. - Bytes { - /// The data of the byte string. - data: &'a [u8], - }, - - /// The type of byte strings. - BytesType, - - /// The type of metadata. - Meta, - - /// A literal floating-point number. - Float { - /// The value of the floating-point number. - value: OrderedFloat, - }, + /// Tuple of static data. + /// + /// Tuples can include individual items or other tuples to be spliced in. + /// + /// **Type:** `(core.tuple ?types)` + Tuple(&'a [TuplePart]), +} - /// The type of floating-point numbers. - FloatType, +/// A part of a tuple term. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum TuplePart { + /// A single item. + Item(TermId), + /// A tuple to be spliced into the parent tuple. + Splice(TermId), } /// A part of a list term. @@ -751,25 +685,12 @@ pub enum ExtSetPart<'a> { /// A parameter to a function or alias. /// /// Parameter names must be unique within a parameter list. -/// Implicit and explicit parameters share a namespace. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Param<'a> { /// The name of the parameter. pub name: &'a str, /// The type of the parameter. pub r#type: TermId, - /// The sort of the parameter (implicit or explicit). - pub sort: ParamSort, -} - -/// The sort of a parameter (implicit or explicit). -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum ParamSort { - /// The parameter is implicit and should be inferred, unless a full application form is used - /// (see [`Term::ApplyFull`] and [`Operation::CustomFull`]). - Implicit, - /// The parameter is explicit and should always be provided. - Explicit, } /// Errors that can occur when traversing and interpreting the model. diff --git a/hugr-model/src/v0/text/hugr.pest b/hugr-model/src/v0/text/hugr.pest index 60e86d613..b9ced84a2 100644 --- a/hugr-model/src/v0/text/hugr.pest +++ b/hugr-model/src/v0/text/hugr.pest @@ -27,8 +27,6 @@ node = { | node_block | node_define_func | node_declare_func - | node_call_func - | node_load_func | node_define_alias | node_declare_alias | node_declare_ctr @@ -37,7 +35,6 @@ node = { | node_cond | node_tag | node_import - | node_const | node_custom } @@ -46,8 +43,6 @@ node_cfg = { "(" ~ "cfg" ~ port_lists? ~ signature? ~ meta* ~ regi node_block = { "(" ~ "block" ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } node_define_func = { "(" ~ "define-func" ~ func_header ~ meta* ~ region* ~ ")" } node_declare_func = { "(" ~ "declare-func" ~ func_header ~ meta* ~ ")" } -node_call_func = { "(" ~ "call" ~ term ~ port_lists? ~ signature? ~ meta* ~ ")" } -node_load_func = { "(" ~ "load-func" ~ term ~ port_lists? ~ signature? ~ meta* ~ ")" } node_define_alias = { "(" ~ "define-alias" ~ alias_header ~ term ~ meta* ~ ")" } node_declare_alias = { "(" ~ "declare-alias" ~ alias_header ~ meta* ~ ")" } node_declare_ctr = { "(" ~ "declare-ctr" ~ ctr_header ~ meta* ~ ")" } @@ -56,20 +51,16 @@ node_tail_loop = { "(" ~ "tail-loop" ~ port_lists? ~ signature? ~ meta* node_cond = { "(" ~ "cond" ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } node_tag = { "(" ~ "tag" ~ tag ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } node_import = { "(" ~ "import" ~ symbol ~ meta* ~ ")" } -node_const = { "(" ~ "const" ~ term ~ port_lists? ~ signature? ~ meta* ~ ")" } -node_custom = { "(" ~ (term_apply | term_apply_full) ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } +node_custom = { "(" ~ term_apply ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } signature = { "(" ~ "signature" ~ term ~ ")" } -func_header = { symbol ~ param* ~ where_clause* ~ term ~ term ~ term } +func_header = { symbol ~ param* ~ where_clause* ~ term } alias_header = { symbol ~ param* ~ term } ctr_header = { symbol ~ param* ~ where_clause* ~ term } operation_header = { symbol ~ param* ~ where_clause* ~ term } -param = { param_implicit | param_explicit } - -param_implicit = { "(" ~ "forall" ~ term_var ~ term ~ ")" } -param_explicit = { "(" ~ "param" ~ term_var ~ term ~ ")" } -where_clause = { "(" ~ "where" ~ term ~ ")" } +param = { "(" ~ "param" ~ term_var ~ term ~ ")" } +where_clause = { "(" ~ "where" ~ term ~ ")" } region = { region_dfg | region_cfg } region_dfg = { "(" ~ "dfg" ~ port_lists? ~ signature? ~ meta* ~ node* ~ ")" } @@ -77,63 +68,37 @@ region_cfg = { "(" ~ "cfg" ~ port_lists? ~ signature? ~ meta* ~ node* ~ ")" } term = { term_wildcard - | term_type - | term_static - | term_constraint | term_var | term_const | term_list | term_list_type | term_str - | term_str_type | term_float | term_nat - | term_nat_type | term_ext_set - | term_ext_set_type - | term_adt - | term_func_type - | term_ctrl - | term_ctrl_type - | term_apply_full | term_apply - | term_non_linear | term_const_func | term_const_adt - | term_bytes_type | term_bytes | term_meta | term_float - | term_float_type + | term_tuple } -term_wildcard = { "_" } -term_type = { "type" } -term_static = { "static" } -term_constraint = { "constraint" } -term_var = { "?" ~ identifier } -term_apply_full = { ("(" ~ "@" ~ symbol ~ term* ~ ")") } -term_apply = { symbol | ("(" ~ symbol ~ term* ~ ")") } -term_const = { "(" ~ "const" ~ term ~ term ~ ")" } -term_list = { "[" ~ (spliced_term | term)* ~ "]" } -term_list_type = { "(" ~ "list" ~ term ~ ")" } -term_str = { string } -term_str_type = { "str" } -term_nat = @{ (ASCII_DIGIT)+ } -term_nat_type = { "nat" } -term_ext_set = { "(" ~ "ext" ~ (spliced_term | ext_name)* ~ ")" } -term_ext_set_type = { "ext-set" } -term_adt = { "(" ~ "adt" ~ term ~ ")" } -term_func_type = { "(" ~ "->" ~ term ~ term ~ term ~ ")" } -term_ctrl = { "(" ~ "ctrl" ~ term ~ ")" } -term_ctrl_type = { "ctrl" } -term_non_linear = { "(" ~ "nonlinear" ~ term ~ ")" } -term_const_func = { "(" ~ "fn" ~ term ~ ")" } -term_const_adt = { "(" ~ "tag" ~ tag ~ term* ~ ")" } -term_bytes_type = { "bytes" } -term_bytes = { "(" ~ "bytes" ~ base64_string ~ ")" } -term_meta = { "meta" } -term_float_type = { "float" } -term_float = @{ ("+" | "-")? ~ (ASCII_DIGIT)+ ~ "." ~ (ASCII_DIGIT)+ } +term_wildcard = { "_" } +term_var = { "?" ~ identifier } +term_apply = { symbol | ("(" ~ symbol ~ term* ~ ")") } +term_const = { "(" ~ "const" ~ term ~ term ~ ")" } +term_list = { "[" ~ (spliced_term | term)* ~ "]" } +term_tuple = { "(" ~ "tuple" ~ (spliced_term | term)* ~ ")" } +term_list_type = { "(" ~ "list" ~ term ~ ")" } +term_str = { string } +term_nat = @{ (ASCII_DIGIT)+ } +term_ext_set = { "(" ~ "ext" ~ (spliced_term | ext_name)* ~ ")" } +term_const_func = { "(" ~ "fn" ~ term ~ ")" } +term_const_adt = { "(" ~ "tag" ~ tag ~ term* ~ ")" } +term_bytes = { "(" ~ "bytes" ~ base64_string ~ ")" } +term_meta = { "meta" } +term_float = @{ ("+" | "-")? ~ (ASCII_DIGIT)+ ~ "." ~ (ASCII_DIGIT)+ } spliced_term = { term ~ "..." } diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs index 485610d76..fb5822008 100644 --- a/hugr-model/src/v0/text/parse.rs +++ b/hugr-model/src/v0/text/parse.rs @@ -9,9 +9,8 @@ use thiserror::Error; use crate::v0::{ scope::{LinkTable, SymbolTable, UnknownSymbolError, VarTable}, - AliasDecl, ConstructorDecl, ExtSetPart, FuncDecl, LinkIndex, ListPart, Module, Node, NodeId, - Operation, OperationDecl, Param, ParamSort, Region, RegionId, RegionKind, RegionScope, - ScopeClosure, Term, TermId, + ExtSetPart, LinkIndex, ListPart, Module, Node, NodeId, Operation, Param, Region, RegionId, + RegionKind, RegionScope, ScopeClosure, Symbol, Term, TermId, TuplePart, }; mod pest_parser { @@ -120,14 +119,6 @@ impl<'a> ParseContext<'a> { let term = match rule { Rule::term_wildcard => Term::Wildcard, - Rule::term_type => Term::Type, - Rule::term_static => Term::StaticType, - Rule::term_constraint => Term::Constraint, - Rule::term_str_type => Term::StrType, - Rule::term_nat_type => Term::NatType, - Rule::term_ctrl_type => Term::ControlType, - Rule::term_ext_set_type => Term::ExtSetType, - Rule::term_meta => Term::Meta, Rule::term_var => { let name_token = inner.next().unwrap(); @@ -148,54 +139,41 @@ impl<'a> ParseContext<'a> { args.push(self.parse_term(token)?); } - Term::Apply { - symbol, - args: self.bump.alloc_slice_copy(&args), - } + Term::Apply(symbol, self.bump.alloc_slice_copy(&args)) } - Rule::term_apply_full => { - let symbol = self.parse_symbol_use(&mut inner)?; - let mut args = Vec::new(); + Rule::term_list => { + let mut parts = BumpVec::with_capacity_in(inner.len(), self.bump); for token in inner { - args.push(self.parse_term(token)?); - } - - Term::ApplyFull { - symbol, - args: self.bump.alloc_slice_copy(&args), + match token.as_rule() { + Rule::term => parts.push(ListPart::Item(self.parse_term(token)?)), + Rule::spliced_term => { + let term_token = token.into_inner().next().unwrap(); + parts.push(ListPart::Splice(self.parse_term(term_token)?)) + } + _ => unreachable!(), + } } - } - Rule::term_const => { - let r#type = self.parse_term(inner.next().unwrap())?; - let extensions = self.parse_term(inner.next().unwrap())?; - Term::Const { r#type, extensions } + Term::List(parts.into_bump_slice()) } - Rule::term_list => { + Rule::term_tuple => { let mut parts = BumpVec::with_capacity_in(inner.len(), self.bump); for token in inner { match token.as_rule() { - Rule::term => parts.push(ListPart::Item(self.parse_term(token)?)), + Rule::term => parts.push(TuplePart::Item(self.parse_term(token)?)), Rule::spliced_term => { let term_token = token.into_inner().next().unwrap(); - parts.push(ListPart::Splice(self.parse_term(term_token)?)) + parts.push(TuplePart::Splice(self.parse_term(term_token)?)) } _ => unreachable!(), } } - Term::List { - parts: parts.into_bump_slice(), - } - } - - Rule::term_list_type => { - let item_type = self.parse_term(inner.next().unwrap())?; - Term::ListType { item_type } + Term::Tuple(parts.into_bump_slice()) } Rule::term_str => { @@ -223,50 +201,14 @@ impl<'a> ParseContext<'a> { } } - Term::ExtSet { - parts: parts.into_bump_slice(), - } - } - - Rule::term_adt => { - let variants = self.parse_term(inner.next().unwrap())?; - Term::Adt { variants } - } - - Rule::term_func_type => { - let inputs = self.parse_term(inner.next().unwrap())?; - let outputs = self.parse_term(inner.next().unwrap())?; - let extensions = self.parse_term(inner.next().unwrap())?; - Term::FuncType { - inputs, - outputs, - extensions, - } - } - - Rule::term_ctrl => { - let values = self.parse_term(inner.next().unwrap())?; - Term::Control { values } - } - - Rule::term_non_linear => { - let term = self.parse_term(inner.next().unwrap())?; - Term::NonLinearConstraint { term } + Term::ExtSet(parts.into_bump_slice()) } Rule::term_const_func => { let region = self.parse_region(inner.next().unwrap(), ScopeClosure::Closed)?; - Term::ConstFunc { region } - } - - Rule::term_const_adt => { - let tag = inner.next().unwrap().as_str().parse().unwrap(); - let values = self.parse_term(inner.next().unwrap())?; - Term::ConstAdt { tag, values } + Term::ConstFunc(region) } - Rule::term_bytes_type => Term::BytesType, - Rule::term_bytes => { let token = inner.next().unwrap(); let slice = token.as_str(); @@ -276,15 +218,12 @@ impl<'a> ParseContext<'a> { ParseError::custom("invalid base64 encoding", token.as_span()) })?; let data = self.bump.alloc_slice_copy(&data); - Term::Bytes { data } + Term::Bytes(data) } - Rule::term_float_type => Term::FloatType, Rule::term_float => { let value: f64 = str_slice.trim().parse().unwrap(); - Term::Float { - value: value.into(), - } + Term::Float(value.into()) } r => unreachable!("term: {:?}", r), @@ -401,12 +340,12 @@ impl<'a> ParseContext<'a> { Rule::node_define_func => { self.vars.enter(node); - let decl = self.parse_func_header(inner.next().unwrap())?; + let symbol = self.parse_func_header(inner.next().unwrap())?; let meta = self.parse_meta(&mut inner)?; let regions = self.parse_regions(&mut inner, ScopeClosure::Closed)?; self.vars.exit(); Node { - operation: Operation::DefineFunc { decl }, + operation: Operation::DefineFunc(symbol), inputs: &[], outputs: &[], params: &[], @@ -418,11 +357,11 @@ impl<'a> ParseContext<'a> { Rule::node_declare_func => { self.vars.enter(node); - let decl = self.parse_func_header(inner.next().unwrap())?; + let symbol = self.parse_func_header(inner.next().unwrap())?; let meta = self.parse_meta(&mut inner)?; self.vars.exit(); Node { - operation: Operation::DeclareFunc { decl }, + operation: Operation::DeclareFunc(symbol), inputs: &[], outputs: &[], params: &[], @@ -432,51 +371,18 @@ impl<'a> ParseContext<'a> { } } - Rule::node_call_func => { - let func = self.parse_term(inner.next().unwrap())?; - let inputs = self.parse_port_list(&mut inner)?; - let outputs = self.parse_port_list(&mut inner)?; - let signature = self.parse_signature(&mut inner)?; - let meta = self.parse_meta(&mut inner)?; - Node { - operation: Operation::CallFunc { func }, - inputs, - outputs, - params: &[], - regions: &[], - meta, - signature, - } - } - - Rule::node_load_func => { - let func = self.parse_term(inner.next().unwrap())?; - let inputs = self.parse_port_list(&mut inner)?; - let outputs = self.parse_port_list(&mut inner)?; - let signature = self.parse_signature(&mut inner)?; - let meta = self.parse_meta(&mut inner)?; - Node { - operation: Operation::LoadFunc { func }, - inputs, - outputs, - params: &[], - regions: &[], - meta, - signature, - } - } - Rule::node_define_alias => { self.vars.enter(node); - let decl = self.parse_alias_header(inner.next().unwrap())?; + let symbol = self.parse_alias_header(inner.next().unwrap())?; let value = self.parse_term(inner.next().unwrap())?; + let params = self.bump.alloc_slice_copy(&[value]); let meta = self.parse_meta(&mut inner)?; self.vars.exit(); Node { - operation: Operation::DefineAlias { decl, value }, + operation: Operation::DefineAlias(symbol), inputs: &[], outputs: &[], - params: &[], + params, regions: &[], meta, signature: None, @@ -485,11 +391,11 @@ impl<'a> ParseContext<'a> { Rule::node_declare_alias => { self.vars.enter(node); - let decl = self.parse_alias_header(inner.next().unwrap())?; + let symbol = self.parse_alias_header(inner.next().unwrap())?; let meta = self.parse_meta(&mut inner)?; self.vars.exit(); Node { - operation: Operation::DeclareAlias { decl }, + operation: Operation::DeclareAlias(symbol), inputs: &[], outputs: &[], params: &[], @@ -501,27 +407,18 @@ impl<'a> ParseContext<'a> { Rule::node_custom => { let op = inner.next().unwrap(); - debug_assert!(matches!( - op.as_rule(), - Rule::term_apply | Rule::term_apply_full - )); - let op_rule = op.as_rule(); + debug_assert!(matches!(op.as_rule(), Rule::term_apply)); let mut op_inner = op.into_inner(); let operation = self.parse_symbol_use(&mut op_inner)?; let mut params = Vec::new(); - for token in filter_rule(&mut inner, Rule::term) { + for token in filter_rule(&mut op_inner, Rule::term) { params.push(self.parse_term(token)?); } - let operation = match op_rule { - Rule::term_apply_full => Operation::CustomFull { operation }, - Rule::term_apply => Operation::Custom { operation }, - _ => unreachable!(), - }; - + let operation = Operation::Custom(operation); let inputs = self.parse_port_list(&mut inner)?; let outputs = self.parse_port_list(&mut inner)?; let signature = self.parse_signature(&mut inner)?; @@ -572,30 +469,13 @@ impl<'a> ParseContext<'a> { } } - Rule::node_tag => { - let tag = inner.next().unwrap().as_str().parse::().unwrap(); - let inputs = self.parse_port_list(&mut inner)?; - let outputs = self.parse_port_list(&mut inner)?; - let signature = self.parse_signature(&mut inner)?; - let meta = self.parse_meta(&mut inner)?; - Node { - operation: Operation::Tag { tag }, - inputs, - outputs, - params: &[], - regions: &[], - meta, - signature, - } - } - Rule::node_declare_ctr => { self.vars.enter(node); - let decl = self.parse_ctr_header(inner.next().unwrap())?; + let symbol = self.parse_ctr_header(inner.next().unwrap())?; let meta = self.parse_meta(&mut inner)?; self.vars.exit(); Node { - operation: Operation::DeclareConstructor { decl }, + operation: Operation::DeclareConstructor(symbol), inputs: &[], outputs: &[], params: &[], @@ -607,11 +487,11 @@ impl<'a> ParseContext<'a> { Rule::node_declare_operation => { self.vars.enter(node); - let decl = self.parse_op_header(inner.next().unwrap())?; + let symbol = self.parse_op_header(inner.next().unwrap())?; let meta = self.parse_meta(&mut inner)?; self.vars.exit(); Node { - operation: Operation::DeclareOperation { decl }, + operation: Operation::DeclareOperation(symbol), inputs: &[], outputs: &[], params: &[], @@ -621,23 +501,6 @@ impl<'a> ParseContext<'a> { } } - Rule::node_const => { - let value = self.parse_term(inner.next().unwrap())?; - let inputs = self.parse_port_list(&mut inner)?; - let outputs = self.parse_port_list(&mut inner)?; - let signature = self.parse_signature(&mut inner)?; - let meta = self.parse_meta(&mut inner)?; - Node { - operation: Operation::Const { value }, - inputs, - outputs, - params: &[], - regions: &[], - meta, - signature, - } - } - _ => unreachable!(), }; @@ -729,79 +592,70 @@ impl<'a> ParseContext<'a> { Ok(nodes) } - fn parse_func_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a FuncDecl<'a>> { + fn parse_func_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a Symbol<'a>> { debug_assert_eq!(pair.as_rule(), Rule::func_header); let mut inner = pair.into_inner(); let name = self.parse_symbol(&mut inner)?; let params = self.parse_params(&mut inner)?; let constraints = self.parse_constraints(&mut inner)?; + let signature = self.parse_term(inner.next().unwrap())?; - let inputs = self.parse_term(inner.next().unwrap())?; - let outputs = self.parse_term(inner.next().unwrap())?; - let extensions = self.parse_term(inner.next().unwrap())?; - - // Assemble the inputs, outputs and extensions into a function type. - let func = self.module.insert_term(Term::FuncType { - inputs, - outputs, - extensions, - }); - - Ok(self.bump.alloc(FuncDecl { + Ok(self.bump.alloc(Symbol { name, params, constraints, - signature: func, + signature, })) } - fn parse_alias_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a AliasDecl<'a>> { + fn parse_alias_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a Symbol<'a>> { debug_assert_eq!(pair.as_rule(), Rule::alias_header); let mut inner = pair.into_inner(); let name = self.parse_symbol(&mut inner)?; let params = self.parse_params(&mut inner)?; - let r#type = self.parse_term(inner.next().unwrap())?; + let signature = self.parse_term(inner.next().unwrap())?; - Ok(self.bump.alloc(AliasDecl { + Ok(self.bump.alloc(Symbol { name, params, - r#type, + constraints: &[], + signature, })) } - fn parse_ctr_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a ConstructorDecl<'a>> { + fn parse_ctr_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a Symbol<'a>> { debug_assert_eq!(pair.as_rule(), Rule::ctr_header); let mut inner = pair.into_inner(); let name = self.parse_symbol(&mut inner)?; let params = self.parse_params(&mut inner)?; let constraints = self.parse_constraints(&mut inner)?; - let r#type = self.parse_term(inner.next().unwrap())?; + let signature = self.parse_term(inner.next().unwrap())?; - Ok(self.bump.alloc(ConstructorDecl { + Ok(self.bump.alloc(Symbol { name, params, constraints, - r#type, + signature, })) } - fn parse_op_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a OperationDecl<'a>> { + fn parse_op_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a Symbol<'a>> { debug_assert_eq!(pair.as_rule(), Rule::operation_header); let mut inner = pair.into_inner(); let name = self.parse_symbol(&mut inner)?; let params = self.parse_params(&mut inner)?; let constraints = self.parse_constraints(&mut inner)?; - let r#type = self.parse_term(inner.next().unwrap())?; + let signature = self.parse_term(inner.next().unwrap())?; - Ok(self.bump.alloc(OperationDecl { + Ok(self.bump.alloc(Symbol { name, params, constraints, - r#type, + signature, })) } @@ -809,32 +663,11 @@ impl<'a> ParseContext<'a> { let mut params = Vec::new(); for pair in filter_rule(pairs, Rule::param) { - let param = pair.into_inner().next().unwrap(); - let param_span = param.as_span(); - - let param = match param.as_rule() { - Rule::param_implicit => { - let mut inner = param.into_inner(); - let name = &inner.next().unwrap().as_str()[1..]; - let r#type = self.parse_term(inner.next().unwrap())?; - Param { - name, - r#type, - sort: ParamSort::Implicit, - } - } - Rule::param_explicit => { - let mut inner = param.into_inner(); - let name = &inner.next().unwrap().as_str()[1..]; - let r#type = self.parse_term(inner.next().unwrap())?; - Param { - name, - r#type, - sort: ParamSort::Explicit, - } - } - _ => unreachable!(), - }; + let param_span = pair.as_span(); + let mut inner = pair.into_inner(); + let name = &inner.next().unwrap().as_str()[1..]; + let r#type = self.parse_term(inner.next().unwrap())?; + let param = Param { name, r#type }; self.vars .insert(param.name) @@ -904,6 +737,10 @@ impl<'a> ParseContext<'a> { fn parse_symbol_use(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult { let name = self.parse_symbol(pairs)?; + self.use_symbol(name) + } + + fn use_symbol(&mut self, name: &'a str) -> ParseResult { let resolved = self.symbols.resolve(name); Ok(match resolved { diff --git a/hugr-model/src/v0/text/print.rs b/hugr-model/src/v0/text/print.rs index 77d6ef0cb..ee3ab54f7 100644 --- a/hugr-model/src/v0/text/print.rs +++ b/hugr-model/src/v0/text/print.rs @@ -3,8 +3,8 @@ use pretty::{Arena, DocAllocator, RefDoc}; use std::borrow::Cow; use crate::v0::{ - ExtSetPart, LinkIndex, ListPart, ModelError, Module, NodeId, Operation, Param, ParamSort, - RegionId, RegionKind, Term, TermId, VarId, + ExtSetPart, LinkIndex, ListPart, ModelError, Module, NodeId, Operation, Param, RegionId, + RegionKind, Term, TermId, TuplePart, VarId, }; type PrintError = ModelError; @@ -135,6 +135,11 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { .get_node(node_id) .ok_or(PrintError::NodeNotFound(node_id))?; + // Skip printing import nodes. + if let Operation::Import { .. } = node_data.operation { + return Ok(()); + } + self.print_parens(|this| match &node_data.operation { Operation::Invalid => Err(ModelError::InvalidOperation(node_id)), Operation::Dfg => { @@ -165,87 +170,33 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_regions(node_data.regions) } - Operation::DefineFunc { decl } => this.with_local_scope(decl.params, |this| { + Operation::DefineFunc(symbol) => this.with_local_scope(symbol.params, |this| { this.print_group(|this| { this.print_text("define-func"); - this.print_text(decl.name); + this.print_text(symbol.name); }); - this.print_params(decl.params)?; - this.print_constraints(decl.constraints)?; - - match self.module.get_term(decl.signature) { - Some(Term::FuncType { - inputs, - outputs, - extensions, - }) => { - this.print_group(|this| { - this.print_term(*inputs)?; - this.print_term(*outputs)?; - this.print_term(*extensions) - })?; - } - Some(_) => return Err(PrintError::TypeError(decl.signature)), - None => return Err(PrintError::TermNotFound(decl.signature)), - } - + this.print_params(symbol.params)?; + this.print_constraints(symbol.constraints)?; + this.print_term(symbol.signature)?; this.print_meta(node_data.meta)?; this.print_regions(node_data.regions) }), - Operation::DeclareFunc { decl } => this.with_local_scope(decl.params, |this| { + Operation::DeclareFunc(symbol) => this.with_local_scope(symbol.params, |this| { this.print_group(|this| { this.print_text("declare-func"); - this.print_text(decl.name); + this.print_text(symbol.name); }); - this.print_params(decl.params)?; - this.print_constraints(decl.constraints)?; - - match self.module.get_term(decl.signature) { - Some(Term::FuncType { - inputs, - outputs, - extensions, - }) => { - this.print_group(|this| { - this.print_term(*inputs)?; - this.print_term(*outputs)?; - this.print_term(*extensions) - })?; - } - Some(_) => return Err(PrintError::TypeError(decl.signature)), - None => return Err(PrintError::TermNotFound(decl.signature)), - } - + this.print_params(symbol.params)?; + this.print_constraints(symbol.constraints)?; + this.print_term(symbol.signature)?; this.print_meta(node_data.meta)?; Ok(()) }), - Operation::CallFunc { func } => { - this.print_group(|this| { - this.print_text("call"); - this.print_term(*func)?; - this.print_port_lists(node_data.inputs, node_data.outputs) - })?; - this.print_signature(node_data.signature)?; - this.print_meta(node_data.meta)?; - Ok(()) - } - - Operation::LoadFunc { func } => { - this.print_group(|this| { - this.print_text("load-func"); - this.print_term(*func)?; - this.print_port_lists(node_data.inputs, node_data.outputs) - })?; - this.print_signature(node_data.signature)?; - this.print_meta(node_data.meta)?; - Ok(()) - } - - Operation::Custom { operation } => { + Operation::Custom(operation) => { this.print_group(|this| { if node_data.params.is_empty() { this.print_symbol(*operation)?; @@ -268,76 +219,54 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_regions(node_data.regions) } - Operation::CustomFull { operation } => { - this.print_group(|this| { - this.print_parens(|this| { - this.print_text("@"); - this.print_symbol(*operation)?; - - for param in node_data.params { - this.print_term(*param)?; - } - - Ok(()) - })?; - - this.print_port_lists(node_data.inputs, node_data.outputs) - })?; - this.print_signature(node_data.signature)?; - this.print_meta(node_data.meta)?; - this.print_regions(node_data.regions) - } - - Operation::DefineAlias { decl, value } => this.with_local_scope(decl.params, |this| { + Operation::DefineAlias(symbol) => this.with_local_scope(symbol.params, |this| { this.print_group(|this| { this.print_text("define-alias"); - this.print_text(decl.name); + this.print_text(symbol.name); }); - this.print_params(decl.params)?; - - this.print_term(decl.r#type)?; - this.print_term(*value)?; + this.print_params(symbol.params)?; + this.print_term(symbol.signature)?; + for param in node_data.params { + this.print_term(*param)?; + } this.print_meta(node_data.meta)?; Ok(()) }), - Operation::DeclareAlias { decl } => this.with_local_scope(decl.params, |this| { + Operation::DeclareAlias(symbol) => this.with_local_scope(symbol.params, |this| { this.print_group(|this| { this.print_text("declare-alias"); - this.print_text(decl.name); + this.print_text(symbol.name); }); - this.print_params(decl.params)?; - - this.print_term(decl.r#type)?; + this.print_params(symbol.params)?; + this.print_term(symbol.signature)?; this.print_meta(node_data.meta)?; Ok(()) }), - Operation::DeclareConstructor { decl } => this.with_local_scope(decl.params, |this| { + Operation::DeclareConstructor(symbol) => this.with_local_scope(symbol.params, |this| { this.print_group(|this| { this.print_text("declare-ctr"); - this.print_text(decl.name); + this.print_text(symbol.name); }); - this.print_params(decl.params)?; - this.print_constraints(decl.constraints)?; - - this.print_term(decl.r#type)?; + this.print_params(symbol.params)?; + this.print_constraints(symbol.constraints)?; + this.print_term(symbol.signature)?; this.print_meta(node_data.meta)?; Ok(()) }), - Operation::DeclareOperation { decl } => this.with_local_scope(decl.params, |this| { + Operation::DeclareOperation(symbol) => this.with_local_scope(symbol.params, |this| { this.print_group(|this| { this.print_text("declare-operation"); - this.print_text(decl.name); + this.print_text(symbol.name); }); - this.print_params(decl.params)?; - this.print_constraints(decl.constraints)?; - - this.print_term(decl.r#type)?; + this.print_params(symbol.params)?; + this.print_constraints(symbol.constraints)?; + this.print_term(symbol.signature)?; this.print_meta(node_data.meta)?; Ok(()) }), @@ -358,26 +287,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_regions(node_data.regions) } - Operation::Tag { tag } => { - this.print_text("tag"); - this.print_text(format!("{}", tag)); - this.print_port_lists(node_data.inputs, node_data.outputs)?; - this.print_signature(node_data.signature)?; - this.print_meta(node_data.meta) - } - - Operation::Import { name } => { - this.print_text("import"); - this.print_text(*name); - this.print_meta(node_data.meta) - } - - Operation::Const { value } => { - this.print_text("const"); - this.print_term(*value)?; - this.print_port_lists(node_data.inputs, node_data.outputs)?; - this.print_signature(node_data.signature)?; - this.print_meta(node_data.meta) + Operation::Import { .. } => { + unreachable!() } }) } @@ -460,11 +371,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { fn print_param(&mut self, param: Param<'a>) -> PrintResult<()> { self.print_parens(|this| { - match param.sort { - ParamSort::Implicit => this.print_text("forall"), - ParamSort::Explicit => this.print_text("param"), - }; - + this.print_text("param"); this.print_text(format!("?{}", param.name)); this.print_term(param.r#type) }) @@ -492,20 +399,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { self.print_text("_"); Ok(()) } - Term::Type => { - self.print_text("type"); - Ok(()) - } - Term::StaticType => { - self.print_text("static"); - Ok(()) - } - Term::Constraint => { - self.print_text("constraint"); - Ok(()) - } Term::Var(var) => self.print_var(*var), - Term::Apply { symbol, args } => { + Term::Apply(symbol, args) => { if args.is_empty() { self.print_symbol(*symbol)?; } else { @@ -520,107 +415,38 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { Ok(()) } - Term::ApplyFull { symbol, args } => self.print_parens(|this| { - this.print_text("@"); - this.print_symbol(*symbol)?; - for arg in args.iter() { - this.print_term(*arg)?; - } - - Ok(()) - }), - Term::Const { r#type, extensions } => self.print_parens(|this| { - this.print_text("const"); - this.print_term(*r#type)?; - this.print_term(*extensions) - }), Term::List { .. } => self.print_brackets(|this| this.print_list_parts(term_id)), - Term::ListType { item_type } => self.print_parens(|this| { - this.print_text("list"); - this.print_term(*item_type) + Term::Tuple { .. } => self.print_parens(|this| { + this.print_text("tuple"); + this.print_tuple_parts(term_id) }), Term::Str(str) => { self.print_string(str); Ok(()) } - Term::StrType => { - self.print_text("str"); - Ok(()) - } Term::Nat(n) => { self.print_text(n.to_string()); Ok(()) } - Term::NatType => { - self.print_text("nat"); - Ok(()) - } Term::ExtSet { .. } => self.print_parens(|this| { this.print_text("ext"); this.print_ext_set_parts(term_id)?; Ok(()) }), - Term::ExtSetType => { - self.print_text("ext-set"); - Ok(()) - } - Term::Adt { variants } => self.print_parens(|this| { - this.print_text("adt"); - this.print_term(*variants) - }), - Term::FuncType { - inputs, - outputs, - extensions, - } => self.print_parens(|this| { - this.print_text("->"); - this.print_term(*inputs)?; - this.print_term(*outputs)?; - this.print_term(*extensions) - }), - Term::Control { values } => self.print_parens(|this| { - this.print_text("ctrl"); - this.print_term(*values) - }), - Term::ControlType => { - self.print_text("ctrl"); - Ok(()) - } - Term::NonLinearConstraint { term } => self.print_parens(|this| { - this.print_text("nonlinear"); - this.print_term(*term) - }), - Term::ConstFunc { region } => self.print_parens(|this| { + Term::ConstFunc(region) => self.print_parens(|this| { this.print_text("fn"); this.print_region(*region) }), - Term::ConstAdt { tag, values } => self.print_parens(|this| { - this.print_text("tag"); - this.print_text(tag.to_string()); - this.print_term(*values) - }), - Term::BytesType => { - self.print_text("bytes"); - Ok(()) - } - Term::Bytes { data } => self.print_parens(|this| { + Term::Bytes(data) => self.print_parens(|this| { this.print_text("bytes"); this.print_byte_string(data); Ok(()) }), - Term::Meta => { - self.print_text("meta"); - Ok(()) - } - Term::Float { value } => { + Term::Float(value) => { // The debug representation of a float always includes a decimal point. self.print_text(format!("{:?}", value.into_inner())); Ok(()) } - Term::FloatType => { - self.print_text("float"); - Ok(()) - } } } @@ -633,7 +459,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { .get_term(term_id) .ok_or(PrintError::TermNotFound(term_id))?; - if let Term::List { parts } = term_data { + if let Term::List(parts) = term_data { for part in *parts { match part { ListPart::Item(term) => self.print_term(*term)?, @@ -648,6 +474,30 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { Ok(()) } + /// Prints the contents of a tuple. + /// + /// This is used so that spliced tuples are merged into the parent tuple. + fn print_tuple_parts(&mut self, term_id: TermId) -> PrintResult<()> { + let term_data = self + .module + .get_term(term_id) + .ok_or(PrintError::TermNotFound(term_id))?; + + if let Term::Tuple(parts) = term_data { + for part in *parts { + match part { + TuplePart::Item(term) => self.print_term(*term)?, + TuplePart::Splice(list) => self.print_tuple_parts(*list)?, + } + } + } else { + self.print_term(term_id)?; + self.print_text("..."); + } + + Ok(()) + } + /// Prints the contents of an extension set. /// /// This is used so that spliced extension sets are merged into the parent extension set. @@ -657,7 +507,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { .get_term(term_id) .ok_or(PrintError::TermNotFound(term_id))?; - if let Term::ExtSet { parts } = term_data { + if let Term::ExtSet(parts) = term_data { for part in *parts { match part { ExtSetPart::Extension(ext) => self.print_text(*ext), diff --git a/hugr-model/tests/fixtures/model-add.edn b/hugr-model/tests/fixtures/model-add.edn index ed8476ea9..3ecc4159c 100644 --- a/hugr-model/tests/fixtures/model-add.edn +++ b/hugr-model/tests/fixtures/model-add.edn @@ -1,13 +1,14 @@ (hugr 0) (define-func example.add - [(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)] - [(@ arithmetic.int.types.int)] - (ext) + (core.fn + [arithmetic.int.types.int arithmetic.int.types.int] + [arithmetic.int.types.int] + (ext)) (dfg - [%0 %1] - [%2] - (signature (-> [(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) - ((@ arithmetic.int.iadd) [%0 %1] [%2] - (signature (-> [(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) - ))) + [%0 %1] + [%2] + (signature (core.fn [arithmetic.int.types.int arithmetic.int.types.int] [arithmetic.int.types.int] (ext))) + (arithmetic.int.iadd + [%0 %1] [%2] + (signature (core.fn [arithmetic.int.types.int arithmetic.int.types.int] [arithmetic.int.types.int] (ext)))))) diff --git a/hugr-model/tests/fixtures/model-alias.edn b/hugr-model/tests/fixtures/model-alias.edn index 2998410ad..673a20f89 100644 --- a/hugr-model/tests/fixtures/model-alias.edn +++ b/hugr-model/tests/fixtures/model-alias.edn @@ -1,7 +1,7 @@ (hugr 0) -(declare-alias local.float type) +(declare-alias local.float core.type) -(define-alias local.int type (@ arithmetic.int.types.int)) +(define-alias local.int core.type arithmetic.int.types.int) -(define-alias local.endo type (-> [] [] (ext))) +(define-alias local.endo core.type (core.fn [] [] (ext))) diff --git a/hugr-model/tests/fixtures/model-call.edn b/hugr-model/tests/fixtures/model-call.edn index c757658fc..96ff11164 100644 --- a/hugr-model/tests/fixtures/model-call.edn +++ b/hugr-model/tests/fixtures/model-call.edn @@ -1,24 +1,26 @@ (hugr 0) -(declare-func example.callee - (forall ?ext ext-set) - [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int ?ext ...) - (meta (compat.meta-json "title" "\"Callee\"")) - (meta (compat.meta-json "description" "\"This is a function declaration.\""))) +(declare-func + example.callee + (param ?ext core.ext_set) + (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int] (ext arithmetic.int ?ext ...)) + (meta (compat.meta_json "title" "\"Callee\"")) + (meta (compat.meta_json "description" "\"This is a function declaration.\""))) (define-func example.caller - [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int) - (meta (compat.meta-json "title" "\"Caller\"")) - (meta (compat.meta-json "description" "\"This defines a function that calls the function which we declared earlier.\"")) + (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int] (ext arithmetic.int)) + (meta (compat.meta_json "title" "\"Caller\"")) + (meta (compat.meta_json "description" "\"This defines a function that calls the function which we declared earlier.\"")) (dfg [%3] [%4] - (signature (-> [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) - (call (@ example.callee (ext)) [%3] [%4] - (signature (-> [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)))))) + (signature (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int] (ext))) + ((core.call _ _ _ (example.callee (ext))) [%3] [%4] + (signature (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int] (ext)))))) -(define-func example.load - [] [(-> [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int))] (ext) +(define-func + example.load + (core.fn [] [(core.fn [arithmetic.int.types.int] [arithmetic.int.types.int] (ext arithmetic.int))] (ext)) (dfg - [] - [%5] - (signature (-> [] [(-> [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int))] (ext))) - (load-func (@ example.caller) [] [%5]))) + [] + [%5] + (signature (core.fn [] [(core.fn [arithmetic.int.types.int] [arithmetic.int.types.int] (ext arithmetic.int))] (ext))) + ((core.load_const _ _ example.caller) [] [%5]))) diff --git a/hugr-model/tests/fixtures/model-cfg.edn b/hugr-model/tests/fixtures/model-cfg.edn index 8c25ad91a..747e08943 100644 --- a/hugr-model/tests/fixtures/model-cfg.edn +++ b/hugr-model/tests/fixtures/model-cfg.edn @@ -1,17 +1,17 @@ (hugr 0) (define-func example.cfg - (forall ?a type) - [?a] [?a] (ext) + (param ?a core.type) + (core.fn [?a] [?a] (ext)) (dfg [%0] [%1] - (signature (-> [?a] [?a] (ext))) - (cfg [%0] [%1] - (signature (-> [?a] [?a] (ext))) - (cfg [%2] [%4] - (signature (-> [(ctrl [?a])] [(ctrl [?a])] (ext))) - (block [%2] [%4 %2] - (signature (-> [(ctrl [?a])] [(ctrl [?a]) (ctrl [?a])] (ext))) - (dfg [%5] [%6] - (signature (-> [?a] [(adt [[?a] [?a]])] (ext))) - (tag 0 [%5] [%6] - (signature (-> [?a] [(adt [[?a] [?a]])] (ext)))))))))) + (signature (core.fn [?a] [?a] (ext))) + (cfg [%0] [%1] + (signature (core.fn [?a] [?a] (ext))) + (cfg [%2] [%4] + (signature (core.fn [(core.ctrl [?a])] [(core.ctrl [?a])] (ext))) + (block [%2] [%4 %2] + (signature (core.fn [(core.ctrl [?a])] [(core.ctrl [?a]) (core.ctrl [?a])] (ext))) + (dfg [%5] [%6] + (signature (core.fn [?a] [(core.adt [[?a] [?a]])] (ext))) + ((core.make_adt _ _ 0) [%5] [%6] + (signature (core.fn [?a] [(core.adt [[?a] [?a]])] (ext)))))))))) diff --git a/hugr-model/tests/fixtures/model-cond.edn b/hugr-model/tests/fixtures/model-cond.edn index d6b84d9fa..0a141612d 100644 --- a/hugr-model/tests/fixtures/model-cond.edn +++ b/hugr-model/tests/fixtures/model-cond.edn @@ -1,15 +1,15 @@ (hugr 0) (define-func example.cond - [(adt [[] []]) (@ arithmetic.int.types.int)] - [(@ arithmetic.int.types.int)] - (ext) + (core.fn [(core.adt [[] []]) arithmetic.int.types.int] + [arithmetic.int.types.int] + (ext)) (dfg [%0 %1] [%2] - (signature (-> [(adt [[] []]) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) - (cond [%0 %1] [%2] - (signature (-> [(adt [[] []]) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) - (dfg [%3] [%3] - (signature (-> [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)))) - (dfg [%4] [%5] - (signature (-> [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) - ((@ arithmetic.int.ineg) [%4] [%5] - (signature (-> [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)))))))) + (signature (core.fn [(core.adt [[] []]) arithmetic.int.types.int] [arithmetic.int.types.int] (ext))) + (cond [%0 %1] [%2] + (signature (core.fn [(core.adt [[] []]) arithmetic.int.types.int] [arithmetic.int.types.int] (ext))) + (dfg [%3] [%3] + (signature (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int] (ext)))) + (dfg [%4] [%5] + (signature (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int] (ext))) + (arithmetic.int.ineg [%4] [%5] + (signature (core.fn [arithmetic.int.types.int] [arithmetic.int.types.int] (ext)))))))) diff --git a/hugr-model/tests/fixtures/model-const.edn b/hugr-model/tests/fixtures/model-const.edn index 61748ab38..13a161b05 100644 --- a/hugr-model/tests/fixtures/model-const.edn +++ b/hugr-model/tests/fixtures/model-const.edn @@ -1,65 +1,67 @@ (hugr 0) (define-func example.bools - [] - [(adt [[] []]) (adt [[] []])] - (ext) + (core.fn [] + [(core.adt [[] []]) (core.adt [[] []])] + (ext)) (dfg [] [%false %true] - (signature (-> [] [(adt [[] []]) (adt [[] []])] (ext))) - (const (tag 0 []) [] [%false] - (signature (-> [] [(adt [[] []])] (ext)))) - (const (tag 1 []) [] [%true] - (signature (-> [] [(adt [[] []])] (ext)))))) + (signature (core.fn [] [(core.adt [[] []]) (core.adt [[] []])] (ext))) + ((core.load_const _ _ (core.const.adt _ _ _ 0 (tuple))) [] [%false] + (signature (core.fn [] [(core.adt [[] []])] (ext)))) + ((core.load_const _ _ (core.const.adt _ _ _ 1 (tuple))) [] [%true] + (signature (core.fn [] [(core.adt [[] []])] (ext)))))) (define-func example.make-pair - [] - [(adt - [[(@ collections.array.array 5 (@ arithmetic.int.types.int 6)) - (@ arithmetic.float.types.float64)]])] - (ext) + (core.fn [] + [(core.adt + [[(collections.array.array 5 (arithmetic.int.types.int 6)) + arithmetic.float.types.float64]])] + (ext)) (dfg + [] [%0] + (signature + (core.fn + [] + [(core.adt + [[(collections.array.array 5 (arithmetic.int.types.int 6)) + arithmetic.float.types.float64]])] + (ext))) + ((core.load_const _ _ + (core.const.adt + _ + _ + _ + 0 + (tuple (collections.array.const + 5 + (arithmetic.int.types.int 6) + [(arithmetic.int.const 6 1) + (arithmetic.int.const 6 2) + (arithmetic.int.const 6 3) + (arithmetic.int.const 6 4) + (arithmetic.int.const 6 5)]) + (arithmetic.float.const-f64 -3.0)))) [] [%0] (signature - (-> - [] - [(adt - [[(@ collections.array.array 5 (@ arithmetic.int.types.int 6)) - (@ arithmetic.float.types.float64)]])] - (ext))) - (const - (tag - 0 - [(@ - collections.array.const - 5 - (@ arithmetic.int.types.int 6) - [(@ arithmetic.int.const 6 1) - (@ arithmetic.int.const 6 2) - (@ arithmetic.int.const 6 3) - (@ arithmetic.int.const 6 4) - (@ arithmetic.int.const 6 5)]) - (@ arithmetic.float.const-f64 -3.0)]) - [] [%0] - (signature - (-> - [] - [(adt - [[(@ collections.array.array 5 (@ arithmetic.int.types.int 6)) - (@ arithmetic.float.types.float64)]])] - (ext)))))) + (core.fn + [] + [(core.adt + [[(collections.array.array 5 (arithmetic.int.types.int 6)) + arithmetic.float.types.float64]])] + (ext)))))) (define-func example.f64-json - [] - [(@ arithmetic.float.types.float64)] - (ext) + (core.fn [] + [arithmetic.float.types.float64] + (ext)) (dfg [] [%0 %1] - (signature (-> [] [(@ arithmetic.float.types.float64) (@ arithmetic.float.types.float64)] (ext))) - (const - (@ compat.const-json (@ arithmetic.float.types.float64) "{\"c\":\"ConstF64\",\"v\":{\"value\":1.0}}" (ext)) - [] [%0] - (signature (-> [] [(@ arithmetic.float.types.float64)] (ext)))) + (signature (core.fn [] [arithmetic.float.types.float64 arithmetic.float.types.float64] (ext))) + ((core.load_const _ _ + (compat.const_json arithmetic.float.types.float64 (ext) "{\"c\":\"ConstF64\",\"v\":{\"value\":1.0}}")) + [] [%0] + (signature (core.fn [] [arithmetic.float.types.float64] (ext)))) ; The following const is to test that import/export can deal with unknown constants. - (const - (@ compat.const-json (@ arithmetic.float.types.float64) "{\"c\":\"ConstUnknown\",\"v\":{\"value\":1.0}}" (ext)) - [] [%1] - (signature (-> [] [(@ arithmetic.float.types.float64)] (ext)))))) + ((core.load_const _ _ + (compat.const_json arithmetic.float.types.float64 (ext) "{\"c\":\"ConstUnknown\",\"v\":{\"value\":1.0}}")) + [] [%1] + (signature (core.fn [] [arithmetic.float.types.float64] (ext)))))) diff --git a/hugr-model/tests/fixtures/model-constraints.edn b/hugr-model/tests/fixtures/model-constraints.edn index c2d58346b..f1b3147ce 100644 --- a/hugr-model/tests/fixtures/model-constraints.edn +++ b/hugr-model/tests/fixtures/model-constraints.edn @@ -1,21 +1,25 @@ (hugr 0) (declare-func array.replicate - (forall ?n nat) - (forall ?t type) - (where (nonlinear ?t)) - [?t] [(@ collections.array.array ?n ?t)] - (ext)) + (param ?n core.nat) + (param ?t core.type) + (where (core.nonlinear ?t)) + (core.fn [?t] [(collections.array.array ?n ?t)] + (ext))) -(declare-func array.copy - (forall ?n nat) - (forall ?t type) - (where (nonlinear ?t)) - [(@ collections.array.array ?n ?t)] [(@ collections.array.array ?n ?t) (@ collections.array.array ?n ?t)] (ext)) +(declare-func + array.copy + (param ?n core.nat) + (param ?t core.type) + (where (core.nonlinear ?t)) + (core.fn + [(collections.array.array ?n ?t)] + [(collections.array.array ?n ?t) + (collections.array.array ?n ?t)] (ext))) (define-func util.copy - (forall ?t type) - (where (nonlinear ?t)) - [?t] [?t ?t] (ext) + (param ?t core.type) + (where (core.nonlinear ?t)) + (core.fn [?t] [?t ?t] (ext)) (dfg [%0] [%0 %0] - (signature (-> [?t] [?t ?t] (ext))))) + (signature (core.fn [?t] [?t ?t] (ext))))) diff --git a/hugr-model/tests/fixtures/model-decl-exts.edn b/hugr-model/tests/fixtures/model-decl-exts.edn index 253f480db..b2aebb33f 100644 --- a/hugr-model/tests/fixtures/model-decl-exts.edn +++ b/hugr-model/tests/fixtures/model-decl-exts.edn @@ -1,13 +1,13 @@ (hugr 0) (declare-ctr array.Array - (param ?t type) - (param ?n nat) - type - (meta (core.meta.description "Fixed size array."))) + (param ?t core.type) + (param ?n core.nat) + core.type + (meta (core.meta.description "Fixed size array."))) (declare-operation array.Init - (param ?t type) - (param ?n nat) - (-> [?t] [(array.Array ?t ?n)] (ext array)) - (meta (core.meta.description "Initialize an array of size ?n with copies of a default value."))) + (param ?t core.type) + (param ?n core.nat) + (core.fn [?t] [(array.Array ?t ?n)] (ext array)) + (meta (core.meta.description "Initialize an array of size ?n with copies of a default value."))) diff --git a/hugr-model/tests/fixtures/model-lists.edn b/hugr-model/tests/fixtures/model-lists.edn index db84ffe72..7e39e225a 100644 --- a/hugr-model/tests/fixtures/model-lists.edn +++ b/hugr-model/tests/fixtures/model-lists.edn @@ -1,21 +1,21 @@ (hugr 0) (declare-operation core.call-indirect - (forall ?inputs (list type)) - (forall ?outputs (list type)) - (forall ?exts ext-set) - (-> [(-> ?inputs ?outputs ?exts) ?inputs ...] ?outputs ?exts)) + (param ?inputs (core.list core.type)) + (param ?outputs (core.list core.type)) + (param ?exts core.ext-set) + (core.fn [(core.fn ?inputs ?outputs ?exts) ?inputs ...] ?outputs ?exts)) (declare-operation core.compose-parallel - (forall ?inputs-0 (list type)) - (forall ?inputs-1 (list type)) - (forall ?outputs-0 (list type)) - (forall ?outputs-1 (list type)) - (forall ?exts ext-set) - (-> - [(-> ?inputs-0 ?outputs-0 ?exts) - (-> ?inputs-1 ?outputs-1 ?exts) - ?inputs-0 ... - ?inputs-1 ...] - [?outputs-0 ... ?outputs-1 ...] - ?exts)) + (param ?inputs-0 (core.list core.type)) + (param ?inputs-1 (core.list core.type)) + (param ?outputs-0 (core.list core.type)) + (param ?outputs-1 (core.list core.type)) + (param ?exts core.ext-set) + (core.fn + [(core.fn ?inputs-0 ?outputs-0 ?exts) + (core.fn ?inputs-1 ?outputs-1 ?exts) + ?inputs-0 ... + ?inputs-1 ...] + [?outputs-0 ... ?outputs-1 ...] + ?exts)) diff --git a/hugr-model/tests/fixtures/model-literals.edn b/hugr-model/tests/fixtures/model-literals.edn index 7e961d930..4fff1e399 100644 --- a/hugr-model/tests/fixtures/model-literals.edn +++ b/hugr-model/tests/fixtures/model-literals.edn @@ -1,5 +1,5 @@ (hugr 0) -(define-alias mod.string str "\"\n\r\t\\\u{1F44D}") -(define-alias mod.bytes bytes (bytes "SGVsbG8gd29ybGQg8J+Yig==")) -(define-alias mod.float float -3.141) +(define-alias mod.string core.str "\"\n\r\t\\\u{1F44D}") +(define-alias mod.bytes core.bytes (bytes "SGVsbG8gd29ybGQg8J+Yig==")) +(define-alias mod.float core.float -3.141) diff --git a/hugr-model/tests/fixtures/model-loop.edn b/hugr-model/tests/fixtures/model-loop.edn index f2c49d9d6..29ee5870c 100644 --- a/hugr-model/tests/fixtures/model-loop.edn +++ b/hugr-model/tests/fixtures/model-loop.edn @@ -1,13 +1,13 @@ (hugr 0) (define-func example.loop - (forall ?a type) - [?a] [?a] (ext) + (param ?a core.type) + (core.fn [?a] [?a] (ext)) (dfg [%0] [%1] - (signature (-> [?a] [?a] (ext))) - (tail-loop [%0] [%1] - (signature (-> [?a] [?a] (ext))) - (dfg [%2] [%3] - (signature (-> [?a] [(adt [[?a] [?a]])] (ext))) - (tag 0 [%2] [%3] - (signature (-> [?a] [(adt [[?a] [?a]])] (ext)))))))) + (signature (core.fn [?a] [?a] (ext))) + (tail-loop [%0] [%1] + (signature (core.fn [?a] [?a] (ext))) + (dfg [%2] [%3] + (signature (core.fn [?a] [(core.adt [[?a] [?a]])] (ext))) + ((core.make_adt _ _ 0) [%2] [%3] + (signature (core.fn [?a] [(core.adt [[?a] [?a]])] (ext)))))))) diff --git a/hugr-model/tests/fixtures/model-params.edn b/hugr-model/tests/fixtures/model-params.edn index 6f8554745..7da00ac01 100644 --- a/hugr-model/tests/fixtures/model-params.edn +++ b/hugr-model/tests/fixtures/model-params.edn @@ -2,8 +2,8 @@ (define-func example.swap ; The types of the values to be swapped are passed as implicit parameters. - (forall ?a type) - (forall ?b type) - [?a ?b] [?b ?a] (ext) + (param ?a core.type) + (param ?b core.type) + (core.fn [?a ?b] [?b ?a] (ext)) (dfg [%a %b] [%b %a] - (signature (-> [?a ?b] [?b ?a] (ext))))) + (signature (core.fn [?a ?b] [?b ?a] (ext))))) diff --git a/hugr-model/tests/snapshots/text__declarative_extensions.snap b/hugr-model/tests/snapshots/text__declarative_extensions.snap index 40f7cbeb1..7aaef20d4 100644 --- a/hugr-model/tests/snapshots/text__declarative_extensions.snap +++ b/hugr-model/tests/snapshots/text__declarative_extensions.snap @@ -5,17 +5,15 @@ expression: "roundtrip(include_str!(\"fixtures/model-decl-exts.edn\"))" (hugr 0) (declare-ctr array.Array - (param ?t type) - (param ?n nat) - type + (param ?t core.type) + (param ?n core.nat) + core.type (meta (core.meta.description "Fixed size array."))) (declare-operation array.Init - (param ?t type) - (param ?n nat) - (-> [?t] [(array.Array ?t ?n)] (ext array)) + (param ?t core.type) + (param ?n core.nat) + (core.fn [?t] [(array.Array ?t ?n)] (ext array)) (meta (core.meta.description "Initialize an array of size ?n with copies of a default value."))) - -(import core.meta.description) diff --git a/hugr-model/tests/snapshots/text__literals.snap b/hugr-model/tests/snapshots/text__literals.snap index 3aae5ad2e..66c5e8975 100644 --- a/hugr-model/tests/snapshots/text__literals.snap +++ b/hugr-model/tests/snapshots/text__literals.snap @@ -4,8 +4,8 @@ expression: "roundtrip(include_str!(\"fixtures/model-literals.edn\"))" --- (hugr 0) -(define-alias mod.string str "\"\n\r\t\\👍") +(define-alias mod.string core.str "\"\n\r\t\\👍") -(define-alias mod.bytes bytes (bytes "SGVsbG8gd29ybGQg8J+Yig==")) +(define-alias mod.bytes core.bytes (bytes "SGVsbG8gd29ybGQg8J+Yig==")) -(define-alias mod.float float -3.141) +(define-alias mod.float core.float -3.141) diff --git a/hugr/benches/benchmarks/hugr/examples.rs b/hugr/benches/benchmarks/hugr/examples.rs index 926a97205..6e5287e10 100644 --- a/hugr/benches/benchmarks/hugr/examples.rs +++ b/hugr/benches/benchmarks/hugr/examples.rs @@ -8,9 +8,9 @@ use hugr::builder::{ }; use hugr::extension::prelude::{bool_t, qb_t, usize_t}; use hugr::ops::OpName; -use hugr::std_extensions::arithmetic::float_types::float64_type; +use hugr::std_extensions::arithmetic::float_types::{float64_type, ConstF64}; use hugr::types::Signature; -use hugr::{type_row, Extension, Hugr, Node}; +use hugr::{type_row, CircuitUnit, Extension, Hugr, Node}; use lazy_static::lazy_static; pub fn simple_dfg_hugr() -> Hugr { @@ -95,9 +95,7 @@ pub struct CircuitLayer { pub fn circuit(layers: usize) -> (Hugr, Vec) { let h_gate = QUANTUM_EXT.instantiate_extension_op("H", []).unwrap(); let cx_gate = QUANTUM_EXT.instantiate_extension_op("CX", []).unwrap(); - // let rz = QUANTUM_EXT - // .instantiate_extension_op("Rz", []) - // .unwrap(); + let rz = QUANTUM_EXT.instantiate_extension_op("Rz", []).unwrap(); let signature = Signature::new_endo(vec![qb_t(), qb_t()]).with_extension_delta(QUANTUM_EXT.name().clone()); let mut module_builder = ModuleBuilder::new(); @@ -116,14 +114,13 @@ pub fn circuit(layers: usize) -> (Hugr, Vec) { linear.append(cx_gate.clone(), [1, 0]).unwrap(); let cx2 = linear.tracked_wire(0).unwrap().node(); - // TODO: Currently left out because we can not represent constants in the model - // let angle = linear.add_constant(ConstF64::new(0.5)); - // linear - // .append_and_consume( - // rz.clone(), - // [CircuitUnit::Linear(0), CircuitUnit::Wire(angle)], - // ) - // .unwrap(); + let angle = linear.add_constant(ConstF64::new(0.5)); + linear + .append_and_consume( + rz.clone(), + [CircuitUnit::Linear(0), CircuitUnit::Wire(angle)], + ) + .unwrap(); layer_ids.push(CircuitLayer { h, cx1, cx2 }); }