Skip to content

Commit

Permalink
better inference for num_nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Feb 6, 2025
1 parent 2883214 commit e6d8d48
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 9 deletions.
10 changes: 5 additions & 5 deletions GNNGraphs/src/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ function GNNGraph(data::D;

# don't force the shape of the data when there is only one graph
gdata = normalize_graphdata(gdata, default_name = :u,
n = num_graphs > 1 ? num_graphs : -1)
n = num_graphs > 1 ? num_graphs : -1, glob=true)

GNNGraph(graph,
num_nodes, num_edges, num_graphs,
Expand Down Expand Up @@ -201,10 +201,10 @@ end

function GNNGraph(g::GNNGraph; ndata = g.ndata, edata = g.edata, gdata = g.gdata,
graph_type = nothing)
ndata = normalize_graphdata(ndata, default_name = :x, n = g.num_nodes)
edata = normalize_graphdata(edata, default_name = :e, n = g.num_edges,
duplicate_if_needed = true)
gdata = normalize_graphdata(gdata, default_name = :u, n = g.num_graphs)
ndata = normalize_graphdata(ndata, default_name=:x, n=g.num_nodes)
edata = normalize_graphdata(edata, default_name=:e, n=g.num_edges,
duplicate_if_needed=true)
gdata = normalize_graphdata(gdata, default_name=:u, n=g.num_graphs, glob=true)

if !isnothing(graph_type)
if graph_type == :coo
Expand Down
2 changes: 1 addition & 1 deletion GNNGraphs/src/gnnheterograph/gnnheterograph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ function GNNHeteroGraph(data::EDict;
edata = normalize_heterographdata(edata, default_name = :e, ns = num_edges,
duplicate_if_needed = true)
gdata = normalize_graphdata(gdata, default_name = :u,
n = num_graphs > 1 ? num_graphs : -1)
n = num_graphs > 1 ? num_graphs : -1, glob=true)
end

return GNNHeteroGraph(graph,
Expand Down
6 changes: 3 additions & 3 deletions GNNGraphs/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,16 @@ function normalize_graphdata(data; default_name::Symbol, kws...)
normalize_graphdata(NamedTuple{(default_name,)}((data,)); default_name, kws...)
end

function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_needed = false)
function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_needed = false, glob=false)
# This had to workaround two Zygote bugs with NamedTuples
# https://github.com/FluxML/Zygote.jl/issues/1071
# https://github.com/FluxML/Zygote.jl/issues/1072
# https://github.com/FluxML/Zygote.jl/issues/1072 # TODO this is fixed

if n > 1
@assert all(x -> x isa AbstractArray, data) "Non-array features provided."
end

if n <= 1
if n <= 1 && glob
# If last array dimension is not 1, add a new dimension.
# This is mostly useful to reshape global feature vectors
# of size D to Dx1 matrices.
Expand Down
28 changes: 28 additions & 0 deletions GNNGraphs/test/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,34 @@ end
end
end

@testitem "Constructor: empty" setup=[GraphsTestModule] begin
g = GNNGraph(ndata=ones(2, 1))
@test g.num_nodes == 1
@test g.num_edges == 0
@test g.ndata.x == ones(2, 1)

g = GNNGraph(num_nodes=1)
@test g.num_nodes == 1
@test g.num_edges == 0
@test isempty(g.ndata)

g = GNNGraph((Int[], Int[]); ndata=(;a=[1]))
@test g.num_nodes == 1
@test g.num_edges == 0
@test g.ndata.a == [1]

g = GNNGraph((Int[], Int[]); ndata=(;a=[1]), edata=(;b=Int[]), num_nodes=1)
@test g.num_nodes == 1
@test g.num_edges == 0
@test g.ndata.a == [1]
@test g.edata.b == Int[]

g = GNNGraph(; edata=(;b=Int[]))
@test g.num_nodes == 0
@test g.num_edges == 0
@test g.edata.b == Int[]
end

@testitem "symmetric graph" setup=[GraphsTestModule] tags=[:gpu] begin
using .GraphsTestModule
dev = gpu_device()
Expand Down

0 comments on commit e6d8d48

Please sign in to comment.