Skip to content

Commit

Permalink
Merge pull request #71 from swfsql/cuda_gather_failure
Browse files Browse the repository at this point in the history
Cuda `GatherCompiler` fails on low dimensionality (failing test)
  • Loading branch information
jafioti authored Jul 3, 2024
2 parents 8d36e70 + 82c9833 commit 7136f54
Showing 1 changed file with 62 additions and 1 deletion.
63 changes: 62 additions & 1 deletion crates/luminal_cuda/src/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ impl<T: CudaFloat> Compiler for GatherCompiler<T> {
.as_data()
.unwrap()
.2;
let embed_dim = emb_shape.shape()[2].to_usize().unwrap();
let embed_dim = emb_shape.shape().last().unwrap().to_usize().unwrap();
let index_shape = graph
.edges_connecting(s.get(&indexes), s.get(&ind_copy))
.next()
Expand All @@ -396,3 +396,64 @@ impl<T: CudaFloat> Compiler for GatherCompiler<T> {
}
}
}

#[cfg(test)]
mod tests {
use super::*;
luminal::test_imports!();

type TR0 = GraphTensor<R0>;
type TR1<const A: usize> = GraphTensor<R1<A>>;
type TR2<const A: usize, const B: usize> = GraphTensor<R2<A, B>>;

#[test]
fn test_gather_compiler_r0() {
const CLASSES: usize = 2;
const TARGET: usize = 1;

let mut cx = Graph::new();
let mut input: TR0 = cx.tensor();
let embedder: TR2<CLASSES, TARGET> = cx.tensor();

let input_one_hot: TR1<CLASSES> = input
.graph()
.arange::<LConst<CLASSES>>()
.equals(input.expand());
let input_embedding: TR1<TARGET> = (input_one_hot.expand::<R2<CLASSES, TARGET>, _>()
* embedder)
.sum_reduce::<_, LAxis<0>>();
let mut loss: TR0 = input_embedding.sum_reduce();
let mut weights = vec![embedder.id];

cx.compile(
crate::CudaCompiler::<f32>::default(),
(&mut input, &mut loss, &mut weights),
);
}

#[test]
fn test_gather_compiler_r1() {
const CLASSES: usize = 2;
const TARGET: usize = 1;

let mut cx = Graph::new();
let mut input: TR1<1> = cx.tensor();
let embedder: TR2<CLASSES, TARGET> = cx.tensor();

let input_one_hot: TR2<1, CLASSES> = input
.graph()
.arange::<LConst<CLASSES>>()
.expand::<R2<1, CLASSES>, _>()
.equals(input.expand());
let input_embedding: TR2<1, TARGET> = (input_one_hot.expand::<R3<1, CLASSES, TARGET>, _>()
* embedder.expand())
.sum_reduce::<_, LAxis<1>>();
let mut loss: TR0 = input_embedding.sum_reduce();
let mut weights = vec![embedder.id];

cx.compile(
crate::CudaCompiler::<f32>::default(),
(&mut input, &mut loss, &mut weights),
);
}
}

0 comments on commit 7136f54

Please sign in to comment.