Skip to content

Commit

Permalink
[GNNLux] Add pooling layers (#576)
Browse files Browse the repository at this point in the history
  • Loading branch information
aurorarossi authored Jan 10, 2025
1 parent 24eae1f commit 2dd14fd
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 13 deletions.
4 changes: 3 additions & 1 deletion GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ export TGCN,
EvolveGCNO

include("layers/pool.jl")
export GlobalPool
export GlobalPool,
GlobalAttentionPool,
TopKPool

end #module

93 changes: 92 additions & 1 deletion GNNLux/src/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,95 @@ end

(l::GlobalPool)(g::GNNGraph, x::AbstractArray, ps, st) = GNNlib.global_pool(l, g, x), st

(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st))
(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st))

@doc raw"""
GlobalAttentionPool(fgate, ffeat=identity)
Global soft attention layer from the [Gated Graph Sequence Neural
Networks](https://arxiv.org/abs/1511.05493) paper
```math
\mathbf{u}_V = \sum_{i\in V} \alpha_i\, f_{feat}(\mathbf{x}_i)
```
where the coefficients ``\alpha_i`` are given by a [`GNNLib.softmax_nodes`](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GNNlib.jl/stable/api/utils/#GNNlib.softmax_nodes)
operation:
```math
\alpha_i = \frac{e^{f_{gate}(\mathbf{x}_i)}}
{\sum_{i'\in V} e^{f_{gate}(\mathbf{x}_{i'})}}.
```
# Arguments
- `fgate`: The function ``f_{gate}: \mathbb{R}^{D_{in}} \to \mathbb{R}``.
It is typically expressed by a neural network.
- `ffeat`: The function ``f_{feat}: \mathbb{R}^{D_{in}} \to \mathbb{R}^{D_{out}}``.
It is typically expressed by a neural network.
# Examples
```julia
using Graphs, LuxCore, Lux, GNNLux, Random
rng = Random.default_rng()
chin = 6
chout = 5
fgate = Dense(chin, 1)
ffeat = Dense(chin, chout)
pool = GlobalAttentionPool(fgate, ffeat)
g = batch([GNNGraph(Graphs.random_regular_graph(10, 4),
ndata=rand(Float32, chin, 10))
for i=1:3])
ps = (fgate = LuxCore.initialparameters(rng, fgate), ffeat = LuxCore.initialparameters(rng, ffeat))
st = (fgate = LuxCore.initialstates(rng, fgate), ffeat = LuxCore.initialstates(rng, ffeat))
u, st = pool(g, g.ndata.x, ps, st)
@assert size(u) == (chout, g.num_graphs)
```
"""
@concrete struct GlobalAttentionPool <: GNNContainerLayer{(:fgate, :ffeat)}
fgate
ffeat
end

GlobalAttentionPool(fgate) = GlobalAttentionPool(fgate, identity)

function (l::GlobalAttentionPool)(g, x, ps, st)
fgate = StatefulLuxLayer{true}(l.fgate, ps.fgate, _getstate(st, :fgate))
ffeat = StatefulLuxLayer{true}(l.ffeat, ps.ffeat, _getstate(st, :ffeat))
m = (; fgate, ffeat)
return GNNlib.global_attention_pool(m, g, x), st
end

(l::GlobalAttentionPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st))

"""
TopKPool(adj, k, in_channel)
Top-k pooling layer.
# Arguments
- `adj`: Adjacency matrix of a graph.
- `k`: Top-k nodes are selected to pool together.
- `in_channel`: The dimension of input channel.
"""
struct TopKPool{T, S}
A::AbstractMatrix{T}
k::Int
p::AbstractVector{S}
::AbstractMatrix{T}
end

function TopKPool(adj::AbstractMatrix, k::Int, in_channel::Int; init = glorot_uniform)
TopKPool(adj, k, init(in_channel), similar(adj, k, k))
end

(t::TopKPool)(x::AbstractArray, ps, st) = GNNlib.topk_pool(t, x), st
43 changes: 37 additions & 6 deletions GNNLux/test/layers/pool.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,46 @@
@testitem "Pooling" setup=[TestModuleLux] begin
using .TestModuleLux
@testset "GlobalPool" begin
@testset "Pooling" begin

rng = StableRNG(1234)
g = rand_graph(rng, 10, 40)
in_dims = 3
x = randn(rng, Float32, in_dims, 10)

@testset "GCNConv" begin
@testset "GlobalPool" begin
g = rand_graph(rng, 10, 40)
in_dims = 3
x = randn(rng, Float32, in_dims, 10)
l = GlobalPool(mean)
test_lux_layer(rng, l, g, x, sizey=(in_dims,1))
end
@testset "GlobalAttentionPool" begin
n = 10
chin = 6
chout = 5
ng = 3
g = batch([GNNGraph(rand_graph(rng, 10, 40),
ndata = rand(Float32, chin, n)) for i in 1:ng])

fgate = Dense(chin, 1)
ffeat = Dense(chin, chout)
l = GlobalAttentionPool(fgate, ffeat)

test_lux_layer(rng, l, g, g.ndata.x, sizey=(chout,ng), container=true)
end

@testset "TopKPool" begin
N = 10
k, in_channel = 4, 7
X = rand(in_channel, N)
ps = (;)
st = (;)
for T in [Bool, Float64]
adj = rand(T, N, N)
p = GNNLux.TopKPool(adj, k, in_channel)
@test eltype(p.p) === Float32
@test size(p.p) == (in_channel,)
@test eltype(p.Ã) === T
@test size(p.Ã) == (k, k)
y, st = p(X, ps, st)
@test size(y) == (in_channel, k)
end
end
end
end
4 changes: 2 additions & 2 deletions GNNlib/src/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ topk_index(y::Adjoint, k::Int) = topk_index(y', k)
function set2set_pool(l, g::GNNGraph, x::AbstractMatrix)
n_in = size(x, 1)
qstar = zeros_like(x, (2*n_in, g.num_graphs))
h = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2))
c = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2))
h = zeros_like(l.Wh, size(l.Wh, 2))
c = zeros_like(l.Wh, size(l.Wh, 2))
state = (h, c)
for t in 1:l.num_iters
q, state = l.lstm(qstar, state) # [n_in, n_graphs]
Expand Down
7 changes: 4 additions & 3 deletions GraphNeuralNetworks/src/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ operation:
# Arguments
- `fgate`: The function ``f_{gate}: \mathbb{R}^{D_{in}} \to \mathbb{R}``.
It is tipically expressed by a neural network.
It is typically expressed by a neural network.
- `ffeat`: The function ``f_{feat}: \mathbb{R}^{D_{in}} \to \mathbb{R}^{D_{out}}``.
It is tipically expressed by a neural network.
It is typically expressed by a neural network.
# Examples
Expand Down Expand Up @@ -156,7 +156,8 @@ function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1)
end

function (l::Set2Set)(g, x)
return GNNlib.set2set_pool(l, g, x)
m = (; l.lstm, l.num_iters, Wh = l.lstm.Wh)
return GNNlib.set2set_pool(m, g, x)
end

(l::Set2Set)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g)))

0 comments on commit 2dd14fd

Please sign in to comment.