Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Improved representation for metadata in hugr-model #1849

Merged
merged 1 commit into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 14 additions & 23 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ use std::fmt::Write;

pub(crate) const OP_FUNC_CALL_INDIRECT: &str = "func.call-indirect";
const TERM_PARAM_TUPLE: &str = "param.tuple";
const TERM_JSON: &str = "prelude.json";
const META_DESCRIPTION: &str = "docs.description";
const TERM_JSON_CONST: &str = "prelude.const-json";

/// Export a [`Hugr`] graph to its representation in the model.
pub fn export_hugr<'a>(hugr: &'a Hugr, bump: &'a Bump) -> model::Module<'a> {
Expand Down Expand Up @@ -558,15 +555,12 @@ impl<'a> Context<'a> {
let mut meta = BumpVec::with_capacity_in(meta_len, self.bump);

if let Some(description) = description {
let name = META_DESCRIPTION;
let value = self.make_term(model::Term::Str(self.bump.alloc_str(description)));
meta.push(model::MetaItem { name, value })
meta.push(self.make_term_apply(model::CORE_META_DESCRIPTION, &[value]));
}

for (name, value) in opdef.iter_misc() {
let name = self.bump.alloc_str(name);
let value = self.export_json_meta(value);
meta.push(model::MetaItem { name, value });
meta.push(self.export_json_meta(name, value));
}

self.bump.alloc_slice_copy(&meta)
Expand Down Expand Up @@ -1036,7 +1030,7 @@ impl<'a> Context<'a> {
let args = self
.bump
.alloc_slice_copy(&[runtime_type, json, extensions]);
let symbol = self.resolve_symbol(TERM_JSON_CONST);
let symbol = self.resolve_symbol(model::COMPAT_CONST_JSON);
self.make_term(model::Term::ApplyFull { symbol, args })
}

Expand Down Expand Up @@ -1075,30 +1069,21 @@ impl<'a> Context<'a> {
}
}

pub fn export_node_metadata(
&mut self,
metadata_map: &NodeMetadataMap,
) -> &'a [model::MetaItem<'a>] {
pub fn export_node_metadata(&mut self, metadata_map: &NodeMetadataMap) -> &'a [model::TermId] {
let mut meta = BumpVec::with_capacity_in(metadata_map.len(), self.bump);

for (name, value) in metadata_map {
let name = self.bump.alloc_str(name);
let value = self.export_json_meta(value);
meta.push(model::MetaItem { name, value });
meta.push(self.export_json_meta(name, value));
}

meta.into_bump_slice()
}

pub fn export_json_meta(&mut self, value: &serde_json::Value) -> model::TermId {
pub fn export_json_meta(&mut self, name: &str, value: &serde_json::Value) -> model::TermId {
let value = serde_json::to_string(value).expect("json values are always serializable");
let value = self.make_term(model::Term::Str(self.bump.alloc_str(&value)));
let value = self.bump.alloc_slice_copy(&[value]);
let symbol = self.resolve_symbol(TERM_JSON);
self.make_term(model::Term::ApplyFull {
symbol,
args: value,
})
let name = self.make_term(model::Term::Str(self.bump.alloc_str(name)));
self.make_term_apply(model::COMPAT_META_JSON, &[name, value])
}

fn resolve_symbol(&mut self, name: &'a str) -> model::NodeId {
Expand All @@ -1114,6 +1099,12 @@ impl<'a> Context<'a> {
}),
}
}

fn make_term_apply(&mut self, name: &'a str, args: &[model::TermId]) -> model::TermId {
let symbol = self.resolve_symbol(name);
let args = self.bump.alloc_slice_copy(args);
self.make_term(model::Term::ApplyFull { symbol, args })
}
}

#[cfg(test)]
Expand Down
28 changes: 17 additions & 11 deletions hugr-core/src/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ use itertools::Either;
use smol_str::{SmolStr, ToSmolStr};
use thiserror::Error;

const TERM_JSON: &str = "prelude.json";
const TERM_JSON_CONST: &str = "prelude.const-json";

