Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support case where "num_layers" of 'subnetwork' > 1 in torch_helpers #148

Open
nvssynthesis opened this issue Oct 26, 2024 · 1 comment
Open

Comments

@nvssynthesis
Copy link

Not sure of the best language to describe this succinctly, but in PyTorch, there are modules with the parameter "num_layers", e.g. nn.GRU and nn.LSTM. However, the helpers in torch_helpers seem to assume that this parameter is set to 1.

For example, torch_helpers::loadGRU gets weights and biases whose suffix is always 0. When num_layers > 1, the suffixes will increment. This logic seems doable on the RTNeural end, unless there's some extra complications I'm not aware of.

I believe I can work around this by setting up the network differently in PyTorch, using nn.Sequential to manually stack GRU layers if necessary. Still, this seems important to implement, because I had a few hours where I thought everything was 'working' and the model sounded really bad, but in fact the PyTorch model had a layer's worth of weights and biases that had not been loaded into the RTNeural model.

@jatinchowdhury18
Copy link
Owner

Thanks for bringing this up!

It seems like the simplest thing we could do would be to add an "index" parameter to the torch_helpers::loadGRU() and torch_helpers::loadLSTM() methods? I guess it would be possible to create a kind of "wrapper" layer that contains several GRU/LSTM layers internally to help streamline the process of using/loading weights for those types of layers. Although that's a little bit outside the scope of what RTNeural has supported up to this point.

I'll have a think about ways to potentially detect errors of that sort in the model-loading process... I'm sure there's more we can do, but I don't think it would be realistic to expect RTNeural to be able to catch all errors of that sort.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants