Skip to content

Commit

Permalink
Add serde_json dep, impl Serialize,Deserialize for Analyization, Hash…
Browse files Browse the repository at this point in the history
…Location, impl load,save fns for fork
  • Loading branch information
elftausend committed Jan 6, 2024
1 parent d293992 commit 159570f
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 1 deletion.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ nnapi = { version = "0.2", optional = true }
ndk-sys = {version = "0.4", features=["test"], optional = true}

serde = { version = "1.0", features = ["derive"], optional = true }
serde_json = { version = "1", optional = true }

[build-dependencies]
#min-cl = { path="../min-cl", optional=true }
Expand All @@ -49,7 +50,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true }

[features]
# default = ["cpu", "autograd", "macro"]
default = ["cpu", "lazy", "static-api", "opencl", "graph", "autograd", "fork"]
default = ["cpu", "lazy", "static-api", "opencl", "graph", "autograd", "fork", "serde", "json"]

cpu = []
opencl = ["dep:min-cl", "cpu", "cached"]
Expand All @@ -76,6 +77,7 @@ vulkan = ["dep:ash", "dep:naga", "wgsl"]
wgsl = []

serde = ["dep:serde"]
json = ["dep:serde_json"]

[dev-dependencies]
custos-macro = {git = "https://github.com/elftausend/custos-macro"}
Expand Down
1 change: 1 addition & 0 deletions src/cache/location_hasher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ impl core::hash::Hasher for LocationHasher {
}

#[derive(Debug, Clone, Copy, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct HashLocation<'a> {
pub file: &'a str,
pub line: u32,
Expand Down
67 changes: 67 additions & 0 deletions src/modules/fork.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ mod use_gpu_or_cpu;
pub use analyzation::Analyzation;
pub use use_gpu_or_cpu::*;


#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct Fork<Mods> {
#[cfg_attr(feature = "serde", serde(skip))]
pub modules: Mods,
pub gpu_or_cpu: RefCell<
HashMap<HashLocation<'static>, BinaryHeap<Analyzation>, BuildHasherDefault<LocationHasher>>,
Expand Down Expand Up @@ -52,6 +55,38 @@ impl<Mods: Module<D>, D: Device> Module<D> for Fork<Mods> {
}
}
}

#[cfg(all(feature = "serde", not(feature = "no-std")))]
impl<Mods> Fork<Mods> {
#[inline]
pub fn load_from_deserializer<D: serde::Deserializer<'static>>(&mut self, deserializer: D) -> Result<(), D::Error> {
use serde::Deserialize;

self.gpu_or_cpu = RefCell::new(HashMap::deserialize(deserializer)?);
Ok(())
}

#[cfg(feature = "json")]
#[inline]
pub fn save_as_json(&self, path: impl AsRef<std::path::Path>) -> crate::Result<()> {
serde_json::to_writer(std::fs::File::open(path)?, self)?;
Ok(())
}

#[cfg(feature = "json")]
#[inline]
pub fn load_from_json_read(&mut self, reader: impl std::io::Read) -> serde_json::Result<()> {
self.load_from_deserializer(&mut serde_json::Deserializer::from_reader(reader))
}

#[cfg(feature = "json")]
#[inline]
pub fn load_from_json(&mut self, path: impl AsRef<std::path::Path>) -> crate::Result<()> {
self.load_from_json_read(std::fs::File::open(path)?)?;
Ok(())
}
}

pub trait ForkSetup {
#[inline]
fn fork_setup(&mut self) {}
Expand Down Expand Up @@ -394,4 +429,36 @@ mod tests {
assert_eq!(&analyzations[0].input_lengths, &[6]);
}
}

#[cfg(feature = "json")]
#[cfg(feature = "serde")]
#[cfg(feature = "cpu")]
#[test]
fn test_fork_deserialize() {
use serde::Serialize;

let device = OpenCL::<Fork<Cached<Base>>>::new(0).unwrap();
if !device.unified_mem() {
return;
}
let buf = device.buffer([1, 2, 4, 5, 6, 7]);
let out = device.apply_fn(&buf, |x| x.add(3));
assert_eq!(out.read(), [4, 5, 7, 8, 9, 10]);

for _ in 0..100 {
let _out = device.apply_fn(&buf, |x| x.add(3));
let gpu_or_cpu = device.modules.gpu_or_cpu.borrow();
let (_, operations) = gpu_or_cpu.iter().next().unwrap();
assert_eq!(operations.len(), 2);
let analyzations = operations.iter().cloned().collect::<Vec<Analyzation>>();
assert_eq!(&analyzations[0].input_lengths, &[6]);
}

let mut json = vec![0,];
let mut serializer = serde_json::Serializer::new(&mut json);
device.modules.serialize(&mut serializer).unwrap();
println!("json: {json:?}");
// device.modules.save_as_json(".")
}

}
1 change: 1 addition & 0 deletions src/modules/fork/analyzation.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use core::time::Duration;

#[derive(Debug, PartialEq, Eq, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Analyzation {
pub input_lengths: Vec<usize>,
pub output_lengths: Vec<usize>,
Expand Down

0 comments on commit 159570f

Please sign in to comment.