diff --git a/Project.toml b/Project.toml index 8d49e2e..da096db 100644 --- a/Project.toml +++ b/Project.toml @@ -27,6 +27,13 @@ NCDatasets = "0.11, 0.12, 0.13, 0.14" ThreadsX = "0.1" julia = "1" +[weakdeps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + +[extensions] +CUDAExt = "CUDA" + + [extras] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl new file mode 100644 index 0000000..90beaf6 --- /dev/null +++ b/ext/CUDAExt.jl @@ -0,0 +1,20 @@ +module CUDAExt + +import DINCAE: interpnd!, interp_adjn!, _to_device +using CUDA +using Flux + + +function interpnd!(pos::AbstractVector{<:NTuple{N}},cuA::CuArray,cuvec) where N + @cuda interpnd!(pos,cuA,cuvec) +end + +function interp_adjn!(pos::AbstractVector{<:NTuple{N}},cuvalues::CuArray,cuA2) where N + @cuda interp_adjn!(pos,cuvalues,cuA2) +end + +@inline function _to_device(::Type{Atype}) where Atype <: CuArray + return Flux.gpu +end + +end diff --git a/src/DINCAE.jl b/src/DINCAE.jl index a066136..bee79de 100644 --- a/src/DINCAE.jl +++ b/src/DINCAE.jl @@ -29,7 +29,6 @@ The code is available at: """ module DINCAE using Base.Threads -using CUDA using Dates using Random using NCDatasets diff --git a/src/flux.jl b/src/flux.jl index df307b4..ad542e1 100644 --- a/src/flux.jl +++ b/src/flux.jl @@ -42,12 +42,6 @@ weights(m::Chain) = Flux.params(m) weights(c::Conv) = [c.weight] - - -@inline function _to_device(::Type{Atype}) where Atype <: CuArray - return Flux.gpu -end - @inline function _to_device(::Type{Atype}) where Atype <: AbstractArray return Flux.cpu end diff --git a/src/points.jl b/src/points.jl index a199111..bb49a86 100644 --- a/src/points.jl +++ b/src/points.jl @@ -25,10 +25,6 @@ function interpnd!(pos::AbstractVector{<:NTuple{N}},A,vec) where N return nothing end -function interpnd!(pos::AbstractVector{<:NTuple{N}},cuA::CuArray,cuvec) where N - @cuda DINCAE.interpnd!(pos,cuA,cuvec) -end - #interpnd(pos,A) = interpnd!(pos,A,zeros(eltype(A),length(pos))) function interpnd(pos,A) vec = similar(A,length(pos)) @@ -75,10 +71,6 @@ function interp_adjn!(pos::AbstractVector{<:NTuple{N}},values,A2) where N return nothing end -function interp_adjn!(pos::AbstractVector{<:NTuple{N}},cuvalues::CuArray,cuA2) where N - @cuda interp_adjn!(pos,cuvalues,cuA2) -end - function interp_adjn(pos::AbstractVector{<:NTuple{N}},values,sz::NTuple{N,Int}) where N A2 = similar(values,sz) interp_adjn!(pos,values,A2) @@ -364,7 +356,7 @@ function costfun( xrec,xtrue::Vector{NamedTuple{(:pos, :x),Tuple{Tpos,TA}}},truth_uncertain; laplacian_penalty = 0, laplacian_error_penalty = laplacian_penalty, - ) where TA <: Union{Array{T,N},CuArray{T,N}} where Tpos <: AbstractVector{NTuple{N,T}} where {N,T} + ) where TA <: Union{AbstractArray{T,N}} where Tpos <: AbstractVector{NTuple{N,T}} where {N,T} #@show typeof(xin) #@show typeof(xrec) @@ -463,7 +455,7 @@ end Mandatory parameters: * `T`: `Float32` or `Float64`: float-type used by the neural network -* `Array{T}` or `CuArray{T}`: array-type used by the neural network. +* `Array{T}`, `CuArray{T}`,...: array-type used by the neural network. * `filename`: NetCDF file in the format described below. * `varname`: name of the primary variable in the NetCDF file. * `grid`: tuple of ranges with the grid in the longitude and latitude direction e.g. `(-180:1:180,-90:1:90)`.