Skip to content

Commit

Permalink
Get rid of codegen (which reindents rust raw strings)
Browse files Browse the repository at this point in the history
  • Loading branch information
arnodb committed Feb 7, 2025
1 parent 4689646 commit 6f7d56a
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 58 deletions.
1 change: 0 additions & 1 deletion quirky_binder/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ repository = "https://github.com/arnodb/quirky_binder"
readme = "../README.md"

[dependencies]
codegen = "0.1"
derive-new = "0.5"
derive_more = "0.99"
getset = "0.1"
Expand Down
75 changes: 31 additions & 44 deletions quirky_binder/src/chain/mod.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
use std::{borrow::Cow, collections::HashMap, fmt::Display};

use codegen::{Module, Scope};
use itertools::Itertools;
use proc_macro2::TokenStream;
use quirky_binder_lang::location::Location;
use serde::Deserialize;
use syn::{Ident, Type};

use self::error::{ChainError, ChainErrorWithTrace};
use crate::prelude::*;
use crate::{codegen::Module, prelude::*};

pub mod error;

Expand Down Expand Up @@ -105,7 +104,7 @@ impl ChainSourceThread {
#[derive(new)]
pub struct Chain<'a> {
customizer: &'a ChainCustomizer,
scope: &'a mut Scope,
module: &'a mut Module,
#[new(default)]
threads: Vec<ChainThread>,
#[new(default)]
Expand Down Expand Up @@ -192,7 +191,7 @@ impl<'a> Chain<'a> {
output_pipes,
});
let name = format!("thread_{}", thread_id);
let module = self.scope.new_module(&name).vis("pub").scope();
let module = self.module.get_or_new_module(&name);
for (path, ty) in &self.customizer.custom_module_imports {
module.import(path, ty);
}
Expand Down Expand Up @@ -296,11 +295,7 @@ impl<'a> Chain<'a> {
let pipe = self.new_pipe();
let thread = &mut self.threads[source_thread.thread_id];
let name = format!("thread_{}", source_thread.thread_id);
let scope = self
.scope
.get_module_mut(&name)
.expect("thread module")
.scope();
let module = self.module.get_module(&name).expect("thread module");
let mut import_scope = ImportScope::default();
if thread.output_pipes.is_none() {
assert_eq!(thread.output_streams.len(), 1);
Expand All @@ -324,13 +319,13 @@ impl<'a> Chain<'a> {
}
}
};
scope.raw(&pipe_def.to_string());
module.fragment(pipe_def.to_string());
}

thread.output_pipes = Some(Box::new([pipe]));
thread.main = Some(FullyQualifiedName::new(name).sub("quirky_binder_pipe"));
}
import_scope.import(scope);
import_scope.import(module);
pipe
}

Expand Down Expand Up @@ -362,21 +357,17 @@ impl<'a> Chain<'a> {
pub fn gen_chain(&mut self) {
for thread in &self.threads {
let name = format!("thread_{}", thread.id);
let scope = self
.scope
.get_module_mut(&name)
.expect("thread module")
.scope();
scope.import("std::sync", "Arc");
scope.import(
let module = self.module.get_module(&name).expect("thread module");
module.import("std::sync", "Arc");
module.import(
"quirky_binder_support::chain::configuration",
"ChainConfiguration",
);
if thread.input_streams.len() > 0 {
scope.import("std::sync::mpsc", "Receiver");
module.import("std::sync::mpsc", "Receiver");
}
if thread.output_pipes.is_some() && thread.output_streams.len() > 0 {
scope.import("std::sync::mpsc", "SyncSender");
module.import("std::sync::mpsc", "SyncSender");
}
let inputs = (0..thread.input_streams.len()).map(|i| format_ident!("input_{}", i));
let input_types = thread.input_streams.iter().map(|input_stream| {
Expand Down Expand Up @@ -421,7 +412,7 @@ impl<'a> Chain<'a> {
}

};
scope.raw(&struct_def.to_string());
module.fragment(struct_def.to_string());
}

{
Expand Down Expand Up @@ -556,8 +547,8 @@ impl<'a> Chain<'a> {
}
});

self.scope.import("std::sync", "Arc");
self.scope.import(
self.module.import("std::sync", "Arc");
self.module.import(
"quirky_binder_support::chain::configuration",
"ChainConfiguration",
);
Expand Down Expand Up @@ -595,35 +586,34 @@ impl<'a> Chain<'a> {
Ok(())
}
};
self.scope.raw(&main_def.to_string());
self.module.fragment(main_def.to_string());
}
}

fn get_or_new_module_scope<'i>(
fn get_or_new_module<'i>(
&mut self,
path: impl IntoIterator<Item = &'i Box<str>>,
chain_customizer: &ChainCustomizer,
thread_id: usize,
) -> &mut Scope {
) -> &mut Module {
let mut iter = path.into_iter();
let customize_module = |module: &mut Module| {
for (path, ty) in &chain_customizer.custom_module_imports {
module.scope().import(path, ty);
module.import(path, ty);
}
let thread_module = format!("thread_{}", thread_id);
module.scope().import("super", &thread_module).vis("pub");
module.import("super", &thread_module);
};
if let Some(first) = iter.next() {
let module = self.scope.get_or_new_module(first);
let module = self.module.get_or_new_module(first);
(customize_module)(module);
iter.fold(module, |m, n| {
let module = m.get_or_new_module(n).vis("pub");
let module = m.get_or_new_module(n);
(customize_module)(module);
module
})
.scope()
} else {
self.scope
self.module
}
}

