diff --git a/eagerpy/framework.py b/eagerpy/framework.py index 206a303..4f1727b 100644 --- a/eagerpy/framework.py +++ b/eagerpy/framework.py @@ -172,7 +172,7 @@ def transpose(t: TensorType, axes: Optional[Axes] = None) -> TensorType: def diag(t: TensorType, k: int = 0) -> TensorType: - return t._diag(k) + return t.diag(k) @overload diff --git a/eagerpy/tensor/jax.py b/eagerpy/tensor/jax.py index 90532ca..71e322a 100644 --- a/eagerpy/tensor/jax.py +++ b/eagerpy/tensor/jax.py @@ -242,7 +242,7 @@ def transpose(self: TensorType, axes: Optional[Axes] = None) -> TensorType: axes = tuple(range(self.ndim - 1, -1, -1)) return type(self)(np.transpose(self.raw, axes=axes)) - def _diag(self: TensorType, k: int = 0) -> TensorType: + def diag(self: TensorType, k: int = 0) -> TensorType: return type(self)(np.diag(self.raw, k=k)) def all( diff --git a/eagerpy/tensor/numpy.py b/eagerpy/tensor/numpy.py index fa927d5..17df429 100644 --- a/eagerpy/tensor/numpy.py +++ b/eagerpy/tensor/numpy.py @@ -186,7 +186,7 @@ def transpose(self: TensorType, axes: Optional[Axes] = None) -> TensorType: axes = tuple(range(self.ndim - 1, -1, -1)) return type(self)(np.transpose(self.raw, axes=axes)) - def _diag(self: TensorType, k: int = 0) -> TensorType: + def diag(self: TensorType, k: int = 0) -> TensorType: return type(self)(np.diag(self.raw, k=k)) def all( diff --git a/eagerpy/tensor/pytorch.py b/eagerpy/tensor/pytorch.py index 60aebeb..8fa8b5f 100644 --- a/eagerpy/tensor/pytorch.py +++ b/eagerpy/tensor/pytorch.py @@ -275,7 +275,7 @@ def transpose(self: TensorType, axes: Optional[Axes] = None) -> TensorType: axes = tuple(range(self.ndim - 1, -1, -1)) return type(self)(self.raw.permute(*axes)) - def _diag(self: TensorType, k: int = 0) -> TensorType: + def diag(self: TensorType, k: int = 0) -> TensorType: return type(self)(torch.diag(self.raw, diagonal=k)) def all( diff --git a/eagerpy/tensor/tensor.py b/eagerpy/tensor/tensor.py index 6040bf3..21daa83 100644 --- a/eagerpy/tensor/tensor.py +++ b/eagerpy/tensor/tensor.py @@ -368,7 +368,7 @@ def transpose(self: TensorType, axes: Optional[Axes] = None) -> TensorType: ... @abstractmethod - def _diag(self: TensorType, k: int = 0) -> TensorType: + def diag(self: TensorType, k: int = 0) -> TensorType: ... @abstractmethod diff --git a/eagerpy/tensor/tensorflow.py b/eagerpy/tensor/tensorflow.py index 362cc74..f878f7b 100644 --- a/eagerpy/tensor/tensorflow.py +++ b/eagerpy/tensor/tensorflow.py @@ -257,7 +257,7 @@ def transpose(self: TensorType, axes: Optional[Axes] = None) -> TensorType: axes = tuple(range(self.ndim - 1, -1, -1)) return type(self)(tf.transpose(self.raw, perm=axes)) - def _diag(self: TensorType, k: int = 0) -> TensorType: + def diag(self: TensorType, k: int = 0) -> TensorType: if len(self.shape) == 1: return type(self)(tf.linalg.diag(self.raw, k=k)) else: