Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the truncated QR LAPACK routine geqp3rk as the backend for the partial qr factorization. fixes #60 #61

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 110 additions & 1 deletion src/lapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
module _LAPACK
import LinearAlgebra
import LinearAlgebra.BLAS: @blasfunc
using LinearAlgebra: BlasFloat, BlasInt, chkstride1
using LinearAlgebra: BlasFloat, BlasInt, chkstride1, require_one_based_indexing
using LinearAlgebra.LAPACK: chklapackerror
import Base: Nothing
const liblapack = LinearAlgebra.BLAS.liblapack

Expand Down Expand Up @@ -138,4 +139,112 @@ for (laqps, elty, relty) in ((:slaqps_, :Float32, :Float32),
end
end

for (geqp3rk, elty, relty) in ((:sgeqp3rk_,:Float32,:Float32),
(:dgeqp3rk_,:Float64,:Float64),
(:cgeqp3rk_,:ComplexF32,:Float32),
(:zgeqp3rk_,:ComplexF64,:Float64))
@eval begin
function geqp3rk!(
A::AbstractMatrix{$elty},
nrhs::BlasInt,
kmax::BlasInt,
abstol::BlasFloat,
reltol::BlasFloat,
jpvt::AbstractVector{BlasInt},
tau::AbstractVector{$elty}
)
require_one_based_indexing(A, jpvt, tau)
chkstride1(A,jpvt,tau)
m,n = size(A)
if length(tau) != min(m,n)
throw(DimensionMismatch(lazy"tau has length $(length(tau)), but needs length $(min(m,n))"))
end
if length(jpvt) != n
throw(DimensionMismatch(lazy"jpvt has length $(length(jpvt)), but needs length $n"))
end
lda = stride(A,2)
if lda == 0
return A, tau, jpvt
end # Early exit
work = Vector{$elty}(undef, 1)
lwork = BlasInt(-1)
iwork = Vector{BlasInt}(undef, n - 1)
cmplx = eltype(A)<:Complex
if cmplx
rwork = Vector{$relty}(undef, 2n)
end
info = Ref{BlasInt}()
k = Ref{BlasInt}()
maxc2nrmk = Ref{$relty}()
relmaxc2nrmk = Ref{$relty}()
for i = 1:2 # first call returns lwork as work[1]
if cmplx
ccall((@blasfunc($geqp3rk), liblapack), Cvoid,
(
Ref{BlasInt}, #m
Ref{BlasInt}, #n
Ref{BlasInt}, #nrhs
Ref{BlasInt},#kmax
Ref{$relty}, #abstol
Ref{$relty}, #reltol
Ptr{$elty}, #A
Ref{BlasInt}, #lda
Ptr{BlasInt}, #k
Ptr{$relty}, #maxc2nrmk
Ptr{$relty}, #relmaxc2nrmk
Ptr{BlasInt}, #jpvt -> jpiv
Ptr{$elty}, #tau
Ptr{$elty}, #work
Ref{BlasInt}, #lwork
Ptr{$relty}, #rwork
Ptr{BlasInt}, #iwork
Ptr{BlasInt} #info
),
m, n, nrhs,
kmax, abstol, reltol,
A, lda,
k, maxc2nrmk, relmaxc2nrmk,
jpvt, tau, work,
lwork, rwork, iwork, info)
else
#println("running ccall")
ccall((@blasfunc($geqp3rk), liblapack), Cvoid,
(
Ref{BlasInt}, #m
Ref{BlasInt}, #n
Ref{BlasInt}, #nrhs
Ref{BlasInt},#kmax
Ref{$elty}, #abstol
Ref{$elty}, #reltol
Ptr{$elty}, #A
Ref{BlasInt}, #lda
Ref{BlasInt}, #k
Ptr{$elty}, #maxc2nrmk
Ptr{$elty}, #relmaxc2nrmk
Ptr{BlasInt}, #jpvt -> jpiv
Ptr{$elty}, #tau
Ptr{$elty}, #work
Ref{BlasInt}, #lwork
Ptr{BlasInt}, #iwork
Ptr{BlasInt} #info
),
m, n, nrhs,
kmax, abstol, reltol,
A, lda,
k, maxc2nrmk, relmaxc2nrmk,
jpvt, tau, work,
lwork, iwork, info)
end
chklapackerror(info[])
if i == 1
lwork = BlasInt(real(work[1]))
ajinkya-k marked this conversation as resolved.
Show resolved Hide resolved
resize!(work, lwork)
end
end
return A, k[], tau, jpvt
end
end
end


end # module
14 changes: 12 additions & 2 deletions src/pqr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -350,14 +350,24 @@ function geqp3_adap!(A::AbstractMatrix{T}, opts::LRAOptions) where T
jpvt = collect(BlasInt, 1:n)
l = min(m, n)
k = (opts.rank < 0 || opts.rank > l) ? l : opts.rank
tau = Array{T}(undef, k)
if k > 0
k = geqp3_adap_main!(A, jpvt, tau, opts)
if LAPACK.version() < v"3.12.0"
tau = Array{T}(undef, k)
k = geqp3_adap_main!(A, jpvt, tau, opts)
else
tau = Array{T}(undef, min(size(A)...))
k = geqp3_adap_lapack!(A, k, jpvt, tau, opts)
end
end
jpvt = convert(Array{Int}, jpvt)
jpvt, tau, k
end

function geqp3_adap_lapack!(A::AbstractMatrix{T}, kmax::BlasInt, jpvt::Vector{BlasInt}, tau::Vector{T}, opts::LRAOptions) where T<:BlasFloat
A, k, tau, jpvt = _LAPACK.geqp3rk!(A, 0, kmax, opts.atol, opts.rtol, jpvt, tau)
return k
end

function geqp3_adap_main!(
A::AbstractMatrix{T}, jpvt::Vector{BlasInt}, tau::Vector{T},
opts::LRAOptions) where T<:BlasFloat
Expand Down
Loading