Skip to content

Commit

Permalink
update comment
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 committed Nov 13, 2024
1 parent a082ba3 commit b71214c
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2486,7 +2486,7 @@ def is_constant_for_vjp(symbol: prims.Symbol) -> bool:
bool: True if the symbol is constant, False otherwise.
"""
are_all_args_non_differentiable = not any(isinstance(arg, (FloatProxy, TensorProxy)) for arg in symbol.flat_args)
# `no_grad_detach_graph_pass` tags output of BoundSymbols in `torch.no_grad` regions with `DETACHED_AUTOGRAD_GRAPH`.
# Symbol's tag their output in `torch.no_grad` regions with `DETACHED_AUTOGRAD_GRAPH`.
# These are treated as constant for VJP.
# NOTE - `any(()) is False`
output_disconnected_from_graph = any(
Expand Down
2 changes: 1 addition & 1 deletion thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5228,7 +5228,7 @@ def torch_device(type: DeviceLike, index: int | None = None) -> devices.Device:


# Tag to use on Proxies created in `no_grad` regions.
# VJP transform will treat BOundSymbol's whose output has these tags
# VJP transform will treat BoundSymbol's whose output has these tags
# as constant.
ProxyTag.register_tag("DETACHED_AUTOGRAD_GRAPH")

Expand Down

0 comments on commit b71214c

Please sign in to comment.