From 8f5026f91a1f85f858aa8ffd042af067961fe623 Mon Sep 17 00:00:00 2001 From: beverlylytle <57254617+beverlylytle@users.noreply.github.com> Date: Wed, 13 Nov 2024 13:32:04 +0100 Subject: [PATCH] add elu op (#1417) --- thunder/executors/torchex.py | 2 ++ thunder/tests/opinfos.py | 16 ++++++++++++++-- thunder/tests/test_grad.py | 4 ++-- thunder/torch/__init__.py | 12 ++++++++++++ thunder/torch/default_torch_ops.py | 1 - 5 files changed, 30 insertions(+), 5 deletions(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index afff715728..7b363606e8 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -803,6 +803,7 @@ def _erfcinv_impl(a: torch.Tensor) -> torch.Tensor: # nn.functional elementwise unary celu = _register_torch_operation("celu", module=torch.nn.functional) +elu = _register_torch_operation("elu", module=torch.nn.functional) gelu = _register_torch_operation("gelu", module=torch.nn.functional) relu = _register_torch_operation("relu", module=torch.nn.functional) relu6 = _register_torch_operation("relu6", module=torch.nn.functional) @@ -815,6 +816,7 @@ def _elementwise_unary_with_inplace_checker(a: TensorProxy, /, inplace: bool = F return isinstance(a, TensorProxy) and not inplace +_register_elementwise_unary_implementation(ltorch.elu, elu, checker=_always_executable) _register_elementwise_unary_implementation(ltorch.celu, celu, checker=_always_executable) _register_elementwise_unary_implementation(ltorch.gelu, gelu, checker=_always_executable) _register_elementwise_unary_implementation(ltorch.relu, relu, checker=_elementwise_unary_with_inplace_checker) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 6f8dfcbf5e..b8daca7de2 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -1633,7 +1633,7 @@ def _abs_torch(x: torch.Tensor | Number): elementwise_unary_ops.append(reciprocal_opinfo) -def celu_sample_generator(op, device, dtype, requires_grad): +def elementwise_unary_with_alpha_generator(op, device, dtype, requires_grad): alphas = (None, -1.0, 0.5) samples = elementwise_unary_generator(op, device, dtype, requires_grad) for alpha, sample in itertools.product(alphas, samples): @@ -1646,13 +1646,25 @@ def celu_sample_generator(op, device, dtype, requires_grad): celu_opinfo = OpInfo( ltorch.celu, dtypes=(datatypes.floating,), - sample_input_generator=celu_sample_generator, + sample_input_generator=elementwise_unary_with_alpha_generator, torch_reference=_elementwise_unary_torch(torch.celu), test_directives=(), ) elementwise_unary_ops.append(celu_opinfo) +elu_opinfo = OpInfo( + ltorch.elu, + dtypes=(datatypes.floating,), + sample_input_generator=elementwise_unary_with_alpha_generator, + torch_reference=torch.nn.functional.elu, + # fdm.jvp, which is used in test_vjp_correctness, behaves badly on (-1e-6, 1e-6) for this function + singularity_fn=lambda x: x, + test_directives=(), +) +elementwise_unary_ops.append(elu_opinfo) + + relu_opinfo = OpInfo( ltorch.relu, sample_input_generator=elementwise_unary_generator, diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 8f64ea4493..a79ce2d80b 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -281,7 +281,7 @@ def check_vjp(f, *primals, comp, executor="torch", set_compile_data: bool = Fals """ # Let f be a function from vectors of size n to vectors of size m. # Its Jacobian is a matrix J of size m x n. - # The adjoint property is J^* J = I, where J^* is the conjugate transpose (adjoint) of J. + # Represent by J^* the conjugate transpose (adjoint) of J. # J^* is a matrix of size n x m. # For any vector v of size m, J^* v is a vector of size n. # For any vector u of size n, J u is a vector of size m. @@ -296,7 +296,7 @@ def check_vjp(f, *primals, comp, executor="torch", set_compile_data: bool = Fals u = tree_map(make, primals) - comp_f = thunder.jit(f) + comp_f = thunder.jit(f, disable_torch_autograd=True) outs_p, J_u = numerical_jvp(comp_f)(primals, u) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 7047485ff6..b216d2a684 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -1774,6 +1774,18 @@ def celu(a: TensorLike, /, alpha: float = 1.0, inplace: bool = False) -> TensorL _inplace_to_out_of_place[celu] = celu, 2 +@torchsymbol(torch.nn.functional.elu, is_method=False) +def elu(a: TensorProxy, /, alpha: float = 1.0, inplace: bool = False) -> TensorLike: + negative_domain_value = alpha * expm1(a) + out = where(a > 0, a, negative_domain_value) + if inplace: + return prims.copy_(out, a) + return out + + +_inplace_to_out_of_place[elu] = elu, 2 + + @torchsymbol(torch.nn.functional.gelu, is_method=False) def gelu(a: TensorProxy, /, *, approximate: str = "none") -> TensorLike: if approximate == "none": diff --git a/thunder/torch/default_torch_ops.py b/thunder/torch/default_torch_ops.py index 35d5170508..84e0ae0f90 100644 --- a/thunder/torch/default_torch_ops.py +++ b/thunder/torch/default_torch_ops.py @@ -338,7 +338,6 @@ torch.nn.functional.dropout1d, torch.nn.functional.dropout2d, torch.nn.functional.dropout3d, - torch.nn.functional.elu, torch.nn.functional.embedding_bag, torch.nn.functional.feature_alpha_dropout, torch.nn.functional.fold,