Skip to content

Commit

Permalink
fix zygote error
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jan 13, 2025
1 parent 2dd14fd commit 306da07
Show file tree
Hide file tree
Showing 14 changed files with 104 additions and 69 deletions.
39 changes: 38 additions & 1 deletion GNNGraphs/docs/src/guides/datasets.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,44 @@
# Datasets

GNNGraphs.jl doesn't come with its own datasets, but leverages those available in the Julia (and non-Julia) ecosystem. In particular, the [examples in the GraphNeuralNetworks.jl repository](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/tree/master/examples) make use of the [MLDatasets.jl](https://github.com/JuliaML/MLDatasets.jl) package. There you will find common graph datasets such as Cora, PubMed, Citeseer, TUDataset and [many others](https://juliaml.github.io/MLDatasets.jl/dev/datasets/graphs/).
GNNGraphs.jl doesn't come with its own datasets, but leverages those available in the Julia (and non-Julia) ecosystem.

## MLDatasets.jl

Some of the [examples in the GraphNeuralNetworks.jl repository](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/tree/master/examples) make use of the [MLDatasets.jl](https://github.com/JuliaML/MLDatasets.jl) package. There you will find common graph datasets such as Cora, PubMed, Citeseer, TUDataset and [many others](https://juliaml.github.io/MLDatasets.jl/dev/datasets/graphs/).
For graphs with static structures and temporal features, datasets such as METRLA, PEMSBAY, ChickenPox, and WindMillEnergy are available. For graphs featuring both temporal structures and temporal features, the TemporalBrains dataset is suitable.

GraphNeuralNetworks.jl provides the [`mldataset2gnngraph`](@ref) method for interfacing with MLDatasets.jl.

## PyGDatasets.jl

The package [PyGDatasets.jl](https://github.com/CarloLucibello/PyGDatasets.jl) makes available to Julia users the datasets from the [pytorch geometric](https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html) library.

PyGDatasets' datasets are compatible with GNNGraphs, so no additional conversion is needed.
```julia
julia> using PyGDatasets

julia> dataset = load_dataset("TUDataset", name="MUTAG")
TUDataset(MUTAG) - InMemoryGNNDataset
num_graphs: 188
node_features: [:x]
edge_features: [:edge_attr]
graph_features: [:y]
root: /Users/carlo/.julia/scratchspaces/44f67abd-f36e-4be4-bfe5-65f468a62b3d/datasets/TUDataset

julia> g = dataset[1]
GNNGraph:
num_nodes: 17
num_edges: 38
ndata:
x = 7×17 Matrix{Float32}
edata:
edge_attr = 4×38 Matrix{Float32}
gdata:
y = 1-element Vector{Int64}

julia> using MLUtils: DataLoader

julia> data_loader = DataLoader(dataset, batch_size=32);
```

PyGDatasets is based on [PythonCall.jl](https://github.com/JuliaPy/PythonCall.jl). It carries over some heavy dependencies such as python, pytorch and pytorch geometric.
2 changes: 1 addition & 1 deletion GNNGraphs/src/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import NearestNeighbors
import NNlib
import StatsBase
import KrylovKit
using ChainRulesCore
import ChainRulesCore as CRC
using LinearAlgebra, Random, Statistics
import MLUtils
using MLUtils: getobs, numobs, ones_like, zeros_like, chunk, batch, rand_like
Expand Down
11 changes: 6 additions & 5 deletions GNNGraphs/src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# Taken from https://github.com/JuliaDiff/ChainRules.jl/pull/648
# Remove when merged

function ChainRulesCore.rrule(::Type{T}, ps::Pair...) where {T<:Dict}
function CRC.rrule(::Type{T}, ps::Pair...) where {T<:Dict}
ks = map(first, ps)
project_ks, project_vs = map(ProjectTo, ks), map(ProjectTolast, ps)
project_ks, project_vs = map(CRC.ProjectTo, ks), map(CRC.ProjectTo last, ps)
function Dict_pullback(ȳ)
dy = CRC.unthunk(ȳ)
dps = map(ks, project_ks, project_vs) do k, proj_k, proj_v
dk, dv = proj_k(getkey(, k, NoTangent())), proj_v(get(, k, NoTangent()))
Tangent{Pair{typeof(dk), typeof(dv)}}(first = dk, second = dv)
dk, dv = proj_k(getkey(dy, k, CRC.NoTangent())), proj_v(get(dy, k, CRC.NoTangent()))
CRC.Tangent{Pair{typeof(dk), typeof(dv)}}(first = dk, second = dv)
end
return (NoTangent(), dps...)
return (CRC.NoTangent(), dps...)
end
return T(ps...), Dict_pullback
end
2 changes: 1 addition & 1 deletion GNNGraphs/src/convert.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ function _findnz_idx(A)
return s, t, nz
end

@non_differentiable _findnz_idx(A)
CRC.@non_differentiable _findnz_idx(A)

function to_coo(A::ADJMAT_T; dir = :out, num_nodes = nothing, weighted = true)
s, t, nz = _findnz_idx(A)
Expand Down
2 changes: 1 addition & 1 deletion GNNGraphs/src/gnnheterograph/generate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ See [`rand_heterograph`](@ref) for a more general version.
# Examples
```julia-repl
```julia
julia> g = rand_bipartite_heterograph((10, 15), 20)
GNNHeteroGraph:
num_nodes: (:A => 10, :B => 15)
Expand Down
4 changes: 2 additions & 2 deletions GNNGraphs/src/gnnheterograph/gnnheterograph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ end

# TODO this is not correct but Zygote cannot differentiate
# through dictionary generation
# @non_differentiable edge_type_subgraph(::Any...)
# CRC.@non_differentiable edge_type_subgraph(::Any...)

function _ntypes_from_edges(edge_ts::AbstractVector{<:EType})
ntypes = Symbol[]
Expand All @@ -285,7 +285,7 @@ function _ntypes_from_edges(edge_ts::AbstractVector{<:EType})
return ntypes
end

@non_differentiable _ntypes_from_edges(::Any...)
CRC.@non_differentiable _ntypes_from_edges(::Any...)

function Base.getindex(g::GNNHeteroGraph, node_t::NType)
return g.ndata[node_t]
Expand Down
62 changes: 36 additions & 26 deletions GNNGraphs/src/query.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,37 +241,46 @@ function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType = eltype(g
return dir == :out ? A : A'
end

function ChainRulesCore.rrule(::typeof(adjacency_matrix), g::G, T::DataType;
function CRC.rrule(::typeof(adjacency_matrix), g::G, T::DataType;
dir = :out, weighted = true) where {G <: GNNGraph{<:ADJMAT_T}}
A = adjacency_matrix(g, T; dir, weighted)
if !weighted
function adjacency_matrix_pullback_noweight(Δ)
return (NoTangent(), ZeroTangent(), NoTangent())
return (CRC.NoTangent(), CRC.ZeroTangent(), CRC.NoTangent())
end
return A, adjacency_matrix_pullback_noweight
else
function adjacency_matrix_pullback_weighted(Δ)
dg = Tangent{G}(; graph = Δ .* binarize(A))
return (NoTangent(), dg, NoTangent())
dy = CRC.unthunk(Δ)
dg = CRC.Tangent{G}(; graph = dy .* binarize(dy))
return (CRC.NoTangent(), dg, CRC.NoTangent())
end
return A, adjacency_matrix_pullback_weighted
end
end

function ChainRulesCore.rrule(::typeof(adjacency_matrix), g::G, T::DataType;
function CRC.rrule(::typeof(adjacency_matrix), g::G, T::DataType;
dir = :out, weighted = true) where {G <: GNNGraph{<:COO_T}}
A = adjacency_matrix(g, T; dir, weighted)
w = get_edge_weight(g)
if !weighted || w === nothing
function adjacency_matrix_pullback_noweight(Δ)
return (NoTangent(), ZeroTangent(), NoTangent())
return (CRC.NoTangent(), CRC.ZeroTangent(), CRC.NoTangent())
end
return A, adjacency_matrix_pullback_noweight
else
function adjacency_matrix_pullback_weighted(Δ)
dy = CRC.unthunk(Δ)
s, t = edge_index(g)
dg = Tangent{G}(; graph = (NoTangent(), NoTangent(), NNlib.gather(Δ, s, t)))
return (NoTangent(), dg, NoTangent())
@show dy s t
#TODO using CRC.@thunk gives an error
#TODO use gather when https://github.com/FluxML/NNlib.jl/issues/625 is fixed
dw = zeros_like(w)
idx = CartesianIndex.(s, t) #TODO remove when https://github.com/FluxML/NNlib.jl/issues/626 is fixed
NNlib.gather!(dw, dy, idx)
@show dw
dg = CRC.Tangent{G}(; graph = (CRC.NoTangent(), CRC.NoTangent(), dw))
return (CRC.NoTangent(), dg, CRC.NoTangent())
end
return A, adjacency_matrix_pullback_weighted
end
Expand Down Expand Up @@ -378,34 +387,35 @@ function _degree(A::AbstractMatrix, T::Type, dir::Symbol, edge_weight::Bool, num
vec(sum(A, dims = 1)) .+ vec(sum(A, dims = 2))
end

function ChainRulesCore.rrule(::typeof(_degree), graph, T, dir, edge_weight::Nothing, num_nodes)
function CRC.rrule(::typeof(_degree), graph, T, dir, edge_weight::Nothing, num_nodes)
degs = _degree(graph, T, dir, edge_weight, num_nodes)
function _degree_pullback(Δ)
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent())
return ntuple(i -> (CRC.NoTangent(),), 6)
end
return degs, _degree_pullback
end

function ChainRulesCore.rrule(::typeof(_degree), A::ADJMAT_T, T, dir, edge_weight::Bool, num_nodes)
function CRC.rrule(::typeof(_degree), A::ADJMAT_T, T, dir, edge_weight::Bool, num_nodes)
degs = _degree(A, T, dir, edge_weight, num_nodes)
if edge_weight === false
function _degree_pullback_noweights(Δ)
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent())
return ntuple(i -> (CRC.NoTangent(),), 6)
end
return degs, _degree_pullback_noweights
else
function _degree_pullback_weights(Δ)
dy = CRC.unthunk(Δ)
# We propagate the gradient only to the non-zero elements
# of the adjacency matrix.
bA = binarize(A)
if dir == :in
dA = bA .* Δ'
dA = bA .* dy'
elseif dir == :out
dA = Δ .* bA
dA = dy .* bA
else # dir == :both
dA = Δ .* bA + Δ' .* bA
dA = dy .* bA + dy' .* bA
end
return (NoTangent(), dA, NoTangent(), NoTangent(), NoTangent(), NoTangent())
return (CRC.NoTangent(), dA, CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent())
end
return degs, _degree_pullback_weights
end
Expand Down Expand Up @@ -452,7 +462,7 @@ function normalized_adjacency(g::GNNGraph, T::DataType = Float32;
A = A + I
end
degs = vec(sum(A; dims = 2))
ChainRulesCore.ignore_derivatives() do
CRC.ignore_derivatives() do
@assert all(!iszero, degs) "Graph contains isolated nodes, cannot compute `normalized_adjacency`."
end
inv_sqrtD = Diagonal(inv.(sqrt.(degs)))
Expand Down Expand Up @@ -609,12 +619,12 @@ function laplacian_lambda_max(g::GNNGraph, T::DataType = Float32;
end
end

@non_differentiable edge_index(x...)
@non_differentiable adjacency_list(x...)
@non_differentiable graph_indicator(x...)
@non_differentiable has_multi_edges(x...)
@non_differentiable Graphs.has_self_loops(x...)
@non_differentiable is_bidirected(x...)
@non_differentiable normalized_adjacency(x...) # TODO remove this in the future
@non_differentiable normalized_laplacian(x...) # TODO remove this in the future
@non_differentiable scaled_laplacian(x...) # TODO remove this in the future
CRC.@non_differentiable edge_index(x...)
CRC.@non_differentiable adjacency_list(x...)
CRC.@non_differentiable graph_indicator(x...)
CRC.@non_differentiable has_multi_edges(x...)
CRC.@non_differentiable Graphs.has_self_loops(x...)
CRC.@non_differentiable is_bidirected(x...)
CRC.@non_differentiable normalized_adjacency(x...) # TODO remove this in the future
CRC.@non_differentiable normalized_laplacian(x...) # TODO remove this in the future
CRC.@non_differentiable scaled_laplacian(x...) # TODO remove this in the future
12 changes: 6 additions & 6 deletions GNNGraphs/src/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -808,8 +808,8 @@ function _unbatch_edgemasks(s, t, num_graphs, cumnum_nodes)
return edgemasks
end

@non_differentiable _unbatch_nodemasks(::Any...)
@non_differentiable _unbatch_edgemasks(::Any...)
CRC.@non_differentiable _unbatch_nodemasks(::Any...)
CRC.@non_differentiable _unbatch_edgemasks(::Any...)

"""
getgraph(g::GNNGraph, i; nmap=false)
Expand Down Expand Up @@ -998,10 +998,10 @@ dense_zeros_like(x, sz = size(x)) = dense_zeros_like(x, eltype(x), sz)
# """
ci2t(ci::AbstractVector{<:CartesianIndex}, dims) = ntuple(i -> map(x -> x[i], ci), dims)

@non_differentiable negative_sample(x...)
@non_differentiable add_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
@non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
@non_differentiable dense_zeros_like(x...)
CRC.@non_differentiable negative_sample(x...)
CRC.@non_differentiable add_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
CRC.@non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
CRC.@non_differentiable dense_zeros_like(x...)

"""
ppr_diffusion(g::GNNGraph{<:COO_T}, alpha =0.85f0) -> GNNGraph
Expand Down
12 changes: 6 additions & 6 deletions GNNGraphs/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,9 @@ end

binarize(x) = map(>(0), x)

@non_differentiable binarize(x...)
@non_differentiable edge_encoding(x...)
@non_differentiable edge_decoding(x...)
CRC.@non_differentiable binarize(x...)
CRC.@non_differentiable edge_encoding(x...)
CRC.@non_differentiable edge_decoding(x...)

### PRINTING #####

Expand Down Expand Up @@ -330,11 +330,11 @@ function dims2string(d)
join(map(string, d), '×')
end

@non_differentiable normalize_graphdata(::NamedTuple{(), Tuple{}})
@non_differentiable normalize_graphdata(::Nothing)
CRC.@non_differentiable normalize_graphdata(::NamedTuple{(), Tuple{}})
CRC.@non_differentiable normalize_graphdata(::Nothing)

iscuarray(x::AbstractArray) = false
@non_differentiable iscuarray(::Any)
CRC.@non_differentiable iscuarray(::Any)


@doc raw"""
Expand Down
3 changes: 1 addition & 2 deletions GNNGraphs/test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
Expand All @@ -9,6 +8,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand All @@ -20,7 +20,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[compat]
GPUArraysCore = "0.1"
6 changes: 0 additions & 6 deletions GNNlib/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ version = "1.0.0"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Expand All @@ -22,17 +21,12 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
GNNlibAMDGPUExt = "AMDGPU"
GNNlibCUDAExt = "CUDA"

# GPUArraysCore is not needed as a direct dependency
# but pinning it to 0.1 avoids problems when we do Pkg.add("CUDA") in testing
# See https://github.com/JuliaGPU/CUDA.jl/issues/2564

[compat]
AMDGPU = "1"
CUDA = "5"
ChainRulesCore = "1.24"
DataStructures = "0.18"
GNNGraphs = "1.4"
GPUArraysCore = "0.1"
LinearAlgebra = "1"
MLUtils = "0.4"
NNlib = "0.9"
Expand Down
4 changes: 2 additions & 2 deletions GNNlib/src/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ topk_index(y::Adjoint, k::Int) = topk_index(y', k)
function set2set_pool(l, g::GNNGraph, x::AbstractMatrix)
n_in = size(x, 1)
qstar = zeros_like(x, (2*n_in, g.num_graphs))
h = zeros_like(l.Wh, size(l.Wh, 2))
c = zeros_like(l.Wh, size(l.Wh, 2))
h = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2))
c = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2))
state = (h, c)
for t in 1:l.num_iters
q, state = l.lstm(qstar, state) # [n_in, n_graphs]
Expand Down
9 changes: 3 additions & 6 deletions GraphNeuralNetworks/examples/Project.toml
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"

[compat]
DiffEqFlux = "2"
Flux = "0.13"
GraphNeuralNetworks = "0.6"
Flux = "0.16"
GraphNeuralNetworks = "1"
Graphs = "1"
MLDatasets = "0.7"
julia = "1.9"
julia = "1.10"
5 changes: 1 addition & 4 deletions GraphNeuralNetworks/src/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,6 @@ function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1)
return Set2Set(lstm, n_iters)
end

function (l::Set2Set)(g, x)
m = (; l.lstm, l.num_iters, Wh = l.lstm.Wh)
return GNNlib.set2set_pool(m, g, x)
end
(l::Set2Set)(g, x) = GNNlib.set2set_pool(l, g, x)

(l::Set2Set)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g)))

0 comments on commit 306da07

Please sign in to comment.