Skip to content

Commit

Permalink
extended GatedGraphConv and NNConv to use AbstactGNNGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
Ghaithq committed Feb 11, 2024
1 parent 983669e commit 9d741a8
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ end
# remove after https://github.com/JuliaDiff/ChainRules.jl/pull/521
@non_differentiable fill!(x...)

function (l::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {S <: Real}
function (l::GatedGraphConv)(g::AbstractGNNGraph, H::AbstractMatrix{S}) where {S <: Real}
check_num_nodes(g, H)
m, n = size(H)
@assert (m<=l.out_ch) "number of input features must less or equals to output features."
Expand Down Expand Up @@ -739,11 +739,11 @@ function NNConv(ch::Pair{Int, Int}, nn, σ = identity; aggr = +, bias = true,
return NNConv(W, b, nn, σ, aggr)
end

function (l::NNConv)(g::GNNGraph, x::AbstractMatrix, e)
function (l::NNConv)(g::AbstractGNNGraph, x, e)
check_num_nodes(g, x)

m = propagate(message, g, l.aggr, l, xj = x, e = e)
return l.σ.(l.weight * x .+ m .+ l.bias)
xj, xi = expand_srcdst(g, x)
m = propagate(message, g, l.aggr, l, xj = xj, e = e)
return l.σ.(l.weight * xi .+ m .+ l.bias)
end

function message(l::NNConv, xi, xj, e)
Expand Down

0 comments on commit 9d741a8

Please sign in to comment.