Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Created a copy of input tensor for evaluation #60

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

vkovinicTT
Copy link
Member

Summary

This PR ensures that a copy of the input tensors is created when running forward during compilation (e.g. when trying to verify outputs using forward pass of the framework model). This prevents unintended modifications to the original input tensors if in-place operations are performed during the forward pass.

Changes

  • Added logic to create copies of input tensors before passing them to the model.
  • Ensured that this applies to supported tensor types across different frameworks.
  • Prevents unexpected side effects when the model modifies inputs in-place.

Why is this needed?

Some models perform in-place operations on input tensors during the forward pass, which can lead to unintended changes in the original inputs. By making a copy of the inputs, we ensure correctness and avoid potential issues when running forward during compilation.

Example test:

    class Inplace(torch.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, x):
            y = x + 1 
            x += 2 # in-place operation on the input that causes the inputs to change during forge.compile(...)

            return x + y
        
    input = torch.zeros(shape, requires_grad=False)
    framework_input = input.detach().clone()
    tt_inputs = [input]

    framework_model = Inplace()
    y = framework_model(framework_input)
        
    compiled_model = forge.compile(framework_model, sample_inputs=tt_inputs, module_name="inplace")
    tty = compiled_model(*tt_inputs)[0]

    compare_with_golden(golden=y, calculated=tty) # this would fail
   

@vkovinicTT
Copy link
Member Author

vkovinicTT commented Feb 19, 2025

I was thinking of maybe trying out functionalize for torch as Knezevic has pointed that they do on torch frontend. Because as I've seen other frameworks (jax and I think tensorflow) are functional meaning no side effects and no changing of the inputs.

Because currently if we use torch and have in-place operation on input, torch will change input, but our compiled torch module won't, which might be a problem if we want to imitate torch's behaviour.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants