Skip to content

Commit

Permalink
clean up some things
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebeggs committed Aug 7, 2024
1 parent 8955122 commit a59546d
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 48 deletions.
5 changes: 3 additions & 2 deletions src/RadialBasisFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export AbstractPHS, PHS, PHS1, PHS3, PHS5, PHS7
export IMQ
export Gaussian
export MonomialBasis
export degree, dim

include("utils.jl")
export find_neighbors, reorder_points!
Expand All @@ -42,7 +43,7 @@ export ∂virtual

include("operators/monomial/monomial.jl")

include("operators/operator_combinations.jl")
include("operators/operator_algebra.jl")

include("interpolation.jl")
export Interpolator
Expand All @@ -52,7 +53,7 @@ export Regrid, regrid

# Some consts and aliases
const Δ = ∇² # some people like this notation for the Laplacian
const AVOID_NAN = 1e-16
const AVOID_INF = 1e-16

using PrecompileTools
@setup_workload begin
Expand Down
3 changes: 3 additions & 0 deletions src/basis/basis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ function (ℒmon::ℒMonomialBasis{Dim,Deg})(x) where {Dim,Deg}
end
(m::ℒMonomialBasis)(b, x) = m.f(b, x)

degree(::ℒMonomialBasis{Dim,Deg}) where {Dim,Deg} = Deg
dim(::ℒMonomialBasis{Dim,Deg}) where {Dim,Deg} = Dim

include("polyharmonic_spline.jl")
include("inverse_multiquadric.jl")
include("gaussian.jl")
Expand Down
10 changes: 5 additions & 5 deletions src/basis/polyharmonic_spline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ end

(phs::PHS1)(x, xᵢ) = euclidean(x, xᵢ)
function (::PHS1, dim::Int)
∂ℒ(x, xᵢ) = (x[dim] - xᵢ[dim]) / (euclidean(x, xᵢ) + AVOID_NAN)
∂ℒ(x, xᵢ) = (x[dim] - xᵢ[dim]) / (euclidean(x, xᵢ) + AVOID_INF)
return ℒRadialBasisFunction(∂ℒ)
end
function (::PHS1)
Expand All @@ -49,14 +49,14 @@ end
function ∂²(::PHS1, dim::Int)
function ∂²ℒ(x, xᵢ)
return (-(x[dim] - xᵢ[dim])^2 + sqeuclidean(x, xᵢ)) /
(euclidean(x, xᵢ)^3 + AVOID_NAN)
(euclidean(x, xᵢ)^3 + AVOID_INF)
end
return ℒRadialBasisFunction(∂²ℒ)
end
function ∇²(::PHS1)
function ∇²ℒ(x, xᵢ)
return sum(
(-(x .- xᵢ) .^ 2 .+ sqeuclidean(x, xᵢ)) / (euclidean(x, xᵢ)^3 + AVOID_NAN)
(-(x .- xᵢ) .^ 2 .+ sqeuclidean(x, xᵢ)) / (euclidean(x, xᵢ)^3 + AVOID_INF)
)
end
return ℒRadialBasisFunction(∇²ℒ)
Expand Down Expand Up @@ -87,14 +87,14 @@ end
function ∂²(::PHS3, dim::Int)
function ∂²ℒ(x, xᵢ)
return 3 * (sqeuclidean(x, xᵢ) + (x[dim] - xᵢ[dim])^2) /
(euclidean(x, xᵢ) + AVOID_NAN)
(euclidean(x, xᵢ) + AVOID_INF)
end
return ℒRadialBasisFunction(∂²ℒ)
end
function ∇²(::PHS3)
function ∇²ℒ(x, xᵢ)
return sum(
3 * (sqeuclidean(x, xᵢ) .+ (x .- xᵢ) .^ 2) / (euclidean(x, xᵢ) + AVOID_NAN)
3 * (sqeuclidean(x, xᵢ) .+ (x .- xᵢ) .^ 2) / (euclidean(x, xᵢ) + AVOID_INF)
)
end
return ℒRadialBasisFunction(∇²ℒ)
Expand Down
9 changes: 4 additions & 5 deletions src/operators/monomial/monomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@ end
function ∇²(m::MonomialBasis{Dim,Deg}) where {Dim,Deg}
∂² = ntuple(dim -> (m, 2, dim), Dim)
function basis!(b, x)
cache = ones(size(b))
b .= 0
cache = ones(eltype(x), size(b))
b .= zero(eltype(x))
for ∂²! in ∂²
# use mapreduce here instead?
∂²!(cache, x)
b .+= cache
end
Expand All @@ -34,8 +33,8 @@ end

function build_monomial_basis(ids::Vector{Vector{Vector{T}}}, c::Vector{T}) where {T<:Int}
function basis!(db::AbstractVector{B}, x::AbstractVector) where {B}
db .= 1
# TODO flatten loop - why does it allocate here
db .= one(eltype(x))
# TODO optimize - allocations
@views @inbounds for i in eachindex(ids), j in eachindex(ids[i])
db[ids[i][j]] *= x[i]
end
Expand Down
45 changes: 45 additions & 0 deletions src/operators/operator_algebra.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
for op in (:+, :-)
@eval function Base.$op(a::ℒRadialBasisFunction, b::ℒRadialBasisFunction)
additive_ℒRBF(x, xᵢ) = Base.$op(a(x, xᵢ), b(x, xᵢ))
return ℒRadialBasisFunction(additive_ℒRBF)
end
end

for op in (:+, :-)
@eval function Base.$op(
a::ℒMonomialBasis{Dim,Deg}, b::ℒMonomialBasis{Dim,Deg}
) where {Dim,Deg}
function additive_ℒMon(m, x)
m .= Base.$op.(a(x), b(x))
return nothing
end
return ℒMonomialBasis(Dim, Deg, additive_ℒMon)
end
end

for op in (:+, :-)
@eval function Base.$op(op1::RadialBasisOperator, op2::RadialBasisOperator)
_check_compatible(op1, op2)
k = _update_stencil(op1, op2)
(x) = Base.$op(op1.(x), op2.(x))
return RadialBasisOperator(ℒ, op1.data, op1.basis; k=k, adjl=op1.adjl)
end
end

function _check_compatible(op1::RadialBasisOperator, op2::RadialBasisOperator)
if !all(op1.data .≈ op2.data)
throw(
ArgumentError("Can not add operators that were not built with the same data.")
)
end
if !all(op1.adjl .≈ op2.adjl)
throw(ArgumentError("Can not add operators that do not have the same stencils."))
end
end

function _update_stencil(op1::RadialBasisOperator, op2::RadialBasisOperator)
k1 = length(first((op1.adjl)))
k2 = length(first((op2.adjl)))
k = k1 > k2 ? k1 : k2
return k
end
36 changes: 0 additions & 36 deletions src/operators/operator_combinations.jl

This file was deleted.

24 changes: 24 additions & 0 deletions test/operators/operator_algebra.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using RadialBasisFunctions
using StaticArrays
using LinearAlgebra
using Statistics
using HaltonSequences

mean_percent_error(test, correct) = mean(abs.((test .- correct) ./ correct)) * 100

f(x) = 2 * x[1] + 3 * x[2]
df_dx(x) = 2
df_dy(x) = 3

N = 1000
x = SVector{2}.(HaltonPoint(2)[1:N])
y = f.(x)

dx = partial(x, 1, 1)
dy = partial(x, 1, 2)

dxdy = dx + dy
@test mean_percent_error(dxdy(y), df_dx.(x) .+ df_dy.(x)) < 1e-6

dxdy = dx - dy
@test mean_percent_error(dxdy(y), df_dx.(x) .- df_dy.(x)) < 1e-6
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ end
include("operators/virtual.jl")
end

@safetestset "Operator Algebra" begin
include("operators/operator_algebra.jl")
end

@safetestset "Stencil" begin
include("solve.jl")
end
Expand Down

0 comments on commit a59546d

Please sign in to comment.