diff --git a/src/lapack.jl b/src/lapack.jl index 3689301..d860df5 100644 --- a/src/lapack.jl +++ b/src/lapack.jl @@ -4,7 +4,8 @@ module _LAPACK import LinearAlgebra import LinearAlgebra.BLAS: @blasfunc -using LinearAlgebra: BlasFloat, BlasInt, chkstride1 +using LinearAlgebra: BlasFloat, BlasInt, chkstride1, require_one_based_indexing +using LinearAlgebra.LAPACK: chklapackerror import Base: Nothing const liblapack = LinearAlgebra.BLAS.liblapack @@ -138,4 +139,112 @@ for (laqps, elty, relty) in ((:slaqps_, :Float32, :Float32), end end +for (geqp3rk, elty, relty) in ((:sgeqp3rk_,:Float32,:Float32), + (:dgeqp3rk_,:Float64,:Float64), + (:cgeqp3rk_,:ComplexF32,:Float32), + (:zgeqp3rk_,:ComplexF64,:Float64)) + @eval begin + function geqp3rk!( + A::AbstractMatrix{$elty}, + nrhs::BlasInt, + kmax::BlasInt, + abstol::BlasFloat, + reltol::BlasFloat, + jpvt::AbstractVector{BlasInt}, + tau::AbstractVector{$elty} + ) + require_one_based_indexing(A, jpvt, tau) + chkstride1(A,jpvt,tau) + m,n = size(A) + if length(tau) != min(m,n) + throw(DimensionMismatch(lazy"tau has length $(length(tau)), but needs length $(min(m,n))")) + end + if length(jpvt) != n + throw(DimensionMismatch(lazy"jpvt has length $(length(jpvt)), but needs length $n")) + end + lda = stride(A,2) + if lda == 0 + return A, tau, jpvt + end # Early exit + work = Vector{$elty}(undef, 1) + lwork = BlasInt(-1) + iwork = Vector{BlasInt}(undef, n - 1) + cmplx = eltype(A)<:Complex + if cmplx + rwork = Vector{$relty}(undef, 2n) + end + info = Ref{BlasInt}() + k = Ref{BlasInt}() + maxc2nrmk = Ref{$relty}() + relmaxc2nrmk = Ref{$relty}() + for i = 1:2 # first call returns lwork as work[1] + if cmplx + ccall((@blasfunc($geqp3rk), liblapack), Cvoid, + ( + Ref{BlasInt}, #m + Ref{BlasInt}, #n + Ref{BlasInt}, #nrhs + Ref{BlasInt},#kmax + Ref{$relty}, #abstol + Ref{$relty}, #reltol + Ptr{$elty}, #A + Ref{BlasInt}, #lda + Ptr{BlasInt}, #k + Ptr{$relty}, #maxc2nrmk + Ptr{$relty}, #relmaxc2nrmk + Ptr{BlasInt}, #jpvt -> jpiv + Ptr{$elty}, #tau + Ptr{$elty}, #work + Ref{BlasInt}, #lwork + Ptr{$relty}, #rwork + Ptr{BlasInt}, #iwork + Ptr{BlasInt} #info + ), + m, n, nrhs, + kmax, abstol, reltol, + A, lda, + k, maxc2nrmk, relmaxc2nrmk, + jpvt, tau, work, + lwork, rwork, iwork, info) + else + #println("running ccall") + ccall((@blasfunc($geqp3rk), liblapack), Cvoid, + ( + Ref{BlasInt}, #m + Ref{BlasInt}, #n + Ref{BlasInt}, #nrhs + Ref{BlasInt},#kmax + Ref{$elty}, #abstol + Ref{$elty}, #reltol + Ptr{$elty}, #A + Ref{BlasInt}, #lda + Ref{BlasInt}, #k + Ptr{$elty}, #maxc2nrmk + Ptr{$elty}, #relmaxc2nrmk + Ptr{BlasInt}, #jpvt -> jpiv + Ptr{$elty}, #tau + Ptr{$elty}, #work + Ref{BlasInt}, #lwork + Ptr{BlasInt}, #iwork + Ptr{BlasInt} #info + ), + m, n, nrhs, + kmax, abstol, reltol, + A, lda, + k, maxc2nrmk, relmaxc2nrmk, + jpvt, tau, work, + lwork, iwork, info) + end + chklapackerror(info[]) + if i == 1 + lwork = BlasInt(real(work[1])) + resize!(work, lwork) + end + end + return A, k[], tau, jpvt + end + end +end + + end # module diff --git a/src/pqr.jl b/src/pqr.jl index c5c9673..6d440c4 100644 --- a/src/pqr.jl +++ b/src/pqr.jl @@ -350,14 +350,24 @@ function geqp3_adap!(A::AbstractMatrix{T}, opts::LRAOptions) where T jpvt = collect(BlasInt, 1:n) l = min(m, n) k = (opts.rank < 0 || opts.rank > l) ? l : opts.rank - tau = Array{T}(undef, k) if k > 0 - k = geqp3_adap_main!(A, jpvt, tau, opts) + if LAPACK.version() < v"3.12.0" + tau = Array{T}(undef, k) + k = geqp3_adap_main!(A, jpvt, tau, opts) + else + tau = Array{T}(undef, min(size(A)...)) + k = geqp3_adap_lapack!(A, k, jpvt, tau, opts) + end end jpvt = convert(Array{Int}, jpvt) jpvt, tau, k end +function geqp3_adap_lapack!(A::AbstractMatrix{T}, kmax::BlasInt, jpvt::Vector{BlasInt}, tau::Vector{T}, opts::LRAOptions) where T<:BlasFloat + A, k, tau, jpvt = _LAPACK.geqp3rk!(A, 0, kmax, opts.atol, opts.rtol, jpvt, tau) + return k +end + function geqp3_adap_main!( A::AbstractMatrix{T}, jpvt::Vector{BlasInt}, tau::Vector{T}, opts::LRAOptions) where T<:BlasFloat