Skip to content

Commit

Permalink
add elu op (#1417)
Browse files Browse the repository at this point in the history
  • Loading branch information
beverlylytle authored Nov 13, 2024
1 parent 7912102 commit 8f5026f
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 5 deletions.
2 changes: 2 additions & 0 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,7 @@ def _erfcinv_impl(a: torch.Tensor) -> torch.Tensor:

# nn.functional elementwise unary
celu = _register_torch_operation("celu", module=torch.nn.functional)
elu = _register_torch_operation("elu", module=torch.nn.functional)
gelu = _register_torch_operation("gelu", module=torch.nn.functional)
relu = _register_torch_operation("relu", module=torch.nn.functional)
relu6 = _register_torch_operation("relu6", module=torch.nn.functional)
Expand All @@ -815,6 +816,7 @@ def _elementwise_unary_with_inplace_checker(a: TensorProxy, /, inplace: bool = F
return isinstance(a, TensorProxy) and not inplace


_register_elementwise_unary_implementation(ltorch.elu, elu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.celu, celu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.gelu, gelu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.relu, relu, checker=_elementwise_unary_with_inplace_checker)
Expand Down
16 changes: 14 additions & 2 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1633,7 +1633,7 @@ def _abs_torch(x: torch.Tensor | Number):
elementwise_unary_ops.append(reciprocal_opinfo)


def celu_sample_generator(op, device, dtype, requires_grad):
def elementwise_unary_with_alpha_generator(op, device, dtype, requires_grad):
alphas = (None, -1.0, 0.5)
samples = elementwise_unary_generator(op, device, dtype, requires_grad)
for alpha, sample in itertools.product(alphas, samples):
Expand All @@ -1646,13 +1646,25 @@ def celu_sample_generator(op, device, dtype, requires_grad):
celu_opinfo = OpInfo(
ltorch.celu,
dtypes=(datatypes.floating,),
sample_input_generator=celu_sample_generator,
sample_input_generator=elementwise_unary_with_alpha_generator,
torch_reference=_elementwise_unary_torch(torch.celu),
test_directives=(),
)
elementwise_unary_ops.append(celu_opinfo)


elu_opinfo = OpInfo(
ltorch.elu,
dtypes=(datatypes.floating,),
sample_input_generator=elementwise_unary_with_alpha_generator,
torch_reference=torch.nn.functional.elu,
# fdm.jvp, which is used in test_vjp_correctness, behaves badly on (-1e-6, 1e-6) for this function
singularity_fn=lambda x: x,
test_directives=(),
)
elementwise_unary_ops.append(elu_opinfo)


relu_opinfo = OpInfo(
ltorch.relu,
sample_input_generator=elementwise_unary_generator,
Expand Down
4 changes: 2 additions & 2 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def check_vjp(f, *primals, comp, executor="torch", set_compile_data: bool = Fals
"""
# Let f be a function from vectors of size n to vectors of size m.
# Its Jacobian is a matrix J of size m x n.
# The adjoint property is J^* J = I, where J^* is the conjugate transpose (adjoint) of J.
# Represent by J^* the conjugate transpose (adjoint) of J.
# J^* is a matrix of size n x m.
# For any vector v of size m, J^* v is a vector of size n.
# For any vector u of size n, J u is a vector of size m.
Expand All @@ -296,7 +296,7 @@ def check_vjp(f, *primals, comp, executor="torch", set_compile_data: bool = Fals

u = tree_map(make, primals)

comp_f = thunder.jit(f)
comp_f = thunder.jit(f, disable_torch_autograd=True)

outs_p, J_u = numerical_jvp(comp_f)(primals, u)

Expand Down
12 changes: 12 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1774,6 +1774,18 @@ def celu(a: TensorLike, /, alpha: float = 1.0, inplace: bool = False) -> TensorL
_inplace_to_out_of_place[celu] = celu, 2


@torchsymbol(torch.nn.functional.elu, is_method=False)
def elu(a: TensorProxy, /, alpha: float = 1.0, inplace: bool = False) -> TensorLike:
negative_domain_value = alpha * expm1(a)
out = where(a > 0, a, negative_domain_value)
if inplace:
return prims.copy_(out, a)
return out


_inplace_to_out_of_place[elu] = elu, 2


@torchsymbol(torch.nn.functional.gelu, is_method=False)
def gelu(a: TensorProxy, /, *, approximate: str = "none") -> TensorLike:
if approximate == "none":
Expand Down
1 change: 0 additions & 1 deletion thunder/torch/default_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,6 @@
torch.nn.functional.dropout1d,
torch.nn.functional.dropout2d,
torch.nn.functional.dropout3d,
torch.nn.functional.elu,
torch.nn.functional.embedding_bag,
torch.nn.functional.feature_alpha_dropout,
torch.nn.functional.fold,
Expand Down

0 comments on commit 8f5026f

Please sign in to comment.