From ed3d7c87d1e3c26856ff377526546e4cd985474a Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Wed, 23 Oct 2024 16:42:56 +0200 Subject: [PATCH 1/2] fix primitive_type for complex --- src/XLA.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/XLA.jl b/src/XLA.jl index 556b6ff3..7f48812d 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -233,8 +233,8 @@ end @inline primitive_type(::Type{Float64}) = 12 -@inline primitive_type(::Type{Complex{Float32}}) = 24 -@inline primitive_type(::Type{Complex{Float64}}) = 25 +@inline primitive_type(::Type{Complex{Float32}}) = 15 +@inline primitive_type(::Type{Complex{Float64}}) = 18 function ArrayFromHostBuffer(client::Client, array::Array{T,N}, device) where {T,N} sizear = Int64[s for s in reverse(size(array))] From c0d5a24c5190013e1bf2f4c337a8e7bcd91b4037 Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Wed, 23 Oct 2024 16:53:08 +0200 Subject: [PATCH 2/2] add simple complex runtime test --- test/basic.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/basic.jl b/test/basic.jl index b9095f4c..2e2d7c6a 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -438,3 +438,10 @@ end @test size(f(y)) == size(x) @test eltype(f(y)) == eltype(x) end + +@testset "Complex runtime: $CT" for CT in (ComplexF32, ComplexF64) + a = Reactant.to_rarray(ones(CT, 2)) + b = Reactant.to_rarray(ones(CT, 2)) + c = Reactant.compile(+, (a, b))(a, b) + @test c == ones(CT, 2) + ones(CT, 2) +end