Skip to content

Commit

Permalink
Better design for group parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
arnodb committed Feb 11, 2025
1 parent bee8b90 commit 7b52759
Show file tree
Hide file tree
Showing 21 changed files with 88 additions and 98 deletions.
2 changes: 1 addition & 1 deletion examples/quirky_binder_example_index_first_char/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ use super::tokenize;
)
- tokenize#tokenize()
- sort#sort(fields: ["first_char", "word"])
- group#group(fields: ["word"], group_field: "words")
- group#group(by_fields: ["first_char"], group_field: "words")
- function_terminate#term(
body: r#"
use itertools::Itertools;
Expand Down
2 changes: 1 addition & 1 deletion examples/quirky_binder_example_tree/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ use quirky_binder::{
( < extracted
- unwrap(fields: ["parent_id"], skip_nones: true)
- sort(fields: ["parent_id", "id"])
- group(fields: ["id"], group_field: "children")
- group(by_fields: ["parent_id"], group_field: "children")
-> children
)
}
Expand Down
114 changes: 52 additions & 62 deletions quirky_binder/src/filter/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ const GROUP_TRACE_NAME: &str = "group";
#[derive(Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct GroupParams<'a> {
fields: FieldsParam<'a>,
by_fields: FieldsParam<'a>,
group_field: &'a str,
}

Expand All @@ -26,9 +26,9 @@ pub struct Group {
inputs: [NodeStream; 1],
#[getset(get = "pub")]
outputs: [NodeStream; 1],
by_fields: Vec<ValidFieldName>,
group_field: ValidFieldName,
group_stream: NodeSubStream,
fields: Vec<ValidFieldName>,
}

