diff --git a/Project.toml b/Project.toml index 12b28a3..d4c291f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "HCIToolbox" uuid = "b6cd55e5-4d02-4e12-b82c-005f67e784bf" authors = ["Miles Lucas "] -version = "0.6.3" +version = "0.6.4" [deps] CoordinateTransformations = "150eb455-5306-5404-9cee-2592286d6298" @@ -9,6 +9,8 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" +Optim = "429524aa-4258-5aef-a3af-852621145aeb" PaddedViews = "5432bcbf-9aad-5242-b902-cca2824c8663" Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" @@ -21,6 +23,8 @@ CoordinateTransformations = "0.5, 0.6" FillArrays = "0.6, 0.7, 0.8, 0.9, 0.10, 0.11, 0.12, 0.13" ImageTransformations = "0.9" Interpolations = "0.12, 0.13" +LossFunctions = "0.11" +Optim = "1" PaddedViews = "0.5" Rotations = "1" SpecialFunctions = "0.10, 1, 2" diff --git a/src/HCIToolbox.jl b/src/HCIToolbox.jl index 25d61ca..3a812a5 100644 --- a/src/HCIToolbox.jl +++ b/src/HCIToolbox.jl @@ -44,5 +44,6 @@ include("angles.jl") include("scaling.jl") include("profiles.jl") +include("cutout.jl") end diff --git a/src/cutout.jl b/src/cutout.jl new file mode 100644 index 0000000..2fb57cf --- /dev/null +++ b/src/cutout.jl @@ -0,0 +1,61 @@ + +struct CutoutView{T,N,M<:AbstractArray{T,N},CT,IT} <: AbstractArray{T,N} + parent::M + center::CT + indices::IT + fill::T +end + + +function CutoutView(parent::AbstractArray{T}, center, _size; fill=zero(T)) where T + + xlims = extrema(axes(parent, 1)) + ylims = extrema(axes(parent, 2)) + + half_length = ones(Int, 2) .* _size ./ 2 + minx = max(xlims[begin], round(Int, center[begin] - half_length[begin])) + miny = max(ylims[begin], round(Int, center[end] - half_length[end])) + maxx = minx + round(Int, half_length[begin] * 2) - 1 + maxy = miny + round(Int, half_length[end] * 2) - 1 + inds = (minx:maxx, miny:maxy,) + + return CutoutView(parent, center, inds, T(fill)) +end + +function CutoutView(parent::AbstractArray{T}, _size; fill=zero(T)) where T + center = HCIToolbox.center(parent) + xlims = extrema(axes(parent, 1)) + ylims = extrema(axes(parent, 2)) + + half_length = ones(Int, 2) .* _size ./ 2 + minx = max(xlims[begin], round(Int, center[begin] - half_length[begin])) + miny = max(ylims[begin], round(Int, center[end] - half_length[end])) + maxx = minx + round(Int, half_length[begin] * 2) + maxy = miny + round(Int, half_length[end] * 2) + inds = (minx:maxx, miny:maxy,) + + return CutoutView(parent, center, inds, T(fill)) +end + +Base.parent(view::CutoutView) = view.parent +Base.size(view::CutoutView) = map(length, view.indices) +Base.copy(view::CutoutView) = CutoutView(copy(parent(view)), view.center, view.indices, view.fill) +Base.axes(view::CutoutView) = view.indices + + +@propagate_inbounds function Base.getindex(view::CutoutView{T,N}, idx::Vararg{<:Integer,N}) where {T,N} + @boundscheck checkbounds(parent(view), idx...) + + value = convert(T, parent(view)[idx...]) + return value + # ifelse(inside_annulus(view, idx...), , view.fill) +end + +# @propagate_inbounds function Base.setindex!(view::CutoutView{T,N}, val, idx::Vararg{<:Integer,N}) where {T,N} +# @boundscheck checkbounds(parent(view), idx...) +# if inside_annulus(view, idx...) +# parent(view)[idx...] = val +# else +# view.fill +# end +# end diff --git a/src/inject.jl b/src/inject.jl index 755463a..e3ab754 100644 --- a/src/inject.jl +++ b/src/inject.jl @@ -124,6 +124,16 @@ function inject!(cube::AbstractArray{T,3}, kernel, angles=Zeros(size(cube, 3)); return cube end +function inject!(cube::AbstractArray{T,4}, kernel::AbstractArray{V,3}, angles=Zeros(size(cube, 4)); kwargs...) where {T,V} + @inbounds for idx in axes(cube, 3) + frame = @view cube[:, :, idx, :] + kern = @view kernel[:, :, idx] + inject!(frame, kern, angles; kwargs...) + end + return cube +end + + function inject!(cube::AnnulusView{T}, kernel, angles=Zeros(size(cube, 3)); kwargs...) where T # All zeros position angles is actually 90° parallactic angle @inbounds for idx in cube.indices[2] diff --git a/src/scaling.jl b/src/scaling.jl index 624d76e..ad31aa9 100644 --- a/src/scaling.jl +++ b/src/scaling.jl @@ -1,4 +1,6 @@ +using LossFunctions +using Optim using PaddedViews """ @@ -100,3 +102,33 @@ julia> scale_list([0.5, 2, 4]) """ scale_list(wavelengths) = maximum(wavelengths) ./ wavelengths +function optimize_scale_list(spcube::AbstractArray{T,4}, scales, amps=ones(size(spcube, 3)); kwargs...) where T + # get temporal median first + spframe = median(spcube, dims=4)[:, :, :, 1] + return optimize_scale_list(spframe, scales; kwargs...) +end + +function optimize_scale_list(spframe::AbstractArray{T,3}, scales, amps=ones(size(spframe, 3)); mask=trues(size(spframe, 1), size(spframe, 2))) where T + reference_frame = @view spframe[:, :, end] + N_wl = size(spframe, 3) + best_scales = ones(N_wl) + best_flux = ones(N_wl) + for wl_idx in axes(spframe, 3)[begin:end - 1] + current_frame = @view spframe[:, :, wl_idx] + func(X) = _scale_opt_func(current_frame, reference_frame, X[begin], X[end], mask) + P0 = T[scales[wl_idx], amps[wl_idx]] + result = optimize(func, P0, NelderMead(); autodiff=:forward) + @info "Finished optimizing (wl=$wl_idx/$N_wl)" result + X = Optim.minimizer(result) + best_scales[wl_idx] = X[begin] + best_flux[wl_idx] = X[end] + end + return (;scales=best_scales, fluxes=best_flux) +end + +function _scale_opt_func(frame, reference, scale, amp=1, mask=trues(frame)) + scale < 1 && return Inf + scaled_frame = amp .* HCIToolbox.scale(frame, scale) + # get loss + return sum(L2DistLoss(), mask .* (scaled_frame .- reference)) +end \ No newline at end of file