Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pullback for sparse-array vector product very inefficient #803

Open
oschulz opened this issue Jul 17, 2024 · 4 comments
Open

Pullback for sparse-array vector product very inefficient #803

oschulz opened this issue Jul 17, 2024 · 4 comments

Comments

@oschulz
Copy link

oschulz commented Jul 17, 2024

Our current rrule for sparse matrix vector products is very inefficient, and causes out-of-memory with large sparse CPU or GPU arrays. Our current rrule(*, sparse(A), x) is implemented like this

function rrule(
    ::typeof(*),
    A::AbstractVecOrMat{<:CommutativeMulNumber},
    B::AbstractVecOrMat{<:CommutativeMulNumber},
)
    project_A = ProjectTo(A)
    ...
        dA = @thunk(project_A(Ȳ * B'))
    ...
end

So we first compute a non-sparse Ȳ * B' (may easily exceed memory if A was very large but very sparse) and then project back to a sparse tangent.

The best way to fix this (at least if Ȳ' and 'B' are vectors) might be adding a specific "vector-outer-product" array type for read-only vector * adjoint-vector products (might be useful in general) that computes getindex on the fly. Or maybe we already have that somewhere?

@mcabbott
Copy link
Member

Alternatively, there could be a more specialised rrule(::*, ::SparseMatrix, ...) which knows about this.

Note that all things sparse are very crude! We rushed to include the semantics someone said they wanted in the 1.0 release. It would probably have been better to leave them as errors, until someone who cared could take on the task.

https://github.com/JuliaDiff/ChainRulesCore.jl/blob/2f2c941712f9e2cd11f476666a63dc462ed6440a/ext/ChainRulesCoreSparseArraysExt.jl#L10-L12

@oschulz
Copy link
Author

oschulz commented Oct 16, 2024

Alternatively, there could be a more specialised rrule(::*, ::SparseMatrix, ...) which knows about this.

A dispatch-based solution at the rrule-level could get tricky though, since there are also different GPU sparse matrix types, right?

@mcabbott
Copy link
Member

Yea I don't know! Do any such types currently work? I'd be pretty surprised if ProjectTo digested this LazyOuterProductMatrix{CuVector, ...}. (The code is, again, pretty crude, and written in a rush by someone who knew little about anything sparse.) Could the rrule dispatch on AbstractSparseMatrix?

@oschulz
Copy link
Author

oschulz commented Oct 16, 2024

I did a quick test, LowRankMatrices.LowRankMatrix(Ȳ, B) seems to be very performant. So this could already do the trick I think

dA = @thunk(project_A(LowRankMatrix(Ȳ, B)))

(If LowRankMatrices is an acceptable dependency for ChainRulesCore, it's very lightweight, though).

ProjectTo is specialized for SparseMatrixCSC and uses getindex to get the elements of it's input, which seems to be very efficient for LowRankMatrix. And there shouldn't be any unnecessary memory allocation.

It might be interesting in general, also for non-sparse matrices, though I guess that would be a pretty big change in current behavior. But could dispatch the implementation of the pullback based on ArrayInterface.issparse or so, maybe?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants