From dd88349f0f180a57a5f24a5417a1506201897c8e Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sat, 11 Jan 2025 19:07:46 +0100 Subject: [PATCH] fix zygote error --- GNNGraphs/src/GNNGraphs.jl | 2 +- GNNGraphs/src/chainrules.jl | 11 ++-- GNNGraphs/src/convert.jl | 2 +- .../src/gnnheterograph/gnnheterograph.jl | 4 +- GNNGraphs/src/query.jl | 62 +++++++++++-------- GNNGraphs/src/transform.jl | 12 ++-- GNNGraphs/src/utils.jl | 12 ++-- GNNGraphs/test/Project.toml | 3 +- GraphNeuralNetworks/examples/Project.toml | 9 +-- 9 files changed, 62 insertions(+), 55 deletions(-) diff --git a/GNNGraphs/src/GNNGraphs.jl b/GNNGraphs/src/GNNGraphs.jl index af451918c..1c4947d45 100644 --- a/GNNGraphs/src/GNNGraphs.jl +++ b/GNNGraphs/src/GNNGraphs.jl @@ -8,7 +8,7 @@ import NearestNeighbors import NNlib import StatsBase import KrylovKit -using ChainRulesCore +import ChainRulesCore as CRC using LinearAlgebra, Random, Statistics import MLUtils using MLUtils: getobs, numobs, ones_like, zeros_like, chunk, batch, rand_like diff --git a/GNNGraphs/src/chainrules.jl b/GNNGraphs/src/chainrules.jl index 6ef0b65aa..a3634128d 100644 --- a/GNNGraphs/src/chainrules.jl +++ b/GNNGraphs/src/chainrules.jl @@ -1,15 +1,16 @@ # Taken from https://github.com/JuliaDiff/ChainRules.jl/pull/648 # Remove when merged -function ChainRulesCore.rrule(::Type{T}, ps::Pair...) where {T<:Dict} +function CRC.rrule(::Type{T}, ps::Pair...) where {T<:Dict} ks = map(first, ps) - project_ks, project_vs = map(ProjectTo, ks), map(ProjectTo∘last, ps) + project_ks, project_vs = map(CRC.ProjectTo, ks), map(CRC.ProjectTo ∘ last, ps) function Dict_pullback(ȳ) + dy = CRC.unthunk(ȳ) dps = map(ks, project_ks, project_vs) do k, proj_k, proj_v - dk, dv = proj_k(getkey(ȳ, k, NoTangent())), proj_v(get(ȳ, k, NoTangent())) - Tangent{Pair{typeof(dk), typeof(dv)}}(first = dk, second = dv) + dk, dv = proj_k(getkey(dy, k, CRC.NoTangent())), proj_v(get(dy, k, CRC.NoTangent())) + CRC.Tangent{Pair{typeof(dk), typeof(dv)}}(first = dk, second = dv) end - return (NoTangent(), dps...) + return (CRC.NoTangent(), dps...) end return T(ps...), Dict_pullback end diff --git a/GNNGraphs/src/convert.jl b/GNNGraphs/src/convert.jl index 3789309cb..ea7ae500f 100644 --- a/GNNGraphs/src/convert.jl +++ b/GNNGraphs/src/convert.jl @@ -78,7 +78,7 @@ function _findnz_idx(A) return s, t, nz end -@non_differentiable _findnz_idx(A) +CRC.@non_differentiable _findnz_idx(A) function to_coo(A::ADJMAT_T; dir = :out, num_nodes = nothing, weighted = true) s, t, nz = _findnz_idx(A) diff --git a/GNNGraphs/src/gnnheterograph/gnnheterograph.jl b/GNNGraphs/src/gnnheterograph/gnnheterograph.jl index b449b42dd..6594ca55f 100644 --- a/GNNGraphs/src/gnnheterograph/gnnheterograph.jl +++ b/GNNGraphs/src/gnnheterograph/gnnheterograph.jl @@ -273,7 +273,7 @@ end # TODO this is not correct but Zygote cannot differentiate # through dictionary generation -# @non_differentiable edge_type_subgraph(::Any...) +# CRC.@non_differentiable edge_type_subgraph(::Any...) function _ntypes_from_edges(edge_ts::AbstractVector{<:EType}) ntypes = Symbol[] @@ -285,7 +285,7 @@ function _ntypes_from_edges(edge_ts::AbstractVector{<:EType}) return ntypes end -@non_differentiable _ntypes_from_edges(::Any...) +CRC.@non_differentiable _ntypes_from_edges(::Any...) function Base.getindex(g::GNNHeteroGraph, node_t::NType) return g.ndata[node_t] diff --git a/GNNGraphs/src/query.jl b/GNNGraphs/src/query.jl index 719cbfd17..76bc2dd28 100644 --- a/GNNGraphs/src/query.jl +++ b/GNNGraphs/src/query.jl @@ -241,37 +241,46 @@ function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType = eltype(g return dir == :out ? A : A' end -function ChainRulesCore.rrule(::typeof(adjacency_matrix), g::G, T::DataType; +function CRC.rrule(::typeof(adjacency_matrix), g::G, T::DataType; dir = :out, weighted = true) where {G <: GNNGraph{<:ADJMAT_T}} A = adjacency_matrix(g, T; dir, weighted) if !weighted function adjacency_matrix_pullback_noweight(Δ) - return (NoTangent(), ZeroTangent(), NoTangent()) + return (CRC.NoTangent(), CRC.ZeroTangent(), CRC.NoTangent()) end return A, adjacency_matrix_pullback_noweight else function adjacency_matrix_pullback_weighted(Δ) - dg = Tangent{G}(; graph = Δ .* binarize(A)) - return (NoTangent(), dg, NoTangent()) + dy = CRC.unthunk(Δ) + dg = CRC.Tangent{G}(; graph = dy .* binarize(dy)) + return (CRC.NoTangent(), dg, CRC.NoTangent()) end return A, adjacency_matrix_pullback_weighted end end -function ChainRulesCore.rrule(::typeof(adjacency_matrix), g::G, T::DataType; +function CRC.rrule(::typeof(adjacency_matrix), g::G, T::DataType; dir = :out, weighted = true) where {G <: GNNGraph{<:COO_T}} A = adjacency_matrix(g, T; dir, weighted) w = get_edge_weight(g) if !weighted || w === nothing function adjacency_matrix_pullback_noweight(Δ) - return (NoTangent(), ZeroTangent(), NoTangent()) + return (CRC.NoTangent(), CRC.ZeroTangent(), CRC.NoTangent()) end return A, adjacency_matrix_pullback_noweight else function adjacency_matrix_pullback_weighted(Δ) + dy = CRC.unthunk(Δ) s, t = edge_index(g) - dg = Tangent{G}(; graph = (NoTangent(), NoTangent(), NNlib.gather(Δ, s, t))) - return (NoTangent(), dg, NoTangent()) + @show dy s t + #TODO using CRC.@thunk gives an error + #TODO use gather when https://github.com/FluxML/NNlib.jl/issues/625 is fixed + dw = zeros_like(w) + idx = CartesianIndex.(s, t) #TODO remove when https://github.com/FluxML/NNlib.jl/issues/626 is fixed + NNlib.gather!(dw, dy, idx) + @show dw + dg = CRC.Tangent{G}(; graph = (CRC.NoTangent(), CRC.NoTangent(), dw)) + return (CRC.NoTangent(), dg, CRC.NoTangent()) end return A, adjacency_matrix_pullback_weighted end @@ -378,34 +387,35 @@ function _degree(A::AbstractMatrix, T::Type, dir::Symbol, edge_weight::Bool, num vec(sum(A, dims = 1)) .+ vec(sum(A, dims = 2)) end -function ChainRulesCore.rrule(::typeof(_degree), graph, T, dir, edge_weight::Nothing, num_nodes) +function CRC.rrule(::typeof(_degree), graph, T, dir, edge_weight::Nothing, num_nodes) degs = _degree(graph, T, dir, edge_weight, num_nodes) function _degree_pullback(Δ) - return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent()) + return ntuple(i -> (CRC.NoTangent(),), 6) end return degs, _degree_pullback end -function ChainRulesCore.rrule(::typeof(_degree), A::ADJMAT_T, T, dir, edge_weight::Bool, num_nodes) +function CRC.rrule(::typeof(_degree), A::ADJMAT_T, T, dir, edge_weight::Bool, num_nodes) degs = _degree(A, T, dir, edge_weight, num_nodes) if edge_weight === false function _degree_pullback_noweights(Δ) - return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent()) + return ntuple(i -> (CRC.NoTangent(),), 6) end return degs, _degree_pullback_noweights else function _degree_pullback_weights(Δ) + dy = CRC.unthunk(Δ) # We propagate the gradient only to the non-zero elements # of the adjacency matrix. bA = binarize(A) if dir == :in - dA = bA .* Δ' + dA = bA .* dy' elseif dir == :out - dA = Δ .* bA + dA = dy .* bA else # dir == :both - dA = Δ .* bA + Δ' .* bA + dA = dy .* bA + dy' .* bA end - return (NoTangent(), dA, NoTangent(), NoTangent(), NoTangent(), NoTangent()) + return (CRC.NoTangent(), dA, CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent()) end return degs, _degree_pullback_weights end @@ -452,7 +462,7 @@ function normalized_adjacency(g::GNNGraph, T::DataType = Float32; A = A + I end degs = vec(sum(A; dims = 2)) - ChainRulesCore.ignore_derivatives() do + CRC.ignore_derivatives() do @assert all(!iszero, degs) "Graph contains isolated nodes, cannot compute `normalized_adjacency`." end inv_sqrtD = Diagonal(inv.(sqrt.(degs))) @@ -609,12 +619,12 @@ function laplacian_lambda_max(g::GNNGraph, T::DataType = Float32; end end -@non_differentiable edge_index(x...) -@non_differentiable adjacency_list(x...) -@non_differentiable graph_indicator(x...) -@non_differentiable has_multi_edges(x...) -@non_differentiable Graphs.has_self_loops(x...) -@non_differentiable is_bidirected(x...) -@non_differentiable normalized_adjacency(x...) # TODO remove this in the future -@non_differentiable normalized_laplacian(x...) # TODO remove this in the future -@non_differentiable scaled_laplacian(x...) # TODO remove this in the future +CRC.@non_differentiable edge_index(x...) +CRC.@non_differentiable adjacency_list(x...) +CRC.@non_differentiable graph_indicator(x...) +CRC.@non_differentiable has_multi_edges(x...) +CRC.@non_differentiable Graphs.has_self_loops(x...) +CRC.@non_differentiable is_bidirected(x...) +CRC.@non_differentiable normalized_adjacency(x...) # TODO remove this in the future +CRC.@non_differentiable normalized_laplacian(x...) # TODO remove this in the future +CRC.@non_differentiable scaled_laplacian(x...) # TODO remove this in the future diff --git a/GNNGraphs/src/transform.jl b/GNNGraphs/src/transform.jl index dd7ef52f2..ea1fe9583 100644 --- a/GNNGraphs/src/transform.jl +++ b/GNNGraphs/src/transform.jl @@ -808,8 +808,8 @@ function _unbatch_edgemasks(s, t, num_graphs, cumnum_nodes) return edgemasks end -@non_differentiable _unbatch_nodemasks(::Any...) -@non_differentiable _unbatch_edgemasks(::Any...) +CRC.@non_differentiable _unbatch_nodemasks(::Any...) +CRC.@non_differentiable _unbatch_edgemasks(::Any...) """ getgraph(g::GNNGraph, i; nmap=false) @@ -998,10 +998,10 @@ dense_zeros_like(x, sz = size(x)) = dense_zeros_like(x, eltype(x), sz) # """ ci2t(ci::AbstractVector{<:CartesianIndex}, dims) = ntuple(i -> map(x -> x[i], ci), dims) -@non_differentiable negative_sample(x...) -@non_differentiable add_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule -@non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule -@non_differentiable dense_zeros_like(x...) +CRC.@non_differentiable negative_sample(x...) +CRC.@non_differentiable add_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule +CRC.@non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule +CRC.@non_differentiable dense_zeros_like(x...) """ ppr_diffusion(g::GNNGraph{<:COO_T}, alpha =0.85f0) -> GNNGraph diff --git a/GNNGraphs/src/utils.jl b/GNNGraphs/src/utils.jl index a5a8bae67..a6e96a3ab 100644 --- a/GNNGraphs/src/utils.jl +++ b/GNNGraphs/src/utils.jl @@ -292,9 +292,9 @@ end binarize(x) = map(>(0), x) -@non_differentiable binarize(x...) -@non_differentiable edge_encoding(x...) -@non_differentiable edge_decoding(x...) +CRC.@non_differentiable binarize(x...) +CRC.@non_differentiable edge_encoding(x...) +CRC.@non_differentiable edge_decoding(x...) ### PRINTING ##### @@ -330,11 +330,11 @@ function dims2string(d) join(map(string, d), '×') end -@non_differentiable normalize_graphdata(::NamedTuple{(), Tuple{}}) -@non_differentiable normalize_graphdata(::Nothing) +CRC.@non_differentiable normalize_graphdata(::NamedTuple{(), Tuple{}}) +CRC.@non_differentiable normalize_graphdata(::Nothing) iscuarray(x::AbstractArray) = false -@non_differentiable iscuarray(::Any) +CRC.@non_differentiable iscuarray(::Any) @doc raw""" diff --git a/GNNGraphs/test/Project.toml b/GNNGraphs/test/Project.toml index d104c951a..f18c35628 100644 --- a/GNNGraphs/test/Project.toml +++ b/GNNGraphs/test/Project.toml @@ -1,5 +1,4 @@ [deps] -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c" @@ -9,6 +8,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -20,7 +20,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" -cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [compat] GPUArraysCore = "0.1" diff --git a/GraphNeuralNetworks/examples/Project.toml b/GraphNeuralNetworks/examples/Project.toml index 2e56578bd..ba7f4ff2a 100644 --- a/GraphNeuralNetworks/examples/Project.toml +++ b/GraphNeuralNetworks/examples/Project.toml @@ -1,6 +1,5 @@ [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0" DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" @@ -8,12 +7,10 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" [compat] -DiffEqFlux = "2" -Flux = "0.13" -GraphNeuralNetworks = "0.6" +Flux = "0.16" +GraphNeuralNetworks = "1" Graphs = "1" MLDatasets = "0.7" -julia = "1.9" +julia = "1.10"