Skip to content

Commit

Permalink
Fix empty stream to cuda graph -> result<lazy cuda graph> in oncecell
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Feb 7, 2024
1 parent e12887f commit 8f53b2a
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true }
# min-cl = { version = "0.3.0", optional=true }

[features]
default = ["cuda", "graph"]
default = ["cuda", "graph", "lazy"]
# default = ["cpu", "lazy", "static-api", "graph", "autograd", "fork", "serde", "json"]

std = []
Expand Down
1 change: 1 addition & 0 deletions src/devices/cuda/api/error.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub type CudaResult<T> = std::result::Result<T, CudaErrorKind>;

#[derive(Clone, Copy)]
pub enum CudaErrorKind {
InvalidAllocSize,
InvalidDeviceIdx,
Expand Down
5 changes: 3 additions & 2 deletions src/devices/cuda/cuda_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use super::{
api::{
create_context, create_stream, cuInit, cuStreamDestroy,
cublas::{create_handle, cublasDestroy_v2, cublasSetStream_v2, CublasHandle},
device, Context, CudaIntDevice, FnHandle, Module, Stream,
device, Context, CudaErrorKind, CudaIntDevice, FnHandle, Module, Stream,
},
AsCudaCvoidPtr, CudaSource, KernelCache,
};
Expand All @@ -23,7 +23,8 @@ pub struct CudaDevice {
pub mem_transfer_stream: Stream,
pub handle: CublasHandle,
#[cfg(feature = "lazy")]
pub graph: core::cell::OnceCell<super::lazy::LazyCudaGraph>,
// TODO: remove result when get_or_try_init becomes stable
pub graph: core::cell::OnceCell<Result<super::lazy::LazyCudaGraph, CudaErrorKind>>,
}

impl CudaDevice {
Expand Down
14 changes: 10 additions & 4 deletions src/devices/cuda/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,16 @@ impl<Mods> crate::LazyRun for CUDA<Mods> {
let graph = self
.graph
// TODO: change to get_or_try_init when stable
.get_or_init(|| LazyCudaGraph::new(self.stream()).unwrap());

graph.launch(self.stream.0)?;
self.stream().sync()?;
// an error may occur if the stream was empty ig
.get_or_init(|| LazyCudaGraph::new(self.stream()));

match graph {
Ok(graph) => {
graph.launch(self.stream.0)?;
self.stream().sync()?;
}
Err(e) => return Err((*e).into()),
}
Ok(())
}
}
Expand Down
1 change: 0 additions & 1 deletion src/devices/opencl/cl_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ impl<SimpleMods> OpenCL<SimpleMods> {
NewMods::setup(&mut opencl)?;
Ok(opencl)
}

}

impl<Mods> OpenCL<Mods> {
Expand Down
4 changes: 1 addition & 3 deletions src/modules/graph/opt_graph/optimize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ mod tests {
let out: Buffer<f32, _> = device.retrieve::<2>(1000, (&mul, &mul_b));

device.optimize_mem_graph(device, None).unwrap();
unsafe { device.run().unwrap() };
let _err = unsafe { device.run() };

assert_eq!(squared.replace().id(), mul.replace().id());
assert_eq!(squared.replace().id(), out.replace().id());
Expand All @@ -538,7 +538,6 @@ mod tests {

#[cfg(feature = "opencl")]
#[cfg(feature = "lazy")]
#[cfg_attr(miri, ignore)]
#[test]
fn test_lazy_from_retrieve_sliced_chained_perf_example_optimize_cl() {
use crate::{Base, Graph, Lazy, OpenCL};
Expand All @@ -548,7 +547,6 @@ mod tests {
}
#[cfg(feature = "cuda")]
#[cfg(feature = "lazy")]
#[cfg_attr(miri, ignore)]
#[test]
fn test_lazy_from_retrieve_sliced_chained_perf_example_optimize_cu() {
use crate::{Base, Graph, Lazy, CUDA};
Expand Down
2 changes: 1 addition & 1 deletion src/modules/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ mod tests {
// device.add_op(buf, |buf| {
// Ok(())
// });

// device.add_op(buf.no_id(), |buf| {
// Ok(())
// });
Expand Down
2 changes: 1 addition & 1 deletion tests/graph/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ fn test_graph_cl() -> custos::Result<()> {
#[cfg(feature = "cuda")]
#[test]
fn test_graph_cu() -> custos::Result<()> {
use custos::{Cached, CUDA, Cursor};
use custos::{Cached, Cursor, CUDA};

let device = CUDA::<Graph<Cached<Base>>>::new(0)?;

Expand Down

0 comments on commit 8f53b2a

Please sign in to comment.