Skip to content

Commit

Permalink
Merge branch 'main' of github.com:JuliaStats/MixedModels.jl into pa/p…
Browse files Browse the repository at this point in the history
…ca_precomp
  • Loading branch information
palday committed Mar 5, 2024
2 parents 4b6e940 + 510dcc3 commit f2acdcf
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 23 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
MixedModels v4.22.5 Release Notes
==============================
* Use `muladd` where possible to enable fused multiply-add (FMA) on architectures with hardware support. FMA will generally improve computational speed and gives more accurate rounding. [#740]
* Replace broadcasted lambda with explicit loop and use `one`. This may result in a small performance improvement. [#738]

MixedModels v4.22.4 Release Notes
Expand Down Expand Up @@ -500,5 +501,6 @@ Package dependencies
[#717]: https://github.com/JuliaStats/MixedModels.jl/issues/717
[#733]: https://github.com/JuliaStats/MixedModels.jl/issues/733
[#738]: https://github.com/JuliaStats/MixedModels.jl/issues/738
[#740]: https://github.com/JuliaStats/MixedModels.jl/issues/740
[#744]: https://github.com/JuliaStats/MixedModels.jl/issues/744
[#748]: https://github.com/JuliaStats/MixedModels.jl/issues/748
2 changes: 1 addition & 1 deletion src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ function LinearAlgebra.mul!(
αbnz = α * bnz[ib]
jj = brv[ib]
for ia in nzrange(A, j)
C[arv[ia], jj] += anz[ia] * αbnz
C[arv[ia], jj] = muladd(anz[ia], αbnz, C[arv[ia], jj])
end
end
end
Expand Down
18 changes: 9 additions & 9 deletions src/linalg/rankUpdate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function MixedModels.rankUpdate!(
Cdiag = C.data.diag
Adiag = A.diag
@inbounds for idx in eachindex(Cdiag, Adiag)
Cdiag[idx] = β * Cdiag[idx] + α * abs2(Adiag[idx])
Cdiag[idx] = muladd(β, Cdiag[idx], α * abs2(Adiag[idx]))
end
return C
end
Expand Down Expand Up @@ -52,7 +52,7 @@ function _columndot(rv, nz, rngi, rngj)
while i ni && j nj
@inbounds ri, rj = rv[rngi[i]], rv[rngj[j]]
if ri == rj
@inbounds accum += nz[rngi[i]] * nz[rngj[j]]
@inbounds accum = muladd(nz[rngi[i]], nz[rngj[j]], accum)
i += 1
j += 1
elseif ri < rj
Expand Down Expand Up @@ -80,17 +80,17 @@ function rankUpdate!(C::HermOrSym{T,S}, A::SparseMatrixCSC{T}, α, β) where {T,
rvj = rv[j]
for i in k:lenrngjj
kk = rangejj[i]
Cd[rv[kk], rvj] += nz[kk] * anzj
Cd[rv[kk], rvj] = muladd(nz[kk], anzj, Cd[rv[kk], rvj])
end
end
end
else
@inbounds for j in axes(C, 2)
rngj = nzrange(A, j)
for i in 1:(j - 1)
Cd[i, j] += α * _columndot(rv, nz, nzrange(A, i), rngj)
Cd[i, j] = muladd(α, _columndot(rv, nz, nzrange(A, i), rngj), Cd[i, j])
end
Cd[j, j] += α * sum(i -> abs2(nz[i]), rngj)
Cd[j, j] = muladd(α, sum(i -> abs2(nz[i]), rngj), Cd[j, j])
end
end
return C
Expand All @@ -109,7 +109,7 @@ function rankUpdate!(
isone(β) || rmul!(Cdiag, β)

@inbounds for i in eachindex(Cdiag)
Cdiag[i] += α * sum(abs2, view(A, i, :))
Cdiag[i] = muladd(α, sum(abs2, view(A, i, :)), Cdiag[i])
end

return C
Expand All @@ -132,9 +132,9 @@ function rankUpdate!(
AtAij = 0
for idx in axes(A, 2)
# because the second multiplicant is from A', swap index order
AtAij += A[iind, idx] * A[jind, idx]
AtAij = muladd(A[iind, idx], A[jind, idx], AtAij)
end
Cdat[i, j, k] += α * AtAij
Cdat[i, j, k] = muladd(α, AtAij, Cdat[i, j, k])
end
end

Expand All @@ -152,7 +152,7 @@ function rankUpdate!(
throw(ArgumentError("Columns of A must have exactly 1 nonzero"))

for (r, nz) in zip(rowvals(A), nonzeros(A))
dd[r] += α * abs2(nz)
dd[r] = muladd(α, abs2(nz), dd[r])
end

return C
Expand Down
6 changes: 4 additions & 2 deletions src/linearmixedmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,9 @@ function StatsAPI.leverage(m::LinearMixedModel{T}) where {T}
z = trm.z
stride = size(z, 1)
mul!(
view(rhs2, (rhsoffset + (trm.refs[i] - 1) * stride) .+ Base.OneTo(stride)),
view(
rhs2, muladd((trm.refs[i] - 1), stride, rhsoffset) .+ Base.OneTo(stride)
),
adjoint(trm.λ),
view(z, :, i),
)
Expand Down Expand Up @@ -816,7 +818,7 @@ function objective(m::LinearMixedModel{T}) where {T}
val = if isnothing(σ)
logdet(m) + denomdf * (one(T) + log2π + log(pwrss(m) / denomdf))
else
denomdf * (log2π + 2 * log(σ)) + logdet(m) + pwrss(m) / σ^2
muladd(denomdf, muladd(2, log(σ), log2π), (logdet(m) + pwrss(m) / σ^2))
end
return isempty(wts) ? val : val - T(2.0) * sum(log, wts)
end
Expand Down
23 changes: 13 additions & 10 deletions src/remat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ function LinearAlgebra.mul!(
@inbounds for (j, rrj) in enumerate(B.refs)
αzj = α * zz[j]
for i in 1:p
C[i, rrj] += αzj * Awt[j, i]
C[i, rrj] = muladd(αzj, Awt[j, i], C[i, rrj])
end
end
return C
Expand All @@ -310,7 +310,7 @@ function LinearAlgebra.mul!(
aki = α * Awt[k, i]
kk = Int(rr[k])
for ii in 1:S
scr[ii, kk] += aki * Bwt[ii, k]
scr[ii, kk] = muladd(aki, Bwt[ii, k], scr[ii, kk])
end
end
for j in 1:q
Expand Down Expand Up @@ -340,7 +340,7 @@ function LinearAlgebra.mul!(
coljlast = Int(C.colptr[j + 1] - 1)
K = searchsortedfirst(rv, i, Int(C.colptr[j]), coljlast, Base.Order.Forward)
if K coljlast && rv[K] == i
nz[K] += Az[k] * Bz[k]
nz[K] = muladd(Az[k], Bz[k], nz[K])
else
throw(ArgumentError("C does not have the nonzero pattern of A'B"))
end
Expand All @@ -361,7 +361,7 @@ function LinearAlgebra.mul!(
@inbounds for i in 1:S
zij = Awtz[i, j]
for k in 1:S
Cd[k, i, r] += zij * Awtz[k, j]
Cd[k, i, r] = muladd(zij, Awtz[k, j], Cd[k, i, r])
end
end
end
Expand Down Expand Up @@ -397,7 +397,7 @@ function LinearAlgebra.mul!(
jjo = jj + joffset
Bzijj = Bz[jj, i]
for ii in 1:S
C[ii + ioffset, jjo] += Az[ii, i] * Bzijj
C[ii + ioffset, jjo] = muladd(Az[ii, i], Bzijj, C[ii + ioffset, jjo])
end
end
end
Expand All @@ -416,7 +416,8 @@ function LinearAlgebra.mul!(
isone(beta) || rmul!(y, beta)
z = A.z
@inbounds for (i, r) in enumerate(A.refs)
y[i] += alpha * b[r] * z[i]
# must be muladd and not fma because of potential missings
y[i] = muladd(alpha * b[r], z[i], y[i])
end
return y
end
Expand Down Expand Up @@ -446,7 +447,8 @@ function LinearAlgebra.mul!(
@inbounds for (i, ii) in enumerate(A.refs)
offset = (ii - 1) * k
for j in 1:k
y[i] += alpha * Z[j, i] * b[offset + j]
# must be muladd and not fma because of potential missings
y[i] = muladd(alpha * Z[j, i], b[offset + j], y[i])
end
end
return y
Expand All @@ -466,7 +468,8 @@ function LinearAlgebra.mul!(
isone(beta) || rmul!(y, beta)
@inbounds for (i, ii) in enumerate(refarray(A))
for j in 1:k
y[i] += alpha * Z[j, i] * B[j, ii]
# must be muladd and not fma because of potential missings
y[i] = muladd(alpha * Z[j, i], B[j, ii], y[i])
end
end
return y
Expand Down Expand Up @@ -566,7 +569,7 @@ function copyscaleinflate!(Ljj::Diagonal{T}, Ajj::Diagonal{T}, Λj::ReMat{T,1})
Ldiag, Adiag = Ljj.diag, Ajj.diag
lambsq = abs2(only(Λj.λ.data))
@inbounds for i in eachindex(Ldiag, Adiag)
Ldiag[i] = lambsq * Adiag[i] + one(T)
Ldiag[i] = muladd(lambsq, Adiag[i], one(T))
end
return Ljj
end
Expand All @@ -575,7 +578,7 @@ function copyscaleinflate!(Ljj::Matrix{T}, Ajj::Diagonal{T}, Λj::ReMat{T,1}) wh
fill!(Ljj, zero(T))
lambsq = abs2(only(Λj.λ.data))
@inbounds for (i, a) in enumerate(Ajj.diag)
Ljj[i, i] = lambsq * a + one(T)
Ljj[i, i] = muladd(lambsq, a, one(T))
end
return Ljj
end
Expand Down
2 changes: 1 addition & 1 deletion test/pls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ end

vc = fm1.vcov
@test isa(vc, Matrix{Float64})
@test only(vc) 375.7167775 rtol=1.e-6
@test only(vc) 375.7167775 rtol=1.e-3
# since we're caching the fits, we should get it back to being correctly fitted
# we also take this opportunity to test fitlog
@testset "fitlog" begin
Expand Down

0 comments on commit f2acdcf

Please sign in to comment.