Skip to content

Commit

Permalink
CUDA as extension
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander-Barth committed Nov 15, 2024
1 parent 0268403 commit efaf915
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 17 deletions.
7 changes: 7 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ NCDatasets = "0.11, 0.12, 0.13, 0.14"
ThreadsX = "0.1"
julia = "1"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[extensions]
CUDAExt = "CUDA"


[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Expand Down
20 changes: 20 additions & 0 deletions ext/CUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
module CUDAExt

import DINCAE: interpnd!, interp_adjn!, _to_device
using CUDA
using Flux


function interpnd!(pos::AbstractVector{<:NTuple{N}},cuA::CuArray,cuvec) where N
@cuda interpnd!(pos,cuA,cuvec)
end

function interp_adjn!(pos::AbstractVector{<:NTuple{N}},cuvalues::CuArray,cuA2) where N
@cuda interp_adjn!(pos,cuvalues,cuA2)
end

@inline function _to_device(::Type{Atype}) where Atype <: CuArray
return Flux.gpu
end

end
1 change: 0 additions & 1 deletion src/DINCAE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ The code is available at:
"""
module DINCAE
using Base.Threads
using CUDA
using Dates
using Random
using NCDatasets
Expand Down
6 changes: 0 additions & 6 deletions src/flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,6 @@ weights(m::Chain) = Flux.params(m)

weights(c::Conv) = [c.weight]



@inline function _to_device(::Type{Atype}) where Atype <: CuArray
return Flux.gpu
end

@inline function _to_device(::Type{Atype}) where Atype <: AbstractArray
return Flux.cpu
end
Expand Down
12 changes: 2 additions & 10 deletions src/points.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ function interpnd!(pos::AbstractVector{<:NTuple{N}},A,vec) where N
return nothing
end

function interpnd!(pos::AbstractVector{<:NTuple{N}},cuA::CuArray,cuvec) where N
@cuda DINCAE.interpnd!(pos,cuA,cuvec)
end

#interpnd(pos,A) = interpnd!(pos,A,zeros(eltype(A),length(pos)))
function interpnd(pos,A)
vec = similar(A,length(pos))
Expand Down Expand Up @@ -75,10 +71,6 @@ function interp_adjn!(pos::AbstractVector{<:NTuple{N}},values,A2) where N
return nothing
end

function interp_adjn!(pos::AbstractVector{<:NTuple{N}},cuvalues::CuArray,cuA2) where N
@cuda interp_adjn!(pos,cuvalues,cuA2)
end

function interp_adjn(pos::AbstractVector{<:NTuple{N}},values,sz::NTuple{N,Int}) where N
A2 = similar(values,sz)
interp_adjn!(pos,values,A2)
Expand Down Expand Up @@ -364,7 +356,7 @@ function costfun(
xrec,xtrue::Vector{NamedTuple{(:pos, :x),Tuple{Tpos,TA}}},truth_uncertain;
laplacian_penalty = 0,
laplacian_error_penalty = laplacian_penalty,
) where TA <: Union{Array{T,N},CuArray{T,N}} where Tpos <: AbstractVector{NTuple{N,T}} where {N,T}
) where TA <: Union{AbstractArray{T,N}} where Tpos <: AbstractVector{NTuple{N,T}} where {N,T}

#@show typeof(xin)
#@show typeof(xrec)
Expand Down Expand Up @@ -463,7 +455,7 @@ end
Mandatory parameters:
* `T`: `Float32` or `Float64`: float-type used by the neural network
* `Array{T}` or `CuArray{T}`: array-type used by the neural network.
* `Array{T}`, `CuArray{T}`,...: array-type used by the neural network.
* `filename`: NetCDF file in the format described below.
* `varname`: name of the primary variable in the NetCDF file.
* `grid`: tuple of ranges with the grid in the longitude and latitude direction e.g. `(-180:1:180,-90:1:90)`.
Expand Down

0 comments on commit efaf915

Please sign in to comment.