Expand Down Expand Up @@ -658,15 +648,15 @@ impl<'a> Chain<'a> {
let mut import_scope = ImportScope::default();
import_scope.add_import("fallible_iterator", "FallibleIterator");

let scope = self.get_or_new_module_scope(
let module = self.get_or_new_module(
name.iter().take(name.len() - 1),
self.customizer,
thread.thread_id,
);

scope.raw(&fn_def.to_string());
module.fragment(fn_def.to_string());

import_scope.import(scope);
import_scope.import(module);
}

pub fn implement_node_thread(
Expand All @@ -692,15 +682,12 @@ impl<'a> Chain<'a> {
import_scope.add_import("fallible_iterator", "FallibleIterator");
}

let scope = self.get_or_new_module_scope(
name.iter().take(name.len() - 1),
self.customizer,
thread_id,
);
let module =
self.get_or_new_module(name.iter().take(name.len() - 1), self.customizer, thread_id);

scope.raw(&fn_def.to_string());
module.fragment(fn_def.to_string());

import_scope.import(scope);
import_scope.import(module);
}

pub fn implement_path_update(
Expand Down Expand Up @@ -841,9 +828,9 @@ impl ImportScope {
self.fixed.push((path.to_string(), ty.to_string()));
}

pub fn import(mut self, scope: &mut Scope) {
pub fn import(mut self, module: &mut Module) {
for (path, ty) in &self.fixed {
scope.import(path, ty);
module.import(path, ty);
}
self.used = true;
}
Expand Down
57 changes: 57 additions & 0 deletions quirky_binder/src/codegen.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use std::{collections::BTreeSet, fmt::Display};

#[derive(Default)]
pub struct Module {
imports: BTreeSet<String>,
modules: Vec<(String, Module)>,
fragments: Vec<String>,
}

impl Module {
pub fn import(&mut self, path: impl AsRef<str>, ty: impl AsRef<str>) {
self.imports
.insert(format!("{}::{}", path.as_ref(), ty.as_ref()));
}

pub fn get_module(&mut self, name: &str) -> Option<&mut Module> {
self.modules
.iter_mut()
.find_map(|(n, m)| (n == name).then_some(m))
}

pub fn get_or_new_module(&mut self, name: &str) -> &mut Module {
let pos = self.modules.iter_mut().position(|(n, _m)| n == name);
if let Some(pos) = pos {
&mut self.modules[pos].1
} else {
self.modules.push((name.to_owned(), Module::default()));
&mut self.modules.last_mut().unwrap().1
}
}

pub fn fragment(&mut self, fragment: impl Into<String>) {
self.fragments.push(fragment.into());
}
}

impl Display for Module {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if !self.imports.is_empty() {
for import in &self.imports {
writeln!(f, "use {import};")?;
}
f.write_str("\n")?;
}
for (name, module) in &self.modules {
writeln!(f, "pub mod {name} {{")?;
module.fmt(f)?;
writeln!(f, "}}")?;
f.write_str("\n")?;
}
for fragment in &self.fragments {
f.write_str(fragment)?;
f.write_str("\n")?;
}
Ok(())
}
}
25 changes: 12 additions & 13 deletions quirky_binder/src/graph/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::{collections::BTreeMap, fs::File, path::Path};

use codegen::Scope;
use truc::{
generator::{
config::GeneratorConfig,
Expand All @@ -9,7 +8,7 @@ use truc::{
record::definition::RecordDefinition,
};

use crate::prelude::*;
use crate::{codegen::Module, prelude::*};

pub mod builder;
pub mod error;
Expand Down Expand Up @@ -64,35 +63,35 @@ impl Graph {

{
let mut file = File::create(output.join("streams.rs")).unwrap();
let mut scope = Scope::new();
let mut root_module = Module::default();
for (record_type, definition) in &self.record_definitions {
let module = scope.get_or_new_module(&record_type[0]).vis("pub");
let module = record_type
.iter()
.skip(1)
.fold(module, |m, n| m.get_or_new_module(n).vis("pub"))
.scope();
module.raw(&truc::generator::generate(
.fold(root_module.get_or_new_module(&record_type[0]), |m, n| {
m.get_or_new_module(n)
});
module.fragment(truc::generator::generate(
definition,
&GeneratorConfig::default_with_custom_generators([
Box::new(CloneImplGenerator) as Box<dyn FragmentGenerator>,
Box::new(SerdeImplGenerator) as Box<dyn FragmentGenerator>,
]),
));
}
write!(file, "{}", scope.to_string()).unwrap();
write!(file, "{}", root_module).unwrap();
}
rustfmt_generated_file(output.join("streams.rs").as_path());

{
let mut scope = Scope::new();
let mut root_module = Module::default();
for (path, ty) in &self.chain_customizer.custom_module_imports {
scope.import(path, ty);
root_module.import(path, ty);
}

scope.raw("mod streams;");
root_module.fragment("mod streams;");

let mut chain = Chain::new(&self.chain_customizer, &mut scope);
let mut chain = Chain::new(&self.chain_customizer, &mut root_module);

for node in &self.entry_nodes {
node.gen_chain(self, &mut chain);
Expand All @@ -101,7 +100,7 @@ impl Graph {
chain.gen_chain();

let mut file = File::create(output.join("chain.rs")).unwrap();
write!(file, "{}", scope.to_string()).unwrap();
write!(file, "{}", root_module).unwrap();
}
rustfmt_generated_file(output.join("chain.rs").as_path());

Expand Down
1 change: 1 addition & 0 deletions quirky_binder/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ extern crate quote;
extern crate thiserror;

pub mod chain;
pub mod codegen;
pub mod drawing;
pub mod filter;
pub mod graph;
Expand Down

0 comments on commit 6f7d56a

Please sign in to comment.