Skip to content

Commit

Permalink
add comment
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 committed Nov 13, 2024
1 parent 3f7a25d commit a082ba3
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 1 deletion.
2 changes: 2 additions & 0 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,8 @@ def get_computation_and_inputs(*args, **kwargs):
# which seems to break the consistency of cache_info, leading to a failure in cache_info check.
cache_info["alias_tensor_indices"] = _alias_tensor_of_args_kwargs(*args, **kwargs)

# Store the `is_grad_enabled` state of PyTorch. This is used by vjp transform
# to treat certain Symbols as constant.
cache_info["is_grad_enabled"] = pytorch.is_grad_enabled()
cd.is_grad_enabled = pytorch.is_grad_enabled()

Expand Down
2 changes: 2 additions & 0 deletions thunder/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ def __init__(
# State for pytorch autocast context managers.
self.autocast_stack: AutocastStack = AutocastStack()

# State to query whether grad is enabled or disabled using
# torch.no_grad/torch.enable_grad/torch._C._set_grad_enabled
self.is_grad_enabled: bool = True

#
Expand Down
4 changes: 3 additions & 1 deletion thunder/core/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,9 @@ def __call__(self, *args, **kwargs):

cd = get_compile_data()
if cd is not None and not cd.is_grad_enabled:

# If grad is disabled using `torch.no_grad` or `torch._C._set_grad_enabled(False)`,
# tag the results with `DETACHED_AUTOGRAD_GRAPH` which makes this Symbol a constant for
# vjp transform (applied later).
def tag_tensorproxy_output_as_detached(proxy):
if isinstance(proxy, TensorProxy):
return proxy.replace(tags=(ProxyTag.DETACHED_AUTOGRAD_GRAPH,))
Expand Down

0 comments on commit a082ba3

Please sign in to comment.