Skip to content

Commit

Permalink
Merge pull request #40 from radionets-project/fix_image_rot
Browse files Browse the repository at this point in the history
Fix image rotation
  • Loading branch information
aknierim authored Oct 10, 2024
2 parents 4978574 + e93ae00 commit 05c9809
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 11 deletions.
3 changes: 3 additions & 0 deletions docs/changes/40.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- Fix image rotation caused by bug in rd/lm grid computation in ``pyvisgen.simulation.observation.Obseravtion``
- Fix field order in ``pyvisgen.simulation.observation.ValidBaselineSubset`` data class
- Flip input image at the beginning of ``pyvisgen.simulation.visibility.vis_loop`` to ensure correct indexing, e.g. for plotting
23 changes: 12 additions & 11 deletions pyvisgen/simulation/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def get_valid_subset(self, num_baselines, device):
date = (torch.from_numpy(t[:-1][mask] + t[1:][mask]) / 2).to(device)

return ValidBaselineSubset(
baseline_nums,
u_start,
u_stop,
u_valid,
Expand All @@ -74,13 +73,13 @@ def get_valid_subset(self, num_baselines, device):
w_start,
w_stop,
w_valid,
baseline_nums,
date,
)


@dataclass()
class ValidBaselineSubset:
baseline_nums: torch.tensor
u_start: torch.tensor
u_stop: torch.tensor
u_valid: torch.tensor
Expand All @@ -90,6 +89,7 @@ class ValidBaselineSubset:
w_start: torch.tensor
w_stop: torch.tensor
w_valid: torch.tensor
baseline_nums: torch.tensor
date: torch.tensor

def __getitem__(self, i):
Expand Down Expand Up @@ -350,7 +350,7 @@ def create_rd_grid(self):
Returns
-------
3d array
rd_grid : 3d array
Returns a 3d array with every pixel containing a RA and Dec value
"""
# transform to rad
Expand All @@ -370,9 +370,10 @@ def create_rd_grid(self):
- self.img_size / 2
) * res + dec

_, R = torch.meshgrid((r, r), indexing="ij")
D, _ = torch.meshgrid((d, d), indexing="ij")
R, _ = torch.meshgrid((r, r), indexing="ij")
_, D = torch.meshgrid((d, d), indexing="ij")
rd_grid = torch.cat([R[..., None], D[..., None]], dim=2)

return rd_grid

def create_lm_grid(self):
Expand All @@ -387,17 +388,17 @@ def create_lm_grid(self):
Returns
-------
3d array
lm_grid : 3d array
Returns a 3d array with every pixel containing a l and m value
"""
dec = torch.deg2rad(self.dec)

lm_grid = torch.zeros(self.rd.shape, device=self.device, dtype=torch.float64)
lm_grid[:, :, 0] = (torch.cos(self.rd[..., 1]) * torch.sin(self.rd[..., 0])).T
lm_grid[:, :, 1] = (
torch.sin(self.rd[..., 1]) * torch.cos(dec)
- torch.cos(self.rd[..., 1]) * torch.sin(dec) * torch.cos(self.rd[..., 0])
).T
lm_grid[..., 0] = torch.cos(self.rd[..., 1]) * torch.sin(self.rd[..., 0])
lm_grid[..., 1] = torch.sin(self.rd[..., 1]) * torch.cos(dec) - torch.cos(
self.rd[..., 1]
) * torch.sin(dec) * torch.cos(self.rd[..., 0])

return lm_grid

def get_baselines(self, times):
Expand Down
2 changes: 2 additions & 0 deletions pyvisgen/simulation/visibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def vis_loop(
torch.set_num_threads(num_threads)
torch._dynamo.config.suppress_errors = True

SI = torch.flip(SI, dims=[1])

# define unpolarized sky distribution
SI = SI.permute(dims=(1, 2, 0))
I = torch.zeros((SI.shape[0], SI.shape[1], 4), dtype=torch.cdouble)
Expand Down

0 comments on commit 05c9809

Please sign in to comment.