Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Constant values in hugr-model #1838

Merged
merged 8 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 162 additions & 139 deletions hugr-core/src/export.rs

Large diffs are not rendered by default.

190 changes: 164 additions & 26 deletions hugr-core/src/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ use crate::{
extension::{ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError},
hugr::{HugrMut, IdentList},
ops::{
AliasDecl, AliasDefn, Call, CallIndirect, Case, Conditional, DataflowBlock, ExitBlock,
FuncDecl, FuncDefn, Input, LoadFunction, Module, OpType, OpaqueOp, Output, Tag, TailLoop,
CFG, DFG,
constant::{CustomConst, CustomSerialized, OpaqueValue},
AliasDecl, AliasDefn, Call, CallIndirect, Case, Conditional, Const, DataflowBlock,
ExitBlock, FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, Module, OpType, OpaqueOp,
Output, Tag, TailLoop, Value, CFG, DFG,
},
types::{
type_param::TypeParam, type_row::TypeRowBase, CustomType, FuncTypeBase, MaybeRV,
Expand All @@ -28,6 +29,7 @@ use smol_str::{SmolStr, ToSmolStr};
use thiserror::Error;

const TERM_JSON: &str = "prelude.json";
const TERM_JSON_CONST: &str = "prelude.const-json";

/// Error during import.
#[derive(Debug, Clone, Error)]
Expand Down Expand Up @@ -172,7 +174,7 @@ impl<'a> Context<'a> {
for meta_item in node_data.meta {
// TODO: For now we expect all metadata to be JSON since this is how
// it is handled in `hugr-core`.
let value = self.import_json_value(meta_item.value)?;
let value = self.import_json_meta(meta_item.value)?;
self.hugr.set_metadata(node, meta_item.name, value);
}

Expand Down Expand Up @@ -442,12 +444,6 @@ impl<'a> Context<'a> {

let node = self.make_node(node_id, optype, parent)?;

match node_data.regions {
[] => {}
[region] => self.import_dfg_region(node_id, *region, node)?,
_ => return Err(error_unsupported!("multiple regions in custom operation")),
}

Ok(Some(node))
}

Expand Down Expand Up @@ -508,6 +504,36 @@ impl<'a> Context<'a> {

model::Operation::DeclareConstructor { .. } => Ok(None),
model::Operation::DeclareOperation { .. } => Ok(None),

model::Operation::Const { value } => {
let signature = node_data
.signature
.ok_or_else(|| error_uninferred!("node signature"))?;
let (_, outputs, _) = self.get_func_type(signature)?;
let outputs = self.import_closed_list(outputs)?;
let output = outputs
.first()
.ok_or(model::ModelError::TypeError(signature))?;
let datatype = self.import_type(*output)?;

let imported_value = self.import_value(value, *output)?;

let load_const_node = self.make_node(
node_id,
OpType::LoadConstant(LoadConstant {
datatype: datatype.clone(),
}),
parent,
)?;

let const_node = self
.hugr
.add_node_with_parent(parent, OpType::Const(Const::new(imported_value)));

self.hugr.connect(const_node, 0, load_const_node, 0);

Ok(Some(load_const_node))
}
}
}

Expand Down Expand Up @@ -897,7 +923,7 @@ impl<'a> Context<'a> {
model::Term::Apply { .. } => Err(error_unsupported!("custom type as `TypeParam`")),
model::Term::ApplyFull { .. } => Err(error_unsupported!("custom type as `TypeParam`")),

model::Term::Quote { .. } => Err(error_unsupported!("`(quote ...)` as `TypeParam`")),
model::Term::Const { .. } => Err(error_unsupported!("`(const ...)` as `TypeParam`")),
model::Term::FuncType { .. } => Err(error_unsupported!("`(fn ...)` as `TypeParam`")),

model::Term::ListType { item_type } => {
Expand All @@ -918,9 +944,9 @@ impl<'a> Context<'a> {
| model::Term::ExtSet { .. }
| model::Term::Adt { .. }
| model::Term::Control { .. }
| model::Term::NonLinearConstraint { .. } => {
Err(model::ModelError::TypeError(term_id).into())
}
| model::Term::NonLinearConstraint { .. }
| model::Term::ConstFunc { .. }
| model::Term::ConstAdt { .. } => Err(model::ModelError::TypeError(term_id).into()),

model::Term::ControlType => {
Err(error_unsupported!("type of control types as `TypeParam`"))
Expand Down Expand Up @@ -959,9 +985,6 @@ impl<'a> Context<'a> {
arg: value.to_string(),
}),

model::Term::Quote { .. } => Ok(TypeArg::Type {
ty: self.import_type(term_id)?,
}),
model::Term::Nat(value) => Ok(TypeArg::BoundedNat { n: *value }),
model::Term::ExtSet { .. } => Ok(TypeArg::Extensions {
es: self.import_extension_set(term_id)?,
Expand All @@ -976,6 +999,11 @@ impl<'a> Context<'a> {
model::Term::Constraint => Err(error_unsupported!("`constraint` as `TypeArg`")),
model::Term::StaticType => Err(error_unsupported!("`static` as `TypeArg`")),
model::Term::ControlType => Err(error_unsupported!("`ctrl` as `TypeArg`")),
model::Term::Const { .. } => Err(error_unsupported!("`const` as `TypeArg`")),
model::Term::ConstAdt { .. } => Err(error_unsupported!("adt constant as `TypeArg`")),
model::Term::ConstFunc { .. } => {
Err(error_unsupported!("function constant as `TypeArg`"))
}

model::Term::FuncType { .. }
| model::Term::Adt { .. }
Expand Down Expand Up @@ -1045,12 +1073,12 @@ impl<'a> Context<'a> {
let (extension, id) = self.import_custom_name(name)?;

let extension_ref =
self.extensions.get(&extension.to_string()).ok_or_else(|| {
ImportError::Extension {
self.extensions
.get(&extension)
.ok_or_else(|| ImportError::Extension {
missing_ext: extension.clone(),
available: self.extensions.ids().cloned().collect(),
}
})?;
})?;

Ok(TypeBase::new_extension(CustomType::new(
id,
Expand Down Expand Up @@ -1090,16 +1118,16 @@ impl<'a> Context<'a> {
| model::Term::StaticType
| model::Term::Type
| model::Term::Constraint
| model::Term::Quote { .. }
| model::Term::Const { .. }
| model::Term::Str(_)
| model::Term::ExtSet { .. }
| model::Term::List { .. }
| model::Term::Control { .. }
| model::Term::ControlType
| model::Term::Nat(_)
| model::Term::NonLinearConstraint { .. } => {
Err(model::ModelError::TypeError(term_id).into())
}
| model::Term::NonLinearConstraint { .. }
| model::Term::ConstFunc { .. }
| model::Term::ConstAdt { .. } => Err(model::ModelError::TypeError(term_id).into()),
}
}

Expand Down Expand Up @@ -1234,7 +1262,7 @@ impl<'a> Context<'a> {
}
}

fn import_json_value(
fn import_json_meta(
&mut self,
term_id: model::TermId,
) -> Result<serde_json::Value, ImportError> {
Expand Down Expand Up @@ -1263,6 +1291,116 @@ impl<'a> Context<'a> {

Ok(json_value)
}

fn import_value(
&mut self,
term_id: model::TermId,
type_id: model::TermId,
) -> Result<Value, ImportError> {
let term_data = self.get_term(term_id)?;

match term_data {
model::Term::Wildcard => Err(error_uninferred!("wildcard")),
model::Term::Apply { .. } => {
Err(error_uninferred!("application with implicit parameters"))
}
model::Term::Var(_) => Err(error_unsupported!("constant value containing a variable")),

model::Term::ApplyFull { symbol, args } => {
let symbol_name = self.get_symbol_name(*symbol)?;

if symbol_name == TERM_JSON_CONST {
let value = args.get(1).ok_or(model::ModelError::TypeError(term_id))?;

let model::Term::Str(json) = self.get_term(*value)? else {
return Err(model::ModelError::TypeError(term_id).into());
};

// We attempt to deserialize as the custom const directly.
ss2165 marked this conversation as resolved.
Show resolved Hide resolved
// This might fail due to the custom const struct not being included when
// this code was compiled; in that case, we fall back to the serialized form.
let value: Option<Box<dyn CustomConst>> = serde_json::from_str(json).ok();

if let Some(value) = value {
let opaque_value = OpaqueValue::from(value);
return Ok(Value::Extension { e: opaque_value });
} else {
let runtime_type =
ss2165 marked this conversation as resolved.
Show resolved Hide resolved
args.first().ok_or(model::ModelError::TypeError(term_id))?;
let runtime_type = self.import_type(*runtime_type)?;

let extensions =
args.get(2).ok_or(model::ModelError::TypeError(term_id))?;
let extensions = self.import_extension_set(*extensions)?;

let value: serde_json::Value = serde_json::from_str(json)
.map_err(|_| model::ModelError::TypeError(term_id))?;
let custom_const = CustomSerialized::new(runtime_type, value, extensions);
let opaque_value = OpaqueValue::new(custom_const);
return Ok(Value::Extension { e: opaque_value });
}
}

Err(error_unsupported!("constant value that is not JSON data"))
// TODO: This should ultimately include the following cases:
// - function definitions
// - custom constructors for values
}

model::Term::StaticType
| model::Term::Constraint
| model::Term::Const { .. }
| model::Term::List { .. }
| model::Term::ListType { .. }
| model::Term::Str(_)
| model::Term::StrType
| model::Term::Nat(_)
| model::Term::NatType
| model::Term::ExtSet { .. }
| model::Term::ExtSetType
| model::Term::Adt { .. }
| model::Term::FuncType { .. }
| model::Term::Control { .. }
| model::Term::ControlType
| model::Term::Type
| model::Term::NonLinearConstraint { .. } => {
Err(model::ModelError::TypeError(term_id).into())
}

model::Term::ConstFunc { .. } => Err(error_unsupported!("constant function value")),

model::Term::ConstAdt { tag, values } => {
let model::Term::Adt { variants } = self.get_term(type_id)? else {
return Err(model::ModelError::TypeError(term_id).into());
};

let values = self.import_closed_list(*values)?;
let variants = self.import_closed_list(*variants)?;

let variant = variants
.get(*tag as usize)
.ok_or(model::ModelError::TypeError(term_id))?;
let variant = self.import_closed_list(*variant)?;

let items = values
.iter()
.zip(variant.iter())
.map(|(value, typ)| self.import_value(*value, *typ))
.collect::<Result<Vec<_>, _>>()?;

let typ = {
// TODO: Import as a `SumType` directly and avoid the copy.
ss2165 marked this conversation as resolved.
Show resolved Hide resolved
let typ: Type = self.import_type(type_id)?;
match typ.as_type_enum() {
TypeEnum::Sum(sum) => sum.clone(),
_ => unreachable!(),
}
};

Ok(Value::sum(*tag as _, items, typ).unwrap())
}
}
}
}

/// Information about a local variable.
Expand Down
6 changes: 6 additions & 0 deletions hugr-core/src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,12 @@ impl<CC: CustomConst> From<CC> for OpaqueValue {
}
}

impl From<Box<dyn CustomConst>> for OpaqueValue {
fn from(value: Box<dyn CustomConst>) -> Self {
Self { v: value }
}
}

impl PartialEq for OpaqueValue {
fn eq(&self, other: &Self) -> bool {
self.value().equal_consts(other.value())
Expand Down
7 changes: 7 additions & 0 deletions hugr-core/tests/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,10 @@ pub fn test_roundtrip_constraints() {
"../../hugr-model/tests/fixtures/model-constraints.edn"
)));
}

#[test]
pub fn test_roundtrip_const() {
insta::assert_snapshot!(roundtrip(include_str!(
"../../hugr-model/tests/fixtures/model-const.edn"
)));
}
4 changes: 2 additions & 2 deletions hugr-core/tests/snapshots/model__roundtrip_add.snap
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-add.
(dfg
[%0 %1] [%2]
(signature
(fn
(->
[(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)]
[(@ arithmetic.int.types.int)]
(ext)))
((@ arithmetic.int.iadd) [%0 %1] [%2]
(signature
(fn
(->
[(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)]
[(@ arithmetic.int.types.int)]
(ext arithmetic.int))))))
2 changes: 1 addition & 1 deletion hugr-core/tests/snapshots/model__roundtrip_alias.snap
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-alia

(define-alias local.int type (@ arithmetic.int.types.int))

(define-alias local.endo type (fn [] [] (ext)))
(define-alias local.endo type (-> [] [] (ext)))
17 changes: 9 additions & 8 deletions hugr-core/tests/snapshots/model__roundtrip_call.snap
Original file line number Diff line number Diff line change
Expand Up @@ -28,38 +28,39 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-call
(dfg
[%0] [%1]
(signature
(fn
(->
[(@ arithmetic.int.types.int)]
[(@ arithmetic.int.types.int)]
(ext arithmetic.int)))
(call (@ example.callee (ext)) [%0] [%1]
(signature
(fn
(->
[(@ arithmetic.int.types.int)]
[(@ arithmetic.int.types.int)]
(ext arithmetic.int))))))

(define-func example.load
[]
[(fn
[(->
[(@ arithmetic.int.types.int)]
[(@ arithmetic.int.types.int)]
(ext arithmetic.int))]
(ext)
(dfg
[] [%0]
(signature
(fn
(->
[]
[(fn
[(->
[(@ arithmetic.int.types.int)]
[(@ arithmetic.int.types.int)]
(ext arithmetic.int))]
(ext)))
(load-func (@ example.caller)
(load-func (@ example.caller) [] [%0]
(signature
(fn
(->
[]
[(fn
[(->
[(@ arithmetic.int.types.int)]
[(@ arithmetic.int.types.int)]
(ext arithmetic.int))]
Expand Down
Loading
Loading