diff --git a/thunder/core/symbol.py b/thunder/core/symbol.py index ead4044b5..e5688a013 100644 --- a/thunder/core/symbol.py +++ b/thunder/core/symbol.py @@ -328,6 +328,8 @@ def __call__(self, *args, **kwargs): # vjp transform (applied later). def tag_tensorproxy_output_as_detached(proxy): if isinstance(proxy, TensorProxy): + # We need to remove name from trace, otherwise replace will return a proxy with new name. + trace.names.remove(proxy.name) return proxy.replace(tags=(ProxyTag.DETACHED_AUTOGRAD_GRAPH,)) return proxy