diff --git a/Cargo.toml b/Cargo.toml index 841314f9..850681e2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,6 +41,7 @@ nnapi = { version = "0.2", optional = true } ndk-sys = {version = "0.4", features=["test"], optional = true} serde = { version = "1.0", features = ["derive"], optional = true } +serde_json = { version = "1", optional = true } [build-dependencies] #min-cl = { path="../min-cl", optional=true } @@ -49,7 +50,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true } [features] # default = ["cpu", "autograd", "macro"] -default = ["cpu", "lazy", "static-api", "opencl", "graph", "autograd", "fork"] +default = ["cpu", "lazy", "static-api", "opencl", "graph", "autograd", "fork", "serde", "json"] cpu = [] opencl = ["dep:min-cl", "cpu", "cached"] @@ -76,6 +77,7 @@ vulkan = ["dep:ash", "dep:naga", "wgsl"] wgsl = [] serde = ["dep:serde"] +json = ["dep:serde_json"] [dev-dependencies] custos-macro = {git = "https://github.com/elftausend/custos-macro"} diff --git a/src/cache/location_hasher.rs b/src/cache/location_hasher.rs index e59130bd..db2e0cdd 100644 --- a/src/cache/location_hasher.rs +++ b/src/cache/location_hasher.rs @@ -35,6 +35,7 @@ impl core::hash::Hasher for LocationHasher { } #[derive(Debug, Clone, Copy, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct HashLocation<'a> { pub file: &'a str, pub line: u32, diff --git a/src/modules/fork.rs b/src/modules/fork.rs index d4d14f4c..cfca302c 100644 --- a/src/modules/fork.rs +++ b/src/modules/fork.rs @@ -13,7 +13,10 @@ mod use_gpu_or_cpu; pub use analyzation::Analyzation; pub use use_gpu_or_cpu::*; + +#[cfg_attr(feature = "serde", derive(serde::Serialize))] pub struct Fork { + #[cfg_attr(feature = "serde", serde(skip))] pub modules: Mods, pub gpu_or_cpu: RefCell< HashMap, BinaryHeap, BuildHasherDefault>, @@ -52,6 +55,38 @@ impl, D: Device> Module for Fork { } } } + +#[cfg(all(feature = "serde", not(feature = "no-std")))] +impl Fork { + #[inline] + pub fn load_from_deserializer>(&mut self, deserializer: D) -> Result<(), D::Error> { + use serde::Deserialize; + + self.gpu_or_cpu = RefCell::new(HashMap::deserialize(deserializer)?); + Ok(()) + } + + #[cfg(feature = "json")] + #[inline] + pub fn save_as_json(&self, path: impl AsRef) -> crate::Result<()> { + serde_json::to_writer(std::fs::File::open(path)?, self)?; + Ok(()) + } + + #[cfg(feature = "json")] + #[inline] + pub fn load_from_json_read(&mut self, reader: impl std::io::Read) -> serde_json::Result<()> { + self.load_from_deserializer(&mut serde_json::Deserializer::from_reader(reader)) + } + + #[cfg(feature = "json")] + #[inline] + pub fn load_from_json(&mut self, path: impl AsRef) -> crate::Result<()> { + self.load_from_json_read(std::fs::File::open(path)?)?; + Ok(()) + } +} + pub trait ForkSetup { #[inline] fn fork_setup(&mut self) {} @@ -394,4 +429,36 @@ mod tests { assert_eq!(&analyzations[0].input_lengths, &[6]); } } + + #[cfg(feature = "json")] + #[cfg(feature = "serde")] + #[cfg(feature = "cpu")] + #[test] + fn test_fork_deserialize() { + use serde::Serialize; + + let device = OpenCL::>>::new(0).unwrap(); + if !device.unified_mem() { + return; + } + let buf = device.buffer([1, 2, 4, 5, 6, 7]); + let out = device.apply_fn(&buf, |x| x.add(3)); + assert_eq!(out.read(), [4, 5, 7, 8, 9, 10]); + + for _ in 0..100 { + let _out = device.apply_fn(&buf, |x| x.add(3)); + let gpu_or_cpu = device.modules.gpu_or_cpu.borrow(); + let (_, operations) = gpu_or_cpu.iter().next().unwrap(); + assert_eq!(operations.len(), 2); + let analyzations = operations.iter().cloned().collect::>(); + assert_eq!(&analyzations[0].input_lengths, &[6]); + } + + let mut json = vec![0,]; + let mut serializer = serde_json::Serializer::new(&mut json); + device.modules.serialize(&mut serializer).unwrap(); + println!("json: {json:?}"); + // device.modules.save_as_json(".") + } + } diff --git a/src/modules/fork/analyzation.rs b/src/modules/fork/analyzation.rs index 076b1eaa..563886de 100644 --- a/src/modules/fork/analyzation.rs +++ b/src/modules/fork/analyzation.rs @@ -1,6 +1,7 @@ use core::time::Duration; #[derive(Debug, PartialEq, Eq, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Analyzation { pub input_lengths: Vec, pub output_lengths: Vec,