Skip to content

Commit

Permalink
delete references to Knet
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander-Barth committed Nov 15, 2024
1 parent 3b4422e commit 0268403
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 32 deletions.
3 changes: 0 additions & 3 deletions src/DINCAE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,7 @@ import Base: eltype
import Random: shuffle!

using Profile
#import Knet: KnetArray, AutoGrad
#import Knet

#include("knet.jl")
include("flux.jl")
include("types.jl")
include("data.jl")
Expand Down
1 change: 0 additions & 1 deletion src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ CatSkip(x) = SkipConnection(x,(mx,y) -> cat(mx, y, dims=Val(3)))
#using JLD2

# save inversion
#function sinv(x::Union{AbstractArray{T},KnetArray{T},AutoGrad.Result{<:AbstractArray{T}}}; minx = T(1e-3)) where T
function sinv(x, ; minx = eltype(x)(1e-3))
T = eltype(x)
return one(T) ./ max.(x,minx)
Expand Down
22 changes: 3 additions & 19 deletions src/points.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,7 @@ function interpnd!(pos::AbstractVector{<:NTuple{N}},A,vec) where N
#return vec
return nothing
end
#=
function interpnd!(pos::AbstractVector{<:NTuple{N}},A::KnetArray,vec) where N
cuA = CuArray(A)
cuvec = CuArray(vec)
@cuda DINCAE.interpnd!(pos,cuA,cuvec)
end
=#

function interpnd!(pos::AbstractVector{<:NTuple{N}},cuA::CuArray,cuvec) where N
@cuda DINCAE.interpnd!(pos,cuA,cuvec)
end
Expand All @@ -43,8 +37,6 @@ function interpnd(pos,A)
return vec
end

#Knet.AutoGrad.@primitive interpnd(pos,A),dy,y 0 interp_adjn(pos,dy,size(A))

function ChainRulesCore.rrule(::typeof(interpnd), pos::AbstractVector{<:NTuple{N}}, A) where N
function interpnd_pullback(dy)
dpos = similar(pos)
Expand Down Expand Up @@ -83,14 +75,6 @@ function interp_adjn!(pos::AbstractVector{<:NTuple{N}},values,A2) where N
return nothing
end

#=
function interp_adjn!(pos::AbstractVector{<:NTuple{N}},values::KnetArray,A2) where N
cuvalues = CuArray(values)
cuA2 = CuArray(A2)
@cuda interp_adjn!(pos,cuvalues,cuA2)
end
=#

function interp_adjn!(pos::AbstractVector{<:NTuple{N}},cuvalues::CuArray,cuA2) where N
@cuda interp_adjn!(pos,cuvalues,cuA2)
end
Expand Down Expand Up @@ -380,7 +364,7 @@ function costfun(
xrec,xtrue::Vector{NamedTuple{(:pos, :x),Tuple{Tpos,TA}}},truth_uncertain;
laplacian_penalty = 0,
laplacian_error_penalty = laplacian_penalty,
) where TA <: Union{Array{T,N},CuArray{T,N}#=,KnetArray{T,N}=#} where Tpos <: AbstractVector{NTuple{N,T}} where {N,T}
) where TA <: Union{Array{T,N},CuArray{T,N}} where Tpos <: AbstractVector{NTuple{N,T}} where {N,T}

#@show typeof(xin)
#@show typeof(xrec)
Expand Down Expand Up @@ -479,7 +463,7 @@ end
Mandatory parameters:
* `T`: `Float32` or `Float64`: float-type used by the neural network
* `Array{T}` or `KnetArray{T}`: array-type used by the neural network.
* `Array{T}` or `CuArray{T}`: array-type used by the neural network.
* `filename`: NetCDF file in the format described below.
* `varname`: name of the primary variable in the NetCDF file.
* `grid`: tuple of ranges with the grid in the longitude and latitude direction e.g. `(-180:1:180,-90:1:90)`.
Expand Down
9 changes: 0 additions & 9 deletions src/vector2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,9 @@ function vector2_costfun(xrec,xtrue,truth_uncertain,directionobs)

# interpolate to location of observations
# output should be the same shape as m_true
#obsoper(x::Union{KnetArray{T,2},AbstractArray{T,2}}) where T = repeat(x,(1,1,nsites))
#obsoper(x::AbstractArray{T,3}) where T = mapslices(obsoper,x,dims=(1,2))
#obsoper(x::AbstractArray{T,3}) where T = x
#function obsoper(x::Union{KnetArray{T,3},AbstractArray{T,3},CuArray{T,3},
# AutoGrad.Result{KnetArray{T,3}},AutoGrad.Result{CuArray{T,3}},AutoGrad.Result{Array{T,3}}}) where T
function obsoper(x)
tmp = reshape(x,(size(x)[1:2]...,1,size(x,3)))
#repeat(tmp,inner=(1,1,nsites,1))
# https://github.com/denizyuret/Knet.jl/issues/635
#tmp[:,:,ones(Int,nsites),:]
#cat(tmp,tmp,tmp,dims=3)
#cat((tmp for i = 1:nsites)...,dims=3)
cat([tmp for i = 1:nsites]...,dims=3)
end

Expand Down

0 comments on commit 0268403

Please sign in to comment.