Skip to content

Commit

Permalink
Rename _diag -> diag.
Browse files Browse the repository at this point in the history
  • Loading branch information
Holt59 committed Oct 24, 2020
1 parent 8b20f6b commit 3970044
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion eagerpy/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def transpose(t: TensorType, axes: Optional[Axes] = None) -> TensorType:


def diag(t: TensorType, k: int = 0) -> TensorType:
return t._diag(k)
return t.diag(k)


@overload
Expand Down
2 changes: 1 addition & 1 deletion eagerpy/tensor/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def transpose(self: TensorType, axes: Optional[Axes] = None) -> TensorType:
axes = tuple(range(self.ndim - 1, -1, -1))
return type(self)(np.transpose(self.raw, axes=axes))

def _diag(self: TensorType, k: int = 0) -> TensorType:
def diag(self: TensorType, k: int = 0) -> TensorType:
return type(self)(np.diag(self.raw, k=k))

def all(
Expand Down
2 changes: 1 addition & 1 deletion eagerpy/tensor/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def transpose(self: TensorType, axes: Optional[Axes] = None) -> TensorType:
axes = tuple(range(self.ndim - 1, -1, -1))
return type(self)(np.transpose(self.raw, axes=axes))

def _diag(self: TensorType, k: int = 0) -> TensorType:
def diag(self: TensorType, k: int = 0) -> TensorType:
return type(self)(np.diag(self.raw, k=k))

def all(
Expand Down
2 changes: 1 addition & 1 deletion eagerpy/tensor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def transpose(self: TensorType, axes: Optional[Axes] = None) -> TensorType:
axes = tuple(range(self.ndim - 1, -1, -1))
return type(self)(self.raw.permute(*axes))

def _diag(self: TensorType, k: int = 0) -> TensorType:
def diag(self: TensorType, k: int = 0) -> TensorType:
return type(self)(torch.diag(self.raw, diagonal=k))

def all(
Expand Down
2 changes: 1 addition & 1 deletion eagerpy/tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def transpose(self: TensorType, axes: Optional[Axes] = None) -> TensorType:
...

@abstractmethod
def _diag(self: TensorType, k: int = 0) -> TensorType:
def diag(self: TensorType, k: int = 0) -> TensorType:
...

@abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion eagerpy/tensor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def transpose(self: TensorType, axes: Optional[Axes] = None) -> TensorType:
axes = tuple(range(self.ndim - 1, -1, -1))
return type(self)(tf.transpose(self.raw, perm=axes))

def _diag(self: TensorType, k: int = 0) -> TensorType:
def diag(self: TensorType, k: int = 0) -> TensorType:
if len(self.shape) == 1:
return type(self)(tf.linalg.diag(self.raw, k=k))
else:
Expand Down

0 comments on commit 3970044

Please sign in to comment.