/// Error during import.
#[derive(Debug, Clone, Error)]
pub enum ImportError {
Expand Down Expand Up @@ -174,8 +171,8 @@ impl<'a> Context<'a> {
for meta_item in node_data.meta {
// TODO: For now we expect all metadata to be JSON since this is how
// it is handled in `hugr-core`.
let value = self.import_json_meta(meta_item.value)?;
self.hugr.set_metadata(node, meta_item.name, value);
let (name, value) = self.import_json_meta(*meta_item)?;
self.hugr.set_metadata(node, name, value);
}

Ok(node)
Expand Down Expand Up @@ -949,6 +946,7 @@ impl<'a> Context<'a> {
| model::Term::NonLinearConstraint { .. }
| model::Term::ConstFunc { .. }
| model::Term::Bytes { .. }
| model::Term::Meta
| model::Term::ConstAdt { .. } => Err(model::ModelError::TypeError(term_id).into()),

model::Term::ControlType => {
Expand Down Expand Up @@ -1016,7 +1014,9 @@ impl<'a> Context<'a> {
Ok(TypeArg::Type { ty })
}

model::Term::Control { .. } | model::Term::NonLinearConstraint { .. } => {
model::Term::Control { .. }
| model::Term::Meta
| model::Term::NonLinearConstraint { .. } => {
Err(model::ModelError::TypeError(term_id).into())
}
}
Expand Down Expand Up @@ -1137,6 +1137,7 @@ impl<'a> Context<'a> {
| model::Term::Bytes { .. }
| model::Term::BytesType
| model::Term::ConstFunc { .. }
| model::Term::Meta
| model::Term::ConstAdt { .. } => Err(model::ModelError::TypeError(term_id).into()),
}
}
Expand Down Expand Up @@ -1275,7 +1276,7 @@ impl<'a> Context<'a> {
fn import_json_meta(
&mut self,
term_id: model::TermId,
) -> Result<serde_json::Value, ImportError> {
) -> Result<(&'a str, serde_json::Value), ImportError> {
let (global, args) = match self.get_term(term_id)? {
model::Term::Apply { symbol, args } | model::Term::ApplyFull { symbol, args } => {
(symbol, args)
Expand All @@ -1284,11 +1285,15 @@ impl<'a> Context<'a> {
};

let global = self.get_symbol_name(*global)?;
if global != TERM_JSON {
if global != model::COMPAT_META_JSON {
return Err(model::ModelError::TypeError(term_id).into());
}

let [json_arg] = args else {
let [name_arg, json_arg] = args else {
return Err(model::ModelError::TypeError(term_id).into());
};

let model::Term::Str(name) = self.get_term(*name_arg)? else {
return Err(model::ModelError::TypeError(term_id).into());
};

Expand All @@ -1299,7 +1304,7 @@ impl<'a> Context<'a> {
let json_value =
serde_json::from_str(json_str).map_err(|_| model::ModelError::TypeError(term_id))?;

Ok(json_value)
Ok((name, json_value))
}

fn import_value(
Expand All @@ -1319,7 +1324,7 @@ impl<'a> Context<'a> {
model::Term::ApplyFull { symbol, args } => {
let symbol_name = self.get_symbol_name(*symbol)?;

if symbol_name == TERM_JSON_CONST {
if symbol_name == model::COMPAT_CONST_JSON {
let value = args.get(1).ok_or(model::ModelError::TypeError(term_id))?;

let model::Term::Str(json) = self.get_term(*value)? else {
Expand Down Expand Up @@ -1375,6 +1380,7 @@ impl<'a> Context<'a> {
| model::Term::Type
| model::Term::Bytes { .. }
| model::Term::BytesType
| model::Term::Meta
| model::Term::NonLinearConstraint { .. } => {
Err(model::ModelError::TypeError(term_id).into())
}
Expand Down
14 changes: 8 additions & 6 deletions hugr-core/tests/snapshots/model__roundtrip_call.snap
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-call
---
(hugr 0)

(import prelude.json)
(import compat.meta-json)

(import arithmetic.int.types.int)

Expand All @@ -13,18 +13,20 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-call
[(@ arithmetic.int.types.int)]
[(@ arithmetic.int.types.int)]
(ext ?0 ... arithmetic.int)
(meta doc.description (@ prelude.json "\"This is a function declaration.\""))
(meta doc.title (@ prelude.json "\"Callee\"")))
(meta
(@ compat.meta-json "description" "\"This is a function declaration.\""))
(meta (@ compat.meta-json "title" "\"Callee\"")))

(define-func example.caller
[(@ arithmetic.int.types.int)]
[(@ arithmetic.int.types.int)]
(ext arithmetic.int)
(meta doc.description
(meta
(@
prelude.json
compat.meta-json
"description"
"\"This defines a function that calls the function which we declared earlier.\""))
(meta doc.title (@ prelude.json "\"Caller\""))
(meta (@ compat.meta-json "title" "\"Caller\""))
(dfg
[%0] [%1]
(signature
Expand Down
10 changes: 5 additions & 5 deletions hugr-core/tests/snapshots/model__roundtrip_const.snap
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cons
---
(hugr 0)

(import prelude.const-json)
(import compat.const-json)

(import arithmetic.float.types.float64)

Expand Down Expand Up @@ -34,12 +34,12 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cons
(tag
0
[(@
prelude.const-json
compat.const-json
(@ arithmetic.float.types.float64)
"{\"c\":\"ConstF64\",\"v\":{\"value\":2.0}}"
(ext arithmetic.float.types))
(@
prelude.const-json
compat.const-json
(@ arithmetic.float.types.float64)
"{\"c\":\"ConstF64\",\"v\":{\"value\":3.0}}"
(ext arithmetic.float.types))])
Expand All @@ -63,15 +63,15 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cons
(ext)))
(const
(@
prelude.const-json
compat.const-json
(@ arithmetic.float.types.float64)
"{\"c\":\"ConstF64\",\"v\":{\"value\":1.0}}"
(ext arithmetic.float.types))
[] [%0]
(signature (-> [] [(@ arithmetic.float.types.float64)] (ext))))
(const
(@
prelude.const-json
compat.const-json
(@ arithmetic.float.types.float64)
"{\"c\":\"ConstUnknown\",\"v\":{\"value\":1.0}}"
(ext))
Expand Down
10 changes: 3 additions & 7 deletions hugr-model/capnp/hugr-v0.capnp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ struct Node {
outputs @2 :List(LinkIndex);
params @3 :List(TermId);
regions @4 :List(RegionId);
meta @5 :List(MetaItem);
meta @5 :List(TermId);
signature @6 :OptionalTermId;
}

Expand Down Expand Up @@ -105,7 +105,7 @@ struct Region {
sources @1 :List(LinkIndex);
targets @2 :List(LinkIndex);
children @3 :List(NodeId);
meta @4 :List(MetaItem);
meta @4 :List(TermId);
signature @5 :OptionalTermId;
scope @6 :RegionScope;
}
Expand All @@ -124,11 +124,6 @@ enum RegionKind {
module @2;
}

struct MetaItem {
name @0 :Text;
value @1 :UInt32;
}

struct Term {
union {
wildcard @0 :Void;
Expand Down Expand Up @@ -159,6 +154,7 @@ struct Term {
constAdt @23 :ConstAdt;
bytes @24 :Data;
bytesType @25 :Void;
meta @26 :Void;
}

struct Apply {
Expand Down
14 changes: 3 additions & 11 deletions hugr-model/src/v0/binary/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ fn read_node<'a>(bump: &'a Bump, reader: hugr_capnp::node::Reader) -> ReadResult
let outputs = read_scalar_list!(bump, reader, get_outputs, model::LinkIndex);
let params = read_scalar_list!(bump, reader, get_params, model::TermId);
let regions = read_scalar_list!(bump, reader, get_regions, model::RegionId);
let meta = read_list!(bump, reader, get_meta, read_meta_item);
let meta = read_scalar_list!(bump, reader, get_meta, model::TermId);
let signature = reader.get_signature().checked_sub(1).map(model::TermId);

Ok(model::Node {
Expand Down Expand Up @@ -217,7 +217,7 @@ fn read_region<'a>(
let sources = read_scalar_list!(bump, reader, get_sources, model::LinkIndex);
let targets = read_scalar_list!(bump, reader, get_targets, model::LinkIndex);
let children = read_scalar_list!(bump, reader, get_children, model::NodeId);
let meta = read_list!(bump, reader, get_meta, read_meta_item);
let meta = read_scalar_list!(bump, reader, get_meta, model::TermId);
let signature = reader.get_signature().checked_sub(1).map(model::TermId);

let scope = if reader.has_scope() {
Expand Down Expand Up @@ -256,6 +256,7 @@ fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult
Which::NatType(()) => model::Term::NatType,
Which::ExtSetType(()) => model::Term::ExtSetType,
Which::ControlType(()) => model::Term::ControlType,
Which::Meta(()) => model::Term::Meta,

Which::Variable(reader) => {
let node = model::NodeId(reader.get_variable_node());
Expand Down Expand Up @@ -343,15 +344,6 @@ fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult
})
}

fn read_meta_item<'a>(
bump: &'a Bump,
reader: hugr_capnp::meta_item::Reader,
) -> ReadResult<model::MetaItem<'a>> {
let name = bump.alloc_str(reader.get_name()?.to_str()?);
let value = model::TermId(reader.get_value());
Ok(model::MetaItem { name, value })
}

fn read_list_part(
_: &Bump,
reader: hugr_capnp::term::list_part::Reader,
Expand Down
13 changes: 6 additions & 7 deletions hugr-model/src/v0/binary/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ fn write_node(mut builder: hugr_capnp::node::Builder, node: &model::Node) {
write_operation(builder.reborrow().init_operation(), &node.operation);
let _ = builder.set_inputs(model::LinkIndex::unwrap_slice(node.inputs));
let _ = builder.set_outputs(model::LinkIndex::unwrap_slice(node.outputs));
write_list!(builder, init_meta, write_meta_item, node.meta);
let _ = builder.set_meta(model::TermId::unwrap_slice(node.meta));
let _ = builder.set_params(model::TermId::unwrap_slice(node.params));
let _ = builder.set_regions(model::RegionId::unwrap_slice(node.regions));
builder.set_signature(node.signature.map_or(0, |t| t.0 + 1));
Expand Down Expand Up @@ -117,11 +117,6 @@ fn write_param(mut builder: hugr_capnp::param::Builder, param: &model::Param) {
});
}

fn write_meta_item(mut builder: hugr_capnp::meta_item::Builder, meta_item: &model::MetaItem) {
builder.set_name(meta_item.name);
builder.set_value(meta_item.value.0)
}

fn write_region(mut builder: hugr_capnp::region::Builder, region: &model::Region) {
builder.set_kind(match region.kind {
model::RegionKind::DataFlow => hugr_capnp::RegionKind::DataFlow,
Expand All @@ -132,7 +127,7 @@ fn write_region(mut builder: hugr_capnp::region::Builder, region: &model::Region
let _ = builder.set_sources(model::LinkIndex::unwrap_slice(region.sources));
let _ = builder.set_targets(model::LinkIndex::unwrap_slice(region.targets));
let _ = builder.set_children(model::NodeId::unwrap_slice(region.children));
write_list!(builder, init_meta, write_meta_item, region.meta);
let _ = builder.set_meta(model::TermId::unwrap_slice(region.meta));
builder.set_signature(region.signature.map_or(0, |t| t.0 + 1));

if let Some(scope) = &region.scope {
Expand Down Expand Up @@ -225,6 +220,10 @@ fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) {
model::Term::BytesType => {
builder.set_bytes_type(());
}

model::Term::Meta => {
builder.set_meta(());
}
}
}

Expand Down
Loading
Loading