Skip to content

Commit

Permalink
use torch.OP##_ if available not torch.Tensor.OP##_ (Lightning-AI…
Browse files Browse the repository at this point in the history
  • Loading branch information
crcrpar authored and shino16 committed Sep 5, 2024
1 parent 6c7bfbb commit edd0685
Showing 1 changed file with 32 additions and 32 deletions.
64 changes: 32 additions & 32 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,7 +1331,7 @@ def abs(a: NumberLike | TensorLike, /) -> Number | TensorLike:
return clang.abs(a)


@torchsymbol(torch.Tensor.abs_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.abs_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def abs_(a: NumberLike | TensorLike, /) -> Number | TensorLike:
return prims.copy_(abs(a), a)

Expand All @@ -1341,7 +1341,7 @@ def acos(a: NumberLike | TensorLike, /) -> Number | TensorLike:
return clang.acos(a)


@torchsymbol(torch.Tensor.acos_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.acos_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def acos_(a: TensorLike, /) -> TensorLike:
return prims.copy_(acos(a), a)

Expand All @@ -1351,7 +1351,7 @@ def acosh(a):
return clang.acosh(a)


@torchsymbol(torch.Tensor.acosh_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.acosh_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def acosh_(a):
return prims.copy_(acosh(a), a)

Expand All @@ -1361,7 +1361,7 @@ def asin(a):
return clang.asin(a)


@torchsymbol(torch.Tensor.asin_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.asin_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def asin_(a):
return prims.copy_(asin(a), a)

Expand All @@ -1371,7 +1371,7 @@ def asinh(a):
return clang.asinh(a)


@torchsymbol(torch.Tensor.asinh_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.asinh_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def asinh_(a):
return prims.copy_(asinh(a), a)

Expand All @@ -1381,7 +1381,7 @@ def atan(a):
return clang.atan(a)


@torchsymbol(torch.Tensor.atan_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.atan_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def atan_(a):
return prims.copy_(atan(a), a)

Expand All @@ -1391,7 +1391,7 @@ def atanh(a):
return clang.atanh(a)


@torchsymbol(torch.Tensor.atanh_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.atanh_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def atanh_(a):
return prims.copy_(atanh(a), a)

Expand All @@ -1411,7 +1411,7 @@ def ceil(a):
return clang.ceil(a)


@torchsymbol(torch.Tensor.ceil_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.ceil_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def ceil_(a):
return prims.copy_(ceil(a), a)

Expand All @@ -1421,7 +1421,7 @@ def cos(a):
return clang.cos(a)


@torchsymbol(torch.Tensor.cos_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.cos_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def cos_(a):
return prims.copy_(cos(a), a)

Expand All @@ -1431,7 +1431,7 @@ def cosh(a):
return clang.cosh(a)


@torchsymbol(torch.Tensor.cosh_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.cosh_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def cosh_(a):
return prims.copy_(cosh(a), a)

Expand All @@ -1451,7 +1451,7 @@ def erf(a):
return clang.erf(a)


@torchsymbol(torch.Tensor.erf_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.erf_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def erf_(a):
return prims.copy_(erf(a), a)

Expand All @@ -1461,7 +1461,7 @@ def erfc(a):
return clang.erfc(a)


@torchsymbol(torch.Tensor.erfc_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.erfc_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def erfc_(a):
return prims.copy_(erfc(a), a)

Expand All @@ -1481,7 +1481,7 @@ def exp(a):
return clang.exp(a)


@torchsymbol(torch.Tensor.exp_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.exp_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def exp_(a):
return prims.copy_(exp(a), a)

Expand All @@ -1491,7 +1491,7 @@ def exp2(a):
return clang.exp2(a)


@torchsymbol(torch.Tensor.exp2_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.exp2_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def exp2_(a):
return prims.copy_(exp2(a), a)

Expand All @@ -1501,7 +1501,7 @@ def expm1(a):
return clang.expm1(a)


@torchsymbol(torch.Tensor.expm1_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.expm1_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def expm1_(a):
return prims.copy_(expm1(a), a)

Expand All @@ -1511,7 +1511,7 @@ def floor(a):
return clang.floor(a)


@torchsymbol(torch.Tensor.floor_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.floor_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def floor_(a):
return prims.copy_(floor(a), a)

Expand All @@ -1536,7 +1536,7 @@ def log(a):
return clang.log(a)


@torchsymbol(torch.Tensor.log_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.log_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def log_(a):
return prims.copy_(log(a), a)

Expand All @@ -1546,7 +1546,7 @@ def log10(a):
return clang.log10(a)


@torchsymbol(torch.Tensor.log10_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.log10_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def log10_(a):
return prims.copy_(log10(a), a)

Expand All @@ -1556,7 +1556,7 @@ def log1p(a):
return clang.log1p(a)


@torchsymbol(torch.Tensor.log1p_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.log1p_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def log1p_(a):
return prims.copy_(log1p(a), a)

Expand All @@ -1566,7 +1566,7 @@ def log2(a):
return clang.log2(a)


@torchsymbol(torch.Tensor.log2_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.log2_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def log2_(a):
return prims.copy_(log2(a), a)

Expand All @@ -1582,7 +1582,7 @@ def neg(a):
return clang.neg(a)


@torchsymbol(torch.Tensor.neg_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.neg_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def neg_(a):
return prims.copy_(neg(a), a)

Expand All @@ -1592,7 +1592,7 @@ def reciprocal(a):
return clang.reciprocal(a)


@torchsymbol(torch.Tensor.reciprocal_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.reciprocal_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def reciprocal_(a):
return prims.copy_(reciprocal(a), a)

Expand All @@ -1602,7 +1602,7 @@ def round(a):
return clang.round(a)


@torchsymbol(torch.Tensor.round_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.round_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def round_(a):
return prims.copy_(round(a), a)

Expand All @@ -1612,7 +1612,7 @@ def rsqrt(a):
return clang.rsqrt(a)


@torchsymbol(torch.Tensor.rsqrt_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.rsqrt_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def rsqrt_(a):
return prims.copy_(rsqrt(a), a)

Expand All @@ -1639,7 +1639,7 @@ def sin(a):
return clang.sin(a)


@torchsymbol(torch.Tensor.sin_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.sin_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def sin_(a):
return prims.copy_(sin(a), a)

Expand All @@ -1649,7 +1649,7 @@ def sinh(a):
return clang.sinh(a)


@torchsymbol(torch.Tensor.sinh_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.sinh_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def sinh_(a):
return prims.copy_(sinh(a), a)

Expand All @@ -1659,7 +1659,7 @@ def sqrt(a):
return clang.sqrt(a)


@torchsymbol(torch.Tensor.sqrt_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.sqrt_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def sqrt_(a):
return prims.copy_(sqrt(a), a)

Expand All @@ -1669,7 +1669,7 @@ def tan(a):
return clang.tan(a)


@torchsymbol(torch.Tensor.tan_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.tan_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def tan_(a):
return prims.copy_(tan(a), a)

Expand All @@ -1679,7 +1679,7 @@ def tanh(a):
return clang.tanh(a)


@torchsymbol(torch.Tensor.tanh_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.tanh_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def tanh_(a):
return prims.copy_(tanh(a), a)

Expand All @@ -1689,7 +1689,7 @@ def trunc(a):
return clang.trunc(a)


@torchsymbol(torch.Tensor.trunc_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.trunc_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def trunc_(a):
return prims.copy_(trunc(a), a)

Expand Down Expand Up @@ -2210,7 +2210,7 @@ def clamp(
return a


@torchsymbol(torch.Tensor.clamp_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.clamp_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def clamp_(
a: TensorLike, /, min: None | Number | TensorLike = None, max: None | Number | TensorLike = None
) -> TensorLike:
Expand Down Expand Up @@ -2873,7 +2873,7 @@ def index_put(
return clang.index_put(a, indices, values, accumulate)


@torchsymbol(torch.Tensor.index_put_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
@torchsymbol(torch.index_put_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def index_put_(
a: TensorLike,
/,
Expand Down

0 comments on commit edd0685

Please sign in to comment.