-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generating and Feeding a Custom Dataset into T-GCN #583
Comments
This is an example with fixed graph topology and time-varying node features julia> using GraphNeuralNetworks
julia> num_nodes, num_edges = 5, 10;
julia> d_in, d_out = 2, 3;
julia> timesteps = 5;
julia> g = rand_graph(num_nodes, num_edges);
julia> x = rand(Float32, d_in, timesteps, num_nodes);
julia> layer = TGCN(d_in => d_out)
GNNRecurrence(
TGCNCell(2 => 3), # 126 parameters
) # Total: 18 arrays, 126 parameters, 1.469 KiB.
julia> y = layer(g, x);
julia> size(y) # (d_out, timesteps, num_nodes)
(3, 5, 5)
julia> all(isfinite, y)
true Normally the layer would be part of a larger model, with some node embedding/projection at the beginning and a classification/regression head at the end. The fact that you observe NaNs at initialization is very weird. Can you provide an example? |
Actually, I want to generate a dataset with the same structure as the following code:
In this case, the shapes of features and targets follow:
For example, in the METRLA dataset:
To generate a random dataset with the same structure, I used the following function:
However, despite following this structure, I am still encountering NaN values in the model output. Could you confirm whether this dataset structure is correctly aligned with what T-GCN expects? Additionally, are there any specific pre-processing steps or normalizations required to avoid NaN values? |
Let me first say that in GNN.jl v1.0 the temporal convolutions have been changed and now the temporal dimension is the second to last dimension. The tutorials have been updated only recently, one has to look at the For the dataset, I would create a custom class. Also i don't see NaNs in the following training code: using GraphNeuralNetworks, Flux
struct Dataset{D}
data::D
length::Int
num_timesteps::Int
end
function Dataset(; num_features, num_nodes, total_length, num_timesteps)
data = rand(Float32, num_features, total_length, num_nodes)
length = total_length - num_timesteps + 1
return Dataset(data, length, num_timesteps)
end
Base.length(d::Dataset) = d.length
function Base.getindex(d::Dataset, i::Int)
(1 <= i <= length(d)) || throw(ArgumentError("Index out of bounds."))
x = d.data[:, i:i+d.num_timesteps-1, :]
# the target is the first feature at next time step
y = d.data[1, i+1:i+d.num_timesteps, :]
return x, y
end
Base.getindex(d::Dataset, is::AbstractVector) = [d[i] for i in is]
num_timesteps = 2
num_nodes = 3
num_features = 2
d = Dataset(; num_features, num_nodes, total_length=1000, num_timesteps)
g = rand_graph(num_nodes, 6)
train_data = d[1:100]
val_data = d[101+num_timesteps:200]
train_loader = DataLoader(train_data, shuffle=true, batchsize=-1)
val_loader = DataLoader(val_data, batchsize=-1)
model = GNNChain(Dense(num_features => 64, relu),
TGCN(64 => 64),
x -> relu.(x),
TGCN(64 => 64),
x -> relu.(x),
Dense(64, 1),
flatten)
opt_state = Flux.setup(Flux.AdamW(0.001), model)
for epoch in 1:10
for (x, y) in train_loader
grads = Flux.gradient(m -> Flux.mse(model(g, x), y), model)
Flux.update!(opt_state, model, grads[1])
end
end
|
Hello,
Thank you for this implementation of T-GCN! I am trying to generate my own random dataset and feed it into the TGCN network instead of the METRLA dataset (in traffic_prediction.jl example). However, I am running into an issue where the model's output is NaN, both before and after training.
Could you please provide guidance on how to properly format a custom dataset and ensure it works correctly with the model?
Specifically:
Any function, advice or example code for using a different dataset would be greatly appreciated.
Thanks in advance!
The text was updated successfully, but these errors were encountered: