Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

forward_prop fails when the input to the forward pass supposed to be a list or dict #104

Open
Erotemic opened this issue Sep 2, 2023 · 0 comments

Comments

@Erotemic
Copy link

Erotemic commented Sep 2, 2023

Describe the bug
Currently if torchview detects gets a list or dict as the input to forward it assumes the user will have a forward function in their model that accepts that mapping as keyword arguments or the list as a set of positional arguments. This is not always true. For instance, I often will not collate my batches and I will pass a list of dictionaries with a lot of information for how the forward pass should procede. Here is a MWE that demonstrates this:

To Reproduce

import torch

class MyMWE(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.stem1 = torch.nn.Linear(2, 3)
        self.stem2 = torch.nn.Linear(2, 3)
        self.body = torch.nn.Linear(3, 3)
        self.head1 = torch.nn.Linear(3, 5)
        self.head2 = torch.nn.Linear(3, 7)

    def forward(self, batch):
        batch_outputs = []
        for item in batch:
            if 'domain1' in item:
                feat = self.stem1(item['domain1'])
            elif 'domain2' in item:
                feat = self.stem2(item['domain2'])
            else:
                raise ValueError

            hidden = self.body(feat)
            logits1 = self.head1(hidden)
            logits2 = self.head2(hidden)
            output = {
                'logits1': logits1,
                'logits2': logits2,
            }
            batch_outputs.append(output)
        return output

def main():
    from torchview import draw_graph
    model = MyMWE()

    batch = [
        {'domain1': torch.Tensor([1, 2])},
        {'domain2': torch.Tensor([1, 2])},
        {'domain1': torch.Tensor([1, 2])},
    ]

    # Verify a normal forward pass works
    output = model(batch)

    # Check if draw_graph works
    model_graph = draw_graph(
        model,
        input_data=batch,
        expand_nested=True,
        hide_inner_tensors=True,
        device='meta', depth=9001)
    model_graph.visual_graph.view()


if __name__ == '__main__':
    main()

Expected behavior

This errors, but if I remove the * and ** from:

                if isinstance(x, (list, tuple)):
                    _ = model.to(device)(*x, **kwargs)
                elif isinstance(x, Mapping):
                    _ = model.to(device)(**x, **kwargs)

in forward_prop it works fine and I get a nice graph like:

image

Additional context

I'm not sure what the correct fix for this is, but an easy fix would be to let the user pass a flag that specifies the input data should be passed in as-is without assumptions on trochview's part.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant