Skip to content

Commit

Permalink
Merge pull request #2 from JuliaGaussianProcesses/matern
Browse files Browse the repository at this point in the history
  • Loading branch information
simsurace authored Feb 26, 2024
2 parents 29f9e52 + a3d67db commit 8af7789
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 8 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,7 @@ uuid = "802a20c4-2b75-4561-af12-9016d659aa4b"
version = "0.1.0"

[deps]
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ Omitting the `gpu` conversion will of course also work, but will be quite a bit
- [x] `ScalarSEKernel`
- [x] `ScalarLinearKernel`
- [x] `ScalarPeriodicKernel`
- [ ] `Matern12Kernel`
- [ ] `Matern32Kernel`
- [ ] `Matern52Kernel`
- [x] `ScalarMatern12Kernel` === `ScalarExponentialKernel`
- [x] `ScalarMatern32Kernel`
- [x] `ScalarMatern52Kernel`
- [z] `ScalarMaternKernel`

### Composite kernels
- [x] `ScalarKernelSum`, when doing `k1 + k2`, where `k1` and `k2` are `ScalarKernel`s
Expand Down
42 changes: 42 additions & 0 deletions src/ScalarKernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@ module ScalarKernelFunctions
using Reexport
@reexport using KernelFunctions

using SpecialFunctions: loggamma, besselk
using IrrationalConstants: logtwo

import KernelFunctions: Kernel
import KernelFunctions: kernelmatrix, kernelmatrix!, kernelmatrix_diag, kernelmatrix_diag!
import KernelFunctions: Transform, IdentityTransform, with_lengthscale

export ScalarKernel, ScalarSEKernel, ScalarLinearKernel, ScalarPeriodicKernel
export ScalarExponentialKernel
export ScalarMatern12Kernel, ScalarMatern32Kernel, ScalarMatern52Kernel, ScalarMaternKernel
export ScalarKernelSum, ScalarScaledKernel, with_lengthscale
export TransformedScalarKernel, ScalarScaleTransform

Expand Down Expand Up @@ -76,6 +81,43 @@ ScalarPeriodicKernel() = ScalarPeriodicKernel(1.)
(k::ScalarPeriodicKernel)(x, y) = exp(-abs2(sinpi(x - y) / k.r) / 2)
gpu(k::ScalarPeriodicKernel) = ScalarPeriodicKernel(gpu(k.r))

struct ScalarExponentialKernel <: ScalarKernel end
(k::ScalarExponentialKernel)(x, y) = exp(-abs(x - y))

const ScalarMatern12Kernel = ScalarExponentialKernel

struct ScalarMatern32Kernel <: ScalarKernel end
function (k::ScalarMatern32Kernel)(x::T, y::T) where T<:Real
sqrt3 = sqrt(T(3))
d = abs(x - y)
return (1 + sqrt3 * d) * exp(-sqrt3 * d)
end

struct ScalarMatern52Kernel <: ScalarKernel end
function (k::ScalarMatern52Kernel)(x::T, y::T) where T<:Real
sqrt5 = sqrt(T(5))
d = abs(x - y)
return (1 + sqrt5 * d + 5 * d^2 / 3) * exp(-sqrt5 * d)
end

struct ScalarMaternKernel{T<:Real} <: ScalarKernel
ν::T
end
ScalarMaternKernel() = ScalarMaternKernel(1.5)
function (k::ScalarMaternKernel)(x::T, y::T) where T<:Real
d = abs(x - y)
ν = k.ν
if iszero(d)
c = -ν /- 1)
return one(d) + c * d^2 / 2
else
y = sqrt(2ν) * d
b = log(besselk(ν, y))
return exp((one(d) - ν) * oftype(y, logtwo) - loggamma(ν) + ν * log(y) + b)
end
end
gpu(k::ScalarMaternKernel) = ScalarMaternKernel(gpu(k.ν))



struct ScalarKernelSum{T1<:Kernel, T2<:Kernel} <: ScalarKernel
Expand Down
24 changes: 19 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,26 @@ using Test
v1 = zeros(100)
v2 = zeros(100)
@testset "$kernel2" for (kernel1, kernel2) in (
# Simple kernels
SEKernel() => ScalarSEKernel(),
LinearKernel() => ScalarLinearKernel(),
PeriodicKernel() => ScalarPeriodicKernel(),
PeriodicKernel(; r = [2.]) => ScalarPeriodicKernel(2.),
2. * SEKernel() + 3. * LinearKernel() => 2. * ScalarSEKernel() + 3. * ScalarLinearKernel(),
PeriodicKernel(; r = [2.0]) => ScalarPeriodicKernel(2.0),
Matern12Kernel() => ScalarMatern12Kernel(),
Matern32Kernel() => ScalarMatern32Kernel(),
Matern52Kernel() => ScalarMatern52Kernel(),
MaternKernel(; ν = 1.2) => ScalarMaternKernel(1.2),
MaternKernel(; ν = 3.0) => ScalarMaternKernel(3.0),

# Composite kernels
2.0 * SEKernel() + 3.0 * LinearKernel() =>
2.0 * ScalarSEKernel() + 3.0 * ScalarLinearKernel(),
SEKernel() * PeriodicKernel() => ScalarSEKernel() * ScalarPeriodicKernel()
)
@testset for (k1, k2) in (
(kernel1, kernel2),
with_lengthscale.((kernel1, kernel2), 2.),
2 .* (kernel1, kernel2)
with_lengthscale.((kernel1, kernel2), 2.0),
2.0 .* (kernel1, kernel2)
)
@test k1(1., 4.) k2(1., 4.)
@test kernelmatrix(k1, x0) kernelmatrix(k2, x0)
Expand Down Expand Up @@ -58,8 +67,13 @@ using Test
x0 = rand(Float32, 10) |> jl
x1 = rand(Float32, 10) |> jl
@testset "$k" for k in (
# Simple kernels
ScalarSEKernel(), ScalarLinearKernel(), ScalarPeriodicKernel(),
with_lengthscale(ScalarSEKernel(), 2.), 2. * ScalarLinearKernel(),
ScalarMatern12Kernel(), ScalarMatern32Kernel(), ScalarMatern52Kernel(),
ScalarMaternKernel(1.2),

# Composite kernels
with_lengthscale(ScalarSEKernel(), 2.0), 2.0 * ScalarLinearKernel(),
ScalarSEKernel() + ScalarPeriodicKernel(),
# ScalarSEKernel() * ScalarPeriodicKernel()
)
Expand Down

0 comments on commit 8af7789

Please sign in to comment.