Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sequential.eval() does not put model into eval mode #1426

Open
brianberns opened this issue Dec 19, 2024 · 2 comments
Open

Sequential.eval() does not put model into eval mode #1426

brianberns opened this issue Dec 19, 2024 · 2 comments
Assignees

Comments

@brianberns
Copy link

brianberns commented Dec 19, 2024

Calling eval() should make training false. However, this does not work for Sequential modules.

Example F# program:

open type TorchSharp.torch.nn

let linear = Linear(10, 10)
linear.eval()
assert(not linear.training)       // succeeds

let sequential = Sequential(Linear(10, 10))
sequential.eval()
assert(not sequential.training)   // fails

I think the problem is that Sequential.train() should call base.train() in addition to calling train() for each submodule:

public override void train(bool on = true)
{
    foreach (var m in _modules) { ((torch.nn.Module)m).train(on); }
    base.train(on);
}
@yueyinqiu
Copy link
Contributor

yueyinqiu commented Dec 19, 2024

May not related to this issue, but actually I suppose that we shall reconsider about the submodules, especially the way we register them. Actually we have discussed this before here. I think the best approach would be to use source generators. But it will true add too much complexity, so previously we just consider it as a last resort and put it on hold. #1272 (comment)

@alinpahontu2912
Copy link
Member

Hey @brianberns, thanks for the issue. I tested it myself and you are right, there seems to be a problem with the Sequential module. For the moment, would you try creating and using your own custom module containing the Sequential module as specified in the wiki here ? Meaning something like this:

public class CustomModel : Module<Tensor, Tensor>
{
    private readonly Module<Tensor, Tensor> layers;
    public CustomModel()
        : base("CustomModel")
    {

        var modules = new List<(string, Module<Tensor, Tensor>)>();

        modules.Add(("lin1", Linear(10, 10)));

        layers = Sequential(modules);

        RegisterComponents();
    }

    public override Tensor forward(Tensor input)
    {
        return layers.forward(input);
    }

    protected override void Dispose(bool disposing)
    {
        if (disposing)
        {
            layers.Dispose();
        }
        base.Dispose(disposing);
    }
}

@alinpahontu2912 alinpahontu2912 self-assigned this Jan 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants