Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement vcat(::AbstractBandedMatrix...) #448

Merged
merged 23 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
1d9ce24
implement bandwidths for OneElement
max-vassili3v Jul 16, 2024
e173576
Merge branch 'master' into pr/447
dlfivefifty Jul 16, 2024
5b5fdc1
make improvements
max-vassili3v Jul 16, 2024
997495a
Merge branch 'bandwidths-oneelement' of https://github.com/max-vassil…
max-vassili3v Jul 16, 2024
df56798
fix sparse(::SparseMatrixCSC)
max-vassili3v Jul 17, 2024
1adba5e
fix bandwidths for SparseMatrixCSC, add for SparseVector
max-vassili3v Jul 17, 2024
63ddd98
add bandwidths(::Zeros) behaviour for empty sparse structures
max-vassili3v Jul 17, 2024
e31183f
add unit tests
max-vassili3v Jul 17, 2024
c4f6da9
overload vcat(::AbstractBandedMatrix...)
max-vassili3v Jul 18, 2024
226a909
style
max-vassili3v Jul 18, 2024
7f612ce
include tests in runtests.jl
max-vassili3v Jul 18, 2024
4b6ad72
fix issue involving LazyBandedMatrices
max-vassili3v Jul 18, 2024
e64d1ff
fixed mistake
max-vassili3v Jul 18, 2024
2679f45
make improvements
max-vassili3v Jul 19, 2024
d527ac8
Merge branch 'bandwidths-oneelement' into implement_vcat
max-vassili3v Jul 19, 2024
9297cc1
add vcat between BandedMatrices and OneElements
max-vassili3v Jul 19, 2024
0282c46
fix issue involving calculation of bandwidths. Add unit tests for One…
max-vassili3v Jul 19, 2024
174b393
fix issue involving bandwidths larger than dimensions
max-vassili3v Jul 20, 2024
73bf5a2
Merge branch 'master' into implement_vcat
dlfivefifty Jul 23, 2024
40599d4
Merge branch 'master' into pr/448
dlfivefifty Aug 7, 2024
eee5ad9
restore vcat
dlfivefifty Aug 7, 2024
6b02ee3
Merge branch 'JuliaLinearAlgebra:master' into implement_vcat
max-vassili3v Aug 22, 2024
9e8706b
v1.7.4
dlfivefifty Sep 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions ext/BandedMatricesSparseArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,53 @@

using BandedMatrices
using BandedMatrices: _banded_rowval, _banded_colval, _banded_nzval
using SparseArrays
using SparseArrays, FillArrays
import SparseArrays: sparse

function sparse(B::BandedMatrix)
sparse(_banded_rowval(B), _banded_colval(B), _banded_nzval(B), size(B)...)
end

function BandedMatrices.bandwidths(A::SparseMatrixCSC)
l,u = -size(A,1),-size(A,2)

m,n = size(A)
l = u = -max(size(A,1),size(A,2))
n = size(A)[2]

Check warning on line 14 in ext/BandedMatricesSparseArraysExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BandedMatricesSparseArraysExt.jl#L13-L14

Added lines #L13 - L14 were not covered by tests
dlfivefifty marked this conversation as resolved.
Show resolved Hide resolved
rows = rowvals(A)
vals = nonzeros(A)

if isempty(vals)
return bandwidths(Zeros(1))

Check warning on line 19 in ext/BandedMatricesSparseArraysExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BandedMatricesSparseArraysExt.jl#L18-L19

Added lines #L18 - L19 were not covered by tests
end

for j = 1:n
for ind in nzrange(A, j)
i = rows[ind]
# We skip non-structural zeros when computing the
# bandwidths.
iszero(vals[ind]) && continue
ij = abs(i-j)
if i ≥ j
l = max(l, ij)
u = max(u, -ij)
elseif i < j
l = max(l, -ij)
u = max(u, ij)
end
u = max(u, j-i)
l = max(l, i-j)

Check warning on line 29 in ext/BandedMatricesSparseArraysExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BandedMatricesSparseArraysExt.jl#L28-L29

Added lines #L28 - L29 were not covered by tests
end
end

l,u
end

#Treat as n x 1 matrix
function BandedMatrices.bandwidths(A::SparseVector)
l = u = -size(A,1)
rows = rowvals(A)

Check warning on line 39 in ext/BandedMatricesSparseArraysExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BandedMatricesSparseArraysExt.jl#L37-L39

Added lines #L37 - L39 were not covered by tests

if isempty(rows)
return bandwidths(Zeros(1))

Check warning on line 42 in ext/BandedMatricesSparseArraysExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BandedMatricesSparseArraysExt.jl#L41-L42

