Skip to content

Commit

Permalink
Support CSV headers
Browse files Browse the repository at this point in the history
  • Loading branch information
arnodb committed Nov 17, 2024
1 parent ce0ecc9 commit b2dea46
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 2 deletions.
50 changes: 49 additions & 1 deletion contrib/quirky_binder_csv/src/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ pub struct ReadCsvParams<'a> {
fields: TypedFieldsParam<'a>,
order_fields: Option<DirectedFieldsParam<'a>>,
distinct_fields: Option<FieldsParam<'a>>,
#[serde(default)]
has_headers: bool,
}

#[derive(Getters)]
Expand All @@ -21,6 +23,8 @@ pub struct ReadCsv {
#[getset(get = "pub")]
outputs: [NodeStream; 1],
input_file: String,
fields: Vec<(ValidFieldName, ValidFieldType)>,
has_headers: bool,
}

impl ReadCsv {
Expand Down Expand Up @@ -88,6 +92,8 @@ impl ReadCsv {
inputs,
outputs,
input_file: params.input_file.to_owned(),
fields: valid_fields,
has_headers: params.has_headers,
})
}
}
Expand Down Expand Up @@ -115,6 +121,43 @@ impl DynNode for ReadCsv {

let input_file = &self.input_file;

let has_headers = self.has_headers;

let read_headers = if has_headers {
let header_checks = self
.fields
.iter()
.enumerate()
.map(|(index, (field_name, _))| {
let field_name = field_name.name();
quote! {
let header = iter.next();
if let Some(header) = header {
if header != #field_name {
return Err(QuirkyBinderError::Custom(format!(
"Header mismatch at position {}, expected {} but got {}",
#index, #field_name, header)));
}
} else {
return Err(QuirkyBinderError::Custom(format!(
"Missing header at position {}, expected {}",
#index, #field_name)));
}
}
});
Some(quote! {{
let headers = reader.headers()
.map_err(|err| QuirkyBinderError::Custom(err.to_string()))?;
let mut iter = headers.into_iter();
#(#header_checks)*
if let Some(header) = iter.next() {
return Err(QuirkyBinderError::Custom(format!("Unexpected extra header {}", header)));
};
}})
} else {
None
};

let thread_body = quote! {
let output = thread_control.output_0.take().expect("output 0");
move || {
Expand All @@ -123,15 +166,20 @@ impl DynNode for ReadCsv {

let file = File::open(#input_file)
.map_err(|err| QuirkyBinderError::Custom(err.to_string()))?;

let mut reader = csv::ReaderBuilder::new()
.has_headers(false)
.has_headers(#has_headers)
.from_reader(BufReader::new(file));

#read_headers

for result in reader.deserialize() {
let record = result
.map_err(|err| QuirkyBinderError::Custom(err.to_string()))?;
output.send(Some(record))?;
}
output.send(None)?;

Ok(())
}
};
Expand Down
27 changes: 26 additions & 1 deletion contrib/quirky_binder_csv/src/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use truc::record::type_resolver::TypeResolver;
#[serde(deny_unknown_fields)]
pub struct WriteCsvParams<'a> {
output_file: &'a str,
#[serde(default)]
has_headers: bool,
}

#[derive(Getters)]
Expand All @@ -16,6 +18,7 @@ pub struct WriteCsv {
#[getset(get = "pub")]
outputs: [NodeStream; 0],
output_file: String,
has_headers: bool,
}

impl WriteCsv {
Expand All @@ -31,6 +34,7 @@ impl WriteCsv {
inputs,
outputs: [],
output_file: params.output_file.to_owned(),
has_headers: params.has_headers,
})
}
}
Expand Down Expand Up @@ -73,6 +77,20 @@ impl DynNode for WriteCsv {

let output_file = &self.output_file;

let has_headers = self.has_headers;

let write_headers = if has_headers {
let record_definition = &graph.record_definitions()[self.inputs.single().record_type()];
let variant = &record_definition[self.inputs.single().variant_id()];
let headers = variant.data().map(|d| record_definition[d].name());
Some(quote! {{
writer.write_record([#(#headers),*])
.map_err(|err| QuirkyBinderError::Custom(err.to_string()))?;
}})
} else {
None
};

let thread_body = quote! {
#(
#inputs
Expand All @@ -91,11 +109,18 @@ impl DynNode for WriteCsv {
}
let file = File::create(file_path)
.map_err(|err| QuirkyBinderError::Custom(err.to_string()))?;
let mut writer = csv::Writer::from_writer(file);

let mut writer = csv::WriterBuilder::new()
.has_headers(false)
.from_writer(file);

#write_headers

while let Some(record) = input.next()? {
writer.serialize(record)
.map_err(|err| QuirkyBinderError::Custom(err.to_string()))?;
}

Ok(())
}
};
Expand Down
1 change: 1 addition & 0 deletions tests/quirky_binder_tests/input/hello_universe.csv
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
hello,universe
world,42
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use quirky_binder_csv::read_csv;
read_csv(
input_file: "input/hello_universe.csv",
fields: [("hello", "String"), ("universe", "usize")],
has_headers: true,
)
- function_terminate(
body: r#"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use quirky_binder_csv::write_csv;
)
- write_csv(
output_file: "output/hello_universe.csv",
has_headers: true,
)
)
}

0 comments on commit b2dea46

Please sign in to comment.