Skip to content

Commit

Permalink
fixed the bug with numpy() return values sharing memory and being wri…
Browse files Browse the repository at this point in the history
…teable
  • Loading branch information
Jonas Rauber committed Feb 12, 2020
1 parent 94ccc1b commit 19bce20
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 4 deletions.
4 changes: 3 additions & 1 deletion eagerpy/tensor/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def _get_subkey(cls) -> Any:
return subkey

def numpy(self) -> Any:
return onp.asarray(self.raw)
a = onp.asarray(self.raw)
assert a.flags.writeable is False
return a

def item(self) -> Union[int, float, bool]:
return self.raw.item() # type: ignore
Expand Down
7 changes: 6 additions & 1 deletion eagerpy/tensor/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ def raw(self) -> "np.ndarray": # type: ignore
return super().raw

def numpy(self: TensorType) -> Any:
return self.raw
a = self.raw.view()
if a.flags.writeable:
# without the check, we would attempt to set it on array
# scalars, and that would fail
a.flags.writeable = False
return a

def item(self) -> Union[int, float, bool]:
return self.raw.item() # type: ignore
Expand Down
7 changes: 6 additions & 1 deletion eagerpy/tensor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,12 @@ def tanh(self: TensorType) -> TensorType:
return type(self)(torch.tanh(self.raw))

def numpy(self: TensorType) -> Any:
return self.raw.detach().cpu().numpy()
a = self.raw.detach().cpu().numpy()
if a.flags.writeable:
# without the check, we would attempt to set it on array
# scalars, and that would fail
a.flags.writeable = False
return a

def item(self) -> Union[int, float, bool]:
return self.raw.item()
Expand Down
7 changes: 6 additions & 1 deletion eagerpy/tensor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,12 @@ def raw(self) -> "tf.Tensor": # type: ignore
return super().raw

def numpy(self: TensorType) -> Any:
return self.raw.numpy()
a = self.raw.numpy()
if a.flags.writeable:
# without the check, we would attempt to set it on array
# scalars, and that would fail
a.flags.writeable = False
return a

def item(self: TensorType) -> Union[int, float, bool]:
return self.numpy().item() # type: ignore
Expand Down

0 comments on commit 19bce20

Please sign in to comment.