impl Group {
Expand All @@ -38,9 +38,9 @@ impl Group {
inputs: [NodeStream; 1],
params: GroupParams,
) -> ChainResultWithTrace<Group> {
let valid_fields =
let valid_by_fields =
params
.fields
.by_fields
.validate_on_stream(inputs.single(), graph, GROUP_TRACE_NAME)?;
let valid_group_field = ValidFieldName::try_from(params.group_field)
.map_err(|_| ChainError::InvalidFieldName {
Expand All @@ -67,13 +67,14 @@ impl Group {
let mut group_stream_def = group_stream.record_definition().borrow_mut();

let variant = &output_stream_def[variant_id];
let mut group_by_datum_ids =
Vec::with_capacity(variant.data_len() - valid_fields.len());
let mut group_data = Vec::with_capacity(valid_fields.len());
let mut group_datum_ids = Vec::with_capacity(valid_fields.len());
let mut group_by_datum_ids = Vec::with_capacity(valid_by_fields.len());
let mut group_data =
Vec::with_capacity(variant.data_len() - valid_by_fields.len());
let mut group_datum_ids =
Vec::with_capacity(variant.data_len() - valid_by_fields.len());
for datum_id in variant.data() {
let datum = &output_stream_def[datum_id];
if valid_fields
if !valid_by_fields
.iter()
.any(|field| field.name() == datum.name())
{
Expand Down Expand Up @@ -152,9 +153,9 @@ impl Group {
name: name.clone(),
inputs,
outputs,
by_fields: valid_by_fields,
group_field: valid_group_field,
group_stream,
fields: valid_fields,
})
}
}
Expand Down Expand Up @@ -184,37 +185,29 @@ impl DynNode for Group {
let group_unpacked_record = def_group.unpacked_record();

let fields = {
let names = self
.fields
.iter()
.map(ValidFieldName::ident)
.collect::<Vec<_>>();
quote!(#(#names),*)
let record_definition = &graph.record_definitions()[self.inputs.single().record_type()];
let variant = &record_definition[self.inputs.single().variant_id()];
let idents = variant.data().filter_map(|d| {
let datum = &record_definition[d];
if !self
.by_fields
.iter()
.any(|field| field.name() == datum.name())
{
Some(format_ident!("{}", datum.name()))
} else {
None
}
});
quote!(#(#idents),*)
};

let group_field = self.group_field.ident();
let mut_group_field = self.group_field.mut_ident();

let record_definition = &graph.record_definitions()[self.inputs.single().record_type()];
let variant = &record_definition[self.inputs.single().variant_id()];
let eq = {
let fields = variant
.data()
.filter_map(|d| {
let datum = &record_definition[d];
if !self.fields.iter().any(|field| field.name() == datum.name()) {
Some(datum.name())
} else {
None
}
})
.collect::<Vec<_>>();
fields_eq_ab(
&def.record(),
fields.iter(),
&def_input.record(),
fields.iter(),
)
let fields = self.by_fields.iter().map(ValidFieldName::name);
fields_eq_ab(&def.record(), fields.clone(), &def_input.record(), fields)
};

let rec_ident = self.identifier_for("rec");
Expand Down Expand Up @@ -267,7 +260,7 @@ const SUB_GROUP_TRACE_NAME: &str = "sub_group";
#[serde(deny_unknown_fields)]
pub struct SubGroupParams<'a> {
path_fields: FieldsParam<'a>,
fields: FieldsParam<'a>,
by_fields: FieldsParam<'a>,
group_field: &'a str,
}

Expand All @@ -279,9 +272,9 @@ pub struct SubGroup {
#[getset(get = "pub")]
outputs: [NodeStream; 1],
path_streams: Vec<PathUpdateElement>,
by_fields: Vec<ValidFieldName>,
group_field: ValidFieldName,
group_stream: NodeSubStream,
fields: Vec<ValidFieldName>,
}

impl SubGroup {
Expand All @@ -300,8 +293,8 @@ impl SubGroup {
name: params.group_field.to_owned(),
})
.with_trace_element(trace_element!(SUB_GROUP_TRACE_NAME))?;
let valid_fields = params
.fields
let valid_by_fields = params
.by_fields
.validate_on_record_definition(&path_def, SUB_GROUP_TRACE_NAME)?;
drop(path_def);

Expand Down Expand Up @@ -330,13 +323,14 @@ impl SubGroup {
let mut group_stream_def = group_stream.record_definition().borrow_mut();

let variant = &output_stream_def[variant_id];
let mut group_by_datum_ids =
Vec::with_capacity(variant.data_len() - valid_fields.len());
let mut group_data = Vec::with_capacity(valid_fields.len());
let mut group_datum_ids = Vec::with_capacity(valid_fields.len());
let mut group_by_datum_ids = Vec::with_capacity(valid_by_fields.len());
let mut group_data =
Vec::with_capacity(variant.data_len() - valid_by_fields.len());
let mut group_datum_ids =
Vec::with_capacity(variant.data_len() - valid_by_fields.len());
for datum_id in variant.data() {
let datum = &output_stream_def[datum_id];
if valid_fields
if !valid_by_fields
.iter()
.any(|field| field.name() == datum.name())
{
Expand Down Expand Up @@ -418,9 +412,9 @@ impl SubGroup {
inputs,
outputs,
path_streams,
by_fields: valid_by_fields,
group_field: valid_group_field,
group_stream: created_group_stream.expect("group stream"),
fields: valid_fields,
})
}
}
Expand All @@ -446,15 +440,6 @@ impl DynNode for SubGroup {
let group_record = def_group.record();
let group_unpacked_record = def_group.unpacked_record();

let fields = {
let names = self
.fields
.iter()
.map(ValidFieldName::ident)
.collect::<Vec<_>>();
quote!(#(#names),*)
};

let group_field = self.group_field.ident();
let mut_group_field = self.group_field.mut_ident();

Expand All @@ -465,19 +450,24 @@ impl DynNode for SubGroup {
&graph.record_definitions()[path_stream.sub_input_stream.record_type()];
let variant = &leaf_record_definition[path_stream.sub_input_stream.variant_id()];

let fields = {
let idents = variant.data().filter_map(|d| {
let datum = &leaf_record_definition[d];
if !self.by_fields.iter().any(|f| f.name() == datum.name()) {
Some(format_ident!("{}", datum.name()))
} else {
None
}
});
quote!(#(#idents),*)
};

let out_record_definition =
chain.sub_stream_definition_fragments(&path_stream.sub_output_stream);

let eq = fields_eq(
&out_record_definition.record(),
variant.data().filter_map(|d| {
let datum = &leaf_record_definition[d];
if !self.fields.iter().any(|f| f.name() == datum.name()) {
Some(datum.name())
} else {
None
}
}),
self.by_fields.iter().map(ValidFieldName::name),
);
let eq_ident = self.identifier_for("eq");
let eq_preamble = quote! { let #eq_ident = #eq; };
Expand Down
2 changes: 1 addition & 1 deletion quirky_binder/src/filter/hof/index/wordlist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use crate::{
( < extracted
- to_lowercase#lowercase_token(fields: ["{{token_field}}"])
- sort#sort_token(fields: ["{{token_field}}", "{{reference_field}}"])
- group#group(fields: ["{{reference_field}}"], group_field: "{{refs_field}}")
- group#group(by_fields: ["{{token_field}}"], group_field: "{{refs_field}}")
-> case_insensitive
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use quirky_binder::{
Ok(())
"#,
)
- group(group_field: "group", fields: ["num"])
- group(by_fields: [], group_field: "group")
- sub_dedup(path_fields: ["group"])
- function_terminate(
body: r#"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ use quirky_binder::{
Ok(())
"#,
)
- group(group_field: "group", fields: ["num"])
- sub_group(path_fields: ["group"], group_field: "sub_group", fields: ["num"])
- group(by_fields: [], group_field: "group")
- sub_group(path_fields: ["group"], by_fields: [], group_field: "sub_group")
- sub_dedup(path_fields: ["group", "sub_group"])
- function_terminate(
body: r#"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use quirky_binder::{
"#,
)
- sort(fields: ["lsb2"])
- group(group_field: "group", fields: ["num"])
- group(by_fields: ["lsb2"], group_field: "group")
- debug()
- function_terminate(
body: r#"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ use quirky_binder::{
"#,
)
- sort(fields: ["lsb2", "lsb4", "lsb6"])
- group(group_field: "sub_sub_group", fields: ["num"])
- group(group_field: "sub_group", fields: ["lsb6", "sub_sub_group"])
- group(group_field: "group", fields: ["lsb4", "sub_group"])
- group(by_fields: ["lsb2", "lsb4", "lsb6"], group_field: "sub_sub_group")
- group(by_fields: ["lsb2", "lsb4"], group_field: "sub_group")
- group(by_fields: ["lsb2"], group_field: "group")
- debug()
- function_terminate(
body: r#"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ use quirky_binder::{
"#,
)
- sort(fields: ["lsb2"])
- group(group_field: "group", fields: ["num", "lsb4"])
- group(by_fields: ["lsb2"], group_field: "group")
- sub_sort(path_fields: ["group"], fields: ["lsb4"])
- sub_group(path_fields: ["group"], group_field: "sub_group", fields: ["num"])
- sub_group(path_fields: ["group"], by_fields: ["lsb4"], group_field: "sub_group")
- debug()
- function_terminate(
body: r#"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ use quirky_binder::{
"#,
)
- sort(fields: ["lsb2", "lsb4", "lsb6"])
- group(group_field: "group", fields: ["num", "lsb4", "lsb6"])
- sub_group(path_fields: ["group"], group_field: "sub_group", fields: ["num", "lsb6"])
- sub_group(path_fields: ["group", "sub_group"], group_field: "sub_sub_group", fields: ["num"])
- group(by_fields: ["lsb2"], group_field: "group")
- sub_group(path_fields: ["group"], by_fields: ["lsb4"], group_field: "sub_group")
- sub_group(path_fields: ["group", "sub_group"], by_fields: ["lsb6"], group_field: "sub_sub_group")
- debug()
- function_terminate(
body: r#"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use quirky_binder::{
Ok(())
"#,
)
- group(group_field: "group", fields: ["num"])
- group(by_fields: [], group_field: "group")
- sub_sort(path_fields: ["group"], fields: ["num"])
- function_update(
body: r#"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ use quirky_binder::{
Ok(())
"#,
)
- group(group_field: "group", fields: ["num"])
- sub_group(path_fields: ["group"], group_field: "sub_group", fields: ["num"])
- group(by_fields: [], group_field: "group")
- sub_group(path_fields: ["group"], by_fields: [], group_field: "sub_group")
- sub_sort(path_fields: ["group", "sub_group"], fields: ["num"])
- function_update(
body: r#"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use quirky_binder::{
Ok(())
"#,
)
- group(group_field: "group", fields: ["value"])
- group(by_fields: [], group_field: "group")
- sub_reverse_chars(path_fields: ["group"], fields: ["value"])
- function_terminate(
body: r#"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use quirky_binder::{
Ok(())
"#,
)
- group(group_field: "group", fields: ["value"])
- sub_group(path_fields: ["group"], group_field: "sub_group", fields: ["value"])
- group(by_fields: [], group_field: "group")
- sub_group(path_fields: ["group"], by_fields: [], group_field: "sub_group")
- sub_reverse_chars(path_fields: ["group", "sub_group"], fields: ["value"])
- function_terminate(
body: r#"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use quirky_binder::{
Ok(())
"#,
)
- group(group_field: "group", fields: ["value"])
- group(by_fields: [], group_field: "group")
- sub_to_lowercase(path_fields: ["group"], fields: ["value"])
- function_terminate(
body: r#"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use quirky_binder::{
Ok(())
"#,
)
- group(group_field: "group", fields: ["value"])
- sub_group(path_fields: ["group"], group_field: "sub_group", fields: ["value"])
- group(by_fields: [], group_field: "group")
- sub_group(path_fields: ["group"], by_fields: [], group_field: "sub_group")
- sub_to_lowercase(path_fields: ["group", "sub_group"], fields: ["value"])
- function_terminate(
body: r#"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ use quirky_binder::{
"#,
)
- sort(fields: ["lsb2"])
- group(group_field: "group", fields: ["num", "lsb4"])
- group(by_fields: ["lsb2"], group_field: "group")
- sub_sort(path_fields: ["group"], fields: ["lsb4"])
- sub_group(path_fields: ["group"], group_field: "sub_group", fields: ["num"])
- sub_group(path_fields: ["group"], by_fields: ["lsb4"], group_field: "sub_group")
- sub_ungroup(path_fields: ["group"], group_field: "sub_group")
- function_terminate(
body: r#"
Expand Down
Loading

0 comments on commit 7b52759

Please sign in to comment.