From b71214cb5172a1305f1d0e030768974812ce8fcc Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Wed, 13 Nov 2024 15:12:40 +0100 Subject: [PATCH] update comment --- thunder/core/transforms.py | 2 +- thunder/torch/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 94596b732..9ebfd66ca 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -2486,7 +2486,7 @@ def is_constant_for_vjp(symbol: prims.Symbol) -> bool: bool: True if the symbol is constant, False otherwise. """ are_all_args_non_differentiable = not any(isinstance(arg, (FloatProxy, TensorProxy)) for arg in symbol.flat_args) - # `no_grad_detach_graph_pass` tags output of BoundSymbols in `torch.no_grad` regions with `DETACHED_AUTOGRAD_GRAPH`. + # Symbol's tag their output in `torch.no_grad` regions with `DETACHED_AUTOGRAD_GRAPH`. # These are treated as constant for VJP. # NOTE - `any(()) is False` output_disconnected_from_graph = any( diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 4f497ecaf..c15fd8b35 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -5228,7 +5228,7 @@ def torch_device(type: DeviceLike, index: int | None = None) -> devices.Device: # Tag to use on Proxies created in `no_grad` regions. -# VJP transform will treat BOundSymbol's whose output has these tags +# VJP transform will treat BoundSymbol's whose output has these tags # as constant. ProxyTag.register_tag("DETACHED_AUTOGRAD_GRAPH")