From 4ff8ff00dc8b19943a531debaedb9d2650b41b81 Mon Sep 17 00:00:00 2001 From: Maxim Vassiliev <76599693+max-vassili3v@users.noreply.github.com> Date: Wed, 4 Sep 2024 11:06:56 +0100 Subject: [PATCH] Implement vcat(::AbstractBandedMatrix...) (#448) * implement bandwidths for OneElement * make improvements * fix sparse(::SparseMatrixCSC) * fix bandwidths for SparseMatrixCSC, add for SparseVector * add bandwidths(::Zeros) behaviour for empty sparse structures * add unit tests * overload vcat(::AbstractBandedMatrix...) * style * include tests in runtests.jl * fix issue involving LazyBandedMatrices * fixed mistake * make improvements * add vcat between BandedMatrices and OneElements * fix issue involving calculation of bandwidths. Add unit tests for OneElement * fix issue involving bandwidths larger than dimensions * restore vcat * v1.7.4 --------- Co-authored-by: Sheehan Olver --- src/BandedMatrices.jl | 2 +- src/generic/AbstractBandedMatrix.jl | 54 +++++++++++++++++++++++++++++ src/interfaceimpl.jl | 3 ++ test/runtests.jl | 1 + test/test_cat.jl | 37 ++++++++++++++++++++ 5 files changed, 96 insertions(+), 1 deletion(-) create mode 100644 test/test_cat.jl diff --git a/src/BandedMatrices.jl b/src/BandedMatrices.jl index 33900c67..8009d326 100644 --- a/src/BandedMatrices.jl +++ b/src/BandedMatrices.jl @@ -34,7 +34,7 @@ import ArrayLayouts: AbstractTridiagonalLayout, BidiagonalLayout, BlasMatLdivVec symmetricuplo, transposelayout, triangulardata, triangularlayout, zero!, QRPackedQLayout, AdjQRPackedQLayout -import FillArrays: AbstractFill, getindex_value, _broadcasted_zeros, unique_value, OneElement, RectDiagonal, OneElementMatrix, OneElementVector +import FillArrays: AbstractFill, getindex_value, _broadcasted_zeros, unique_value, OneElement, RectDiagonal, OneElementMatrix, OneElementVector, ZerosMatrix, ZerosVector const libblas = LinearAlgebra.BLAS.libblas const liblapack = LinearAlgebra.BLAS.liblapack diff --git a/src/generic/AbstractBandedMatrix.jl b/src/generic/AbstractBandedMatrix.jl index c5653c26..7cd5247d 100644 --- a/src/generic/AbstractBandedMatrix.jl +++ b/src/generic/AbstractBandedMatrix.jl @@ -401,3 +401,57 @@ function sum(A::AbstractBandedMatrix; dims=:) throw(ArgumentError("dimension must be ≥ 1, got $dims")) end end + +### +# vcat +### + +function LinearAlgebra.vcat(x::AbstractBandedMatrix...) + #avoid unnecessary steps for singleton + if length(x) == 1 + return x[1] + end + + #instantiate the returned banded matrix with zeros and required bandwidths/dimensions + m = size(x[1], 2) + l,u = -m, typemin(Int64) + n = 0 + isempty = true + + #Check for dimension error and calculate bandwidths + for A in x + if size(A, 2) != m + sizes = Tuple(size(b, 2) for b in x) + throw(DimensionMismatch("number of columns of each matrix must match (got $sizes)")) + end + + l_A, u_A = bandwidths(A) + if l_A + u_A >= 0 + isempty = false + u = max(u, min(m - 1, u_A) - n) + l = max(l, min(size(A, 1) - 1, l_A) + n) + end + + n += size(A, 1) + end + + type = promote_type(eltype.(x)...) + if isempty + return BandedMatrix{type}(undef, (n, m), bandwidths(Zeros(1))) + end + ret = BandedMatrix(Zeros{type}(n, m), (l, u)) + + #Populate the banded matrix + row_offset = 0 + for A in x + n_A = size(A, 1) + + for i = 1:n_A, j = rowrange(A, i) + ret[row_offset + i, j] = A[i, j] + end + + row_offset += n_A + end + + ret +end diff --git a/src/interfaceimpl.jl b/src/interfaceimpl.jl index 789588a9..e8dfec5a 100644 --- a/src/interfaceimpl.jl +++ b/src/interfaceimpl.jl @@ -116,3 +116,6 @@ function getindex(D::Bidiagonal{T,V}, b::Band) where {T,V} D.uplo == 'U' && b.i == 1 && return copy(D.ev) convert(V, Zeros{T}(size(D,1)-abs(b.i))) end + + +Base.vcat(x::Union{OneElement, ZerosMatrix, AdjOrTrans{<:Any,<:ZerosVector}, AbstractBandedMatrix}...) = vcat(BandedMatrix.(x)...) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 7778453b..230c4835 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -27,3 +27,4 @@ include("test_tribanded.jl") include("test_interface.jl") include("test_miscs.jl") include("test_sum.jl") +include("test_cat.jl") diff --git a/test/test_cat.jl b/test/test_cat.jl new file mode 100644 index 00000000..50a42b73 --- /dev/null +++ b/test/test_cat.jl @@ -0,0 +1,37 @@ +module TestCat + +using BandedMatrices, LinearAlgebra, Test, Random, FillArrays, SparseArrays + +@testset "vcat" begin + @testset "banded matrices" begin + a = BandedMatrix(0 => 1:2) + @test vcat(a) == a + + b = BandedMatrix(0 => 1:3,-1 => 1:2, -2 => 1:1) + @test_throws DimensionMismatch vcat(a,b) + + c = BandedMatrix(0 => [1.0, 2.0, 3.0], 1 => [1.0, 2.0], 2 => [1.0]) + @test eltype(vcat(b, c)) == Float64 + @test vcat(b, c) == vcat(Matrix(b), Matrix(c)) + + for i in ((1,2), (-3,4), (0,-1)) + a = BandedMatrix(ones(Float64, rand(1:10), 5), i) + b = BandedMatrix(ones(Int64, rand(1:10), 5), i) + c = BandedMatrix(ones(Int32, rand(1:10), 5), i) + d = vcat(a, b, c) + sd = vcat(sparse(a), sparse(b), sparse(c)) + @test eltype(d) == Float64 + @test d == sd + @test bandwidths(d) == bandwidths(sd) + end + end + + @testset "one element" begin + n = rand(3:20) + x,y = OneElement(1, (1,1), (1,n)), OneElement(1, (1,n), (1,n)) + b = BandedMatrix((0 => ones(n-2), 1 => -2ones(n - 2), 2 => ones(n - 2)), (n-2, n)) + @test vcat(x,b,y) == Tridiagonal([ones(n - 2); 0], [1 ; -2ones(n - 2); 1], [0; ones(n - 2)]) + end +end + +end \ No newline at end of file