Added lines #L41 - L42 were not covered by tests
end

for i in rows
iszero(i) && continue
dlfivefifty marked this conversation as resolved.
Show resolved Hide resolved
u = max(u, 1-i)
l = max(l, i-1)
end

Check warning on line 49 in ext/BandedMatricesSparseArraysExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BandedMatricesSparseArraysExt.jl#L45-L49

Added lines #L45 - L49 were not covered by tests

l,u

Check warning on line 51 in ext/BandedMatricesSparseArraysExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BandedMatricesSparseArraysExt.jl#L51

Added line #L51 was not covered by tests
end

end
54 changes: 54 additions & 0 deletions src/generic/AbstractBandedMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,57 @@
throw(ArgumentError("dimension must be ≥ 1, got $dims"))
end
end

###
# vcat
###

function LinearAlgebra.vcat(x::AbstractBandedMatrix...)

Check warning on line 409 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L409

Added line #L409 was not covered by tests
#avoid unnecessary steps for singleton
if length(x) == 1
return x[1]

Check warning on line 412 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L411-L412

Added lines #L411 - L412 were not covered by tests
end

#instantiate the returned banded matrix with zeros and required bandwidths/dimensions
m = size(x[1], 2)
l,u = -m, typemin(Int64)
n = 0
isempty = true

Check warning on line 419 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L416-L419

Added lines #L416 - L419 were not covered by tests

#Check for dimension error and calculate bandwidths
for A in x
if size(A, 2) != m
sizes = Tuple(size(b, 2) for b in x)
throw(DimensionMismatch("number of columns of each matrix must match (got $sizes)"))

Check warning on line 425 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L422-L425

Added lines #L422 - L425 were not covered by tests
end

l_A, u_A = bandwidths(A)
if l_A + u_A >= 0
isempty = false
u = max(u, min(m - 1, u_A) - n)
l = max(l, min(size(A, 1) - 1, l_A) + n)

Check warning on line 432 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L428-L432

Added lines #L428 - L432 were not covered by tests
end

n += size(A, 1)
end

Check warning on line 436 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L435-L436

Added lines #L435 - L436 were not covered by tests

type = promote_type(eltype.(x)...)
if isempty
return BandedMatrix{type}(undef, (n, m), bandwidths(Zeros(1)))

Check warning on line 440 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L438-L440

Added lines #L438 - L440 were not covered by tests
end
ret = BandedMatrix(Zeros{type}(n, m), (l, u))

Check warning on line 442 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L442

Added line #L442 was not covered by tests

#Populate the banded matrix
row_offset = 0
for A in x
n_A = size(A, 1)

Check warning on line 447 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L445-L447

Added lines #L445 - L447 were not covered by tests

for i = 1:n_A, j = rowrange(A, i)
ret[row_offset + i, j] = A[i, j]
end

Check warning on line 451 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L449-L451

Added lines #L449 - L451 were not covered by tests

row_offset += n_A
end

Check warning on line 454 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L453-L454

Added lines #L453 - L454 were not covered by tests

ret

Check warning on line 456 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L456

Added line #L456 was not covered by tests
end
22 changes: 22 additions & 0 deletions src/interfaceimpl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,28 @@
sublayout(::AbstractTridiagonalLayout, ::Type{<:Tuple{AbstractUnitRange{Int},AbstractUnitRange{Int}}}) =
BandedLayout()

#Implement bandwidths for OneElement structure
function bandwidths(o::OneElement)
pos = FillArrays.nzind(o)
if length(pos) == 1
n = length(o)
if pos[1] > n
bandwidths(Zeros(o))

Check warning on line 65 in src/interfaceimpl.jl

View check run for this annotation

Codecov / codecov/patch

src/interfaceimpl.jl#L60-L65

Added lines #L60 - L65 were not covered by tests
else
(pos[1] - 1, -pos[1] + 1)

Check warning on line 67 in src/interfaceimpl.jl

View check run for this annotation

Codecov / codecov/patch

src/interfaceimpl.jl#L67

Added line #L67 was not covered by tests
end
elseif length(pos) == 2
n,m = size(o)
if pos[1] > n || pos[2] > m
bandwidths(Zeros(o))

Check warning on line 72 in src/interfaceimpl.jl

View check run for this annotation

Codecov / codecov/patch

src/interfaceimpl.jl#L69-L72

Added lines #L69 - L72 were not covered by tests
else
(pos[1]-pos[2],pos[2]-pos[1])

Check warning on line 74 in src/interfaceimpl.jl

View check run for this annotation

Codecov / codecov/patch

src/interfaceimpl.jl#L74

Added line #L74 was not covered by tests
end
end
end

LinearAlgebra.vcat(x::Union{OneElement, AbstractBandedMatrix}...) = vcat(BandedMatrix.(x)...)

Check warning on line 79 in src/interfaceimpl.jl

View check run for this annotation

Codecov / codecov/patch

src/interfaceimpl.jl#L79

Added line #L79 was not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
LinearAlgebra.vcat(x::Union{OneElement, AbstractBandedMatrix}...) = vcat(BandedMatrix.(x)...)
Base.vcat(x::Union{OneElement, AbstractBandedMatrix}...) = vcat(convert.(BandedMatrix, x)...)


###
# rot180
###
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ include("test_tribanded.jl")
include("test_interface.jl")
include("test_miscs.jl")
include("test_sum.jl")
include("test_cat.jl")
37 changes: 37 additions & 0 deletions test/test_cat.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
module TestCat

using BandedMatrices, LinearAlgebra, Test, Random, FillArrays, SparseArrays

@testset "vcat" begin
@testset "banded matrices" begin
a = BandedMatrix(0 => 1:2)
@test vcat(a) == a

b = BandedMatrix(0 => 1:3,-1 => 1:2, -2 => 1:1)
@test_throws DimensionMismatch vcat(a,b)

c = BandedMatrix(0 => [1.0, 2.0, 3.0], 1 => [1.0, 2.0], 2 => [1.0])
@test eltype(vcat(b, c)) == Float64
@test vcat(b, c) == vcat(Matrix(b), Matrix(c))

for i in ((1,2), (-3,4), (0,-1))
a = BandedMatrix(ones(Float64, rand(1:10), 5), i)
b = BandedMatrix(ones(Int64, rand(1:10), 5), i)
c = BandedMatrix(ones(Int32, rand(1:10), 5), i)
d = vcat(a, b, c)
sd = vcat(sparse(a), sparse(b), sparse(c))
@test eltype(d) == Float64
@test d == sd
@test bandwidths(d) == bandwidths(sd)
end
end

@testset "one element" begin
n = rand(3:20)
x,y = OneElement(1, (1,1), (1,n)), OneElement(1, (1,n), (1,n))
b = BandedMatrix((0 => ones(n-2), 1 => -2ones(n - 2), 2 => ones(n - 2)), (n-2, n))
@test vcat(x,b,y) == Tridiagonal([ones(n - 2); 0], [1 ; -2ones(n - 2); 1], [0; ones(n - 2)])
end
dlfivefifty marked this conversation as resolved.
Show resolved Hide resolved
end

end
16 changes: 14 additions & 2 deletions test/test_interface.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
module TestInterface

using BandedMatrices, LinearAlgebra, ArrayLayouts, FillArrays, Test
using BandedMatrices, LinearAlgebra, ArrayLayouts, FillArrays, Test, Random
import BandedMatrices: isbanded, AbstractBandedLayout, BandedStyle,
BandedColumns, bandeddata
import ArrayLayouts: OnesLayout, UnknownLayout
using InfiniteArrays
using InfiniteArrays, SparseArrays

struct PseudoBandedMatrix{T} <: AbstractMatrix{T}
data::Array{T}
Expand Down Expand Up @@ -310,6 +310,18 @@ end
@test layout_getindex(T,1:10,1:10) isa BandedMatrix
end

@testset "OneElement" begin
o = OneElement(1, 3, 5)
@test bandwidths(o) == (2,-2)
n,m = rand(1:10,2)
o = OneElement(1, (rand(1:n),rand(1:m)), (n, m))
@test bandwidths(o) == bandwidths(sparse(o))
o = OneElement(1, (n+1,m+1), (n, m))
@test bandwidths(o) == bandwidths(Zeros(o))
o = OneElement(1, 6, 5)
@test bandwidths(o) == bandwidths(Zeros(o))
end

@testset "rot180" begin
A = brand(5,5,1,2)
R = rot180(A)
Expand Down
9 changes: 9 additions & 0 deletions test/test_miscs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,17 @@ import BandedMatrices: _BandedMatrix, DefaultBandedMatrix
@test bA isa BandedMatrix
@test bA == A
@test bandwidths(bA) == min.((l,u),9)
v = sparsevec(brand(10, 1, l, u))
@test bandwidths(v) == (l, min(0, u))
end

l, u = -1, 0
A = brand(10, 10, l, u)
sA = sparse(A)
@test bandwidths(sA) == bandwidths(Zeros(1))
v = sparsevec(brand(10, 1, l, u))
@test bandwidths(v) == bandwidths(Zeros(1))

for diags = [(-1 => ones(Int, 5),),
(-2 => ones(Int, 5),),
(2 => ones(Int, 5),),
Expand Down